summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-10-09 17:18:00 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-10-09 17:18:00 -0400
commit58ad3b5ed78249f1faf48774bf6dedd424a1f435 (patch)
tree46f10805d70064abfc405f49035d7c1e4e6648d8 /lib/sqlalchemy/sql/crud.py
parent78a7bbdb3b0906a35528bdc829a08f0644d6fd7b (diff)
parent44e5a31ccee5335962602327132a4196fb1c7911 (diff)
downloadsqlalchemy-pr200.tar.gz
Merge remote-tracking branch 'origin/pr/200' into pr200pr200
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py36
1 files changed, 29 insertions, 7 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 72b66c036..8b236cc56 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -14,6 +14,7 @@ from .. import exc
from . import elements
import operator
+
REQUIRED = util.symbol('REQUIRED', """
Placeholder for the value within a :class:`.BindParameter`
which is required to be present when the statement is passed
@@ -61,15 +62,25 @@ def _get_crud_params(compiler, stmt, **kw):
_column_as_key, _getattr_col_key, _col_bind_name = \
_key_getters_for_crud_column(compiler)
+ # We have to keep parameters' order if we are doing an update and the
+ # statement paramenters are a list or tuple of pairs. It would also work
+ # without isupdate check, but adding it shortcircuits the boolean operation
+ # resulting in false for all inserts.
+ if stmt._preserve_parameter_order:
+ stmt_parameters = util.OrderedDict(stmt_parameters)
+ dict_type = util.OrderedDict
+ else:
+ dict_type = dict
+
# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
- parameters = {}
+ parameters = dict_type()
else:
- parameters = dict((_column_as_key(key), REQUIRED)
- for key in compiler.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
+ parameters = dict_type((_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if not stmt_parameters or
+ key not in stmt_parameters)
# create a list of column assignment clauses as tuples
values = []
@@ -97,7 +108,8 @@ def _get_crud_params(compiler, stmt, **kw):
_scan_cols(
compiler, stmt, parameters,
_getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ _col_bind_name, check_columns, values, kw,
+ keep_order=stmt._preserve_parameter_order)
if parameters and stmt_parameters:
check = set(parameters).intersection(
@@ -202,7 +214,7 @@ def _scan_insert_from_select_cols(
def _scan_cols(
compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
+ _column_as_key, _col_bind_name, check_columns, values, kw, keep_order):
need_pks, implicit_returning, \
implicit_return_defaults, postfetch_lastrowid = \
@@ -210,6 +222,16 @@ def _scan_cols(
cols = stmt.table.columns
+ if keep_order:
+ # Order columns with parameters first, preserving their original order,
+ # and then the rest of the columns
+ keys = tuple(parameters.keys()) if parameters else tuple()
+ table_cols = tuple(cols)
+ cols = sorted(table_cols,
+ key=(lambda x: keys.index(_getattr_col_key(x))
+ if _getattr_col_key(x) in keys
+ else len(keys) + table_cols.index(x)))
+
for c in cols:
col_key = _getattr_col_key(c)