diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 21 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 7 |
4 files changed, 27 insertions, 19 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 768d4f83a..691195772 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1971,8 +1971,7 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - crud_params = crud._get_crud_params(self, update_stmt, keep_order=True, - **kw) + crud_params = crud._get_crud_params(self, update_stmt, **kw) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 614f9413b..235889ad9 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -14,6 +14,8 @@ from .. import exc from . import elements import operator +from sqlalchemy.sql import util as sql_util + REQUIRED = util.symbol('REQUIRED', """ Placeholder for the value within a :class:`.BindParameter` which is required to be present when the statement is passed @@ -26,7 +28,7 @@ values present. """) -def _get_crud_params(compiler, stmt, keep_order=False, **kw): +def _get_crud_params(compiler, stmt, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -61,15 +63,22 @@ def _get_crud_params(compiler, stmt, keep_order=False, **kw): _column_as_key, _getattr_col_key, _col_bind_name = \ _key_getters_for_crud_column(compiler) + # We have to keep parameters' order if we are doing an update and the + # statement paramenters are a list or tuple of pairs. It would also work + # without isupdate check, but adding it shortcircuits the boolean operation + # resulting in false for all inserts. + keep_order = (compiler.isupdate + and sql_util.is_value_pair_dict(stmt.parameters)) + dict_type = util.OrderedDict if keep_order else dict # if we have statement parameters - set defaults in the # compiled params if compiler.column_keys is None: - parameters = util.OrderedDict() + parameters = dict_type() else: - parameters = util.OrderedDict((_column_as_key(key), REQUIRED) - for key in compiler.column_keys - if not stmt_parameters or - key not in stmt_parameters) + parameters = dict_type((_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if not stmt_parameters or + key not in stmt_parameters) # create a list of column assignment clauses as tuples values = [] diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 983fed2b5..c8407e3fd 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -15,7 +15,7 @@ from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \ from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes from .. import util from .. import exc -from sqlalchemy.sql import schema +from sqlalchemy.sql import util as sql_util class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): @@ -31,25 +31,18 @@ class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): _prefixes = () def _process_colparams(self, parameters): - def is_value_pair_dict(params): - # Check if params is a value list/tuple representing a dictionary - return ( - isinstance(params, (list, tuple)) and - all(isinstance(p, (list, tuple)) and len(p) == 2 and - isinstance(p[0], schema.Column) for p in params)) - def process_single(p): if isinstance(p, (list, tuple)): - if is_value_pair_dict(p): + if sql_util.is_value_pair_dict(p): return util.OrderedDict(p) - return util.OrderedDict( + return dict( (c.key, pval) for c, pval in zip(self.table.c, p) ) else: return p - if (not is_value_pair_dict(parameters) and + if (not sql_util.is_value_pair_dict(parameters) and isinstance(parameters, (list, tuple)) and parameters and isinstance(parameters[0], (list, tuple, dict))): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index cbd74faac..c73f710af 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -29,6 +29,7 @@ join_condition = util.langhelpers.public_factory( from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate from .elements import _find_columns from .ddl import sort_tables +from sqlalchemy.sql import schema def find_join_source(clauses, join_to): @@ -436,6 +437,12 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, return pairs +def is_value_pair_dict(params): + """Check if params is a value list/tuple representing a dictionary.""" + return (isinstance(params, (list, tuple)) and + all(isinstance(p, (list, tuple)) and len(p) == 2 and + isinstance(p[0], schema.Column) for p in params)) + class ClauseAdapter(visitors.ReplacingCloningVisitor): """Clones and modifies clauses based on column correspondence. |
