diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 36 |
1 files changed, 29 insertions, 7 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 72b66c036..8b236cc56 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -14,6 +14,7 @@ from .. import exc from . import elements import operator + REQUIRED = util.symbol('REQUIRED', """ Placeholder for the value within a :class:`.BindParameter` which is required to be present when the statement is passed @@ -61,15 +62,25 @@ def _get_crud_params(compiler, stmt, **kw): _column_as_key, _getattr_col_key, _col_bind_name = \ _key_getters_for_crud_column(compiler) + # We have to keep parameters' order if we are doing an update and the + # statement paramenters are a list or tuple of pairs. It would also work + # without isupdate check, but adding it shortcircuits the boolean operation + # resulting in false for all inserts. + if stmt._preserve_parameter_order: + stmt_parameters = util.OrderedDict(stmt_parameters) + dict_type = util.OrderedDict + else: + dict_type = dict + # if we have statement parameters - set defaults in the # compiled params if compiler.column_keys is None: - parameters = {} + parameters = dict_type() else: - parameters = dict((_column_as_key(key), REQUIRED) - for key in compiler.column_keys - if not stmt_parameters or - key not in stmt_parameters) + parameters = dict_type((_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if not stmt_parameters or + key not in stmt_parameters) # create a list of column assignment clauses as tuples values = [] @@ -97,7 +108,8 @@ def _get_crud_params(compiler, stmt, **kw): _scan_cols( compiler, stmt, parameters, _getattr_col_key, _column_as_key, - _col_bind_name, check_columns, values, kw) + _col_bind_name, check_columns, values, kw, + keep_order=stmt._preserve_parameter_order) if parameters and stmt_parameters: check = set(parameters).intersection( @@ -202,7 +214,7 @@ def _scan_insert_from_select_cols( def _scan_cols( compiler, stmt, parameters, _getattr_col_key, - _column_as_key, _col_bind_name, check_columns, values, kw): + _column_as_key, _col_bind_name, check_columns, values, kw, keep_order): need_pks, implicit_returning, \ implicit_return_defaults, postfetch_lastrowid = \ @@ -210,6 +222,16 @@ def _scan_cols( cols = stmt.table.columns + if keep_order: + # Order columns with parameters first, preserving their original order, + # and then the rest of the columns + keys = tuple(parameters.keys()) if parameters else tuple() + table_cols = tuple(cols) + cols = sorted(table_cols, + key=(lambda x: keys.index(_getattr_col_key(x)) + if _getattr_col_key(x) in keys + else len(keys) + table_cols.index(x))) + for c in cols: col_key = _getattr_col_key(c) |