summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py3
-rw-r--r--lib/sqlalchemy/sql/crud.py21
-rw-r--r--lib/sqlalchemy/sql/dml.py15
-rw-r--r--lib/sqlalchemy/sql/util.py7
4 files changed, 27 insertions, 19 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 768d4f83a..691195772 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1971,8 +1971,7 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- crud_params = crud._get_crud_params(self, update_stmt, keep_order=True,
- **kw)
+ crud_params = crud._get_crud_params(self, update_stmt, **kw)
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 614f9413b..235889ad9 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -14,6 +14,8 @@ from .. import exc
from . import elements
import operator
+from sqlalchemy.sql import util as sql_util
+
REQUIRED = util.symbol('REQUIRED', """
Placeholder for the value within a :class:`.BindParameter`
which is required to be present when the statement is passed
@@ -26,7 +28,7 @@ values present.
""")
-def _get_crud_params(compiler, stmt, keep_order=False, **kw):
+def _get_crud_params(compiler, stmt, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -61,15 +63,22 @@ def _get_crud_params(compiler, stmt, keep_order=False, **kw):
_column_as_key, _getattr_col_key, _col_bind_name = \
_key_getters_for_crud_column(compiler)
+ # We have to keep parameters' order if we are doing an update and the
+ # statement paramenters are a list or tuple of pairs. It would also work
+ # without isupdate check, but adding it shortcircuits the boolean operation
+ # resulting in false for all inserts.
+ keep_order = (compiler.isupdate
+ and sql_util.is_value_pair_dict(stmt.parameters))
+ dict_type = util.OrderedDict if keep_order else dict
# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
- parameters = util.OrderedDict()
+ parameters = dict_type()
else:
- parameters = util.OrderedDict((_column_as_key(key), REQUIRED)
- for key in compiler.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
+ parameters = dict_type((_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if not stmt_parameters or
+ key not in stmt_parameters)
# create a list of column assignment clauses as tuples
values = []
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 983fed2b5..c8407e3fd 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -15,7 +15,7 @@ from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \
from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes
from .. import util
from .. import exc
-from sqlalchemy.sql import schema
+from sqlalchemy.sql import util as sql_util
class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement):
@@ -31,25 +31,18 @@ class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement):
_prefixes = ()
def _process_colparams(self, parameters):
- def is_value_pair_dict(params):
- # Check if params is a value list/tuple representing a dictionary
- return (
- isinstance(params, (list, tuple)) and
- all(isinstance(p, (list, tuple)) and len(p) == 2 and
- isinstance(p[0], schema.Column) for p in params))
-
def process_single(p):
if isinstance(p, (list, tuple)):
- if is_value_pair_dict(p):
+ if sql_util.is_value_pair_dict(p):
return util.OrderedDict(p)
- return util.OrderedDict(
+ return dict(
(c.key, pval)
for c, pval in zip(self.table.c, p)
)
else:
return p
- if (not is_value_pair_dict(parameters) and
+ if (not sql_util.is_value_pair_dict(parameters) and
isinstance(parameters, (list, tuple)) and parameters and
isinstance(parameters[0], (list, tuple, dict))):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index cbd74faac..c73f710af 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -29,6 +29,7 @@ join_condition = util.langhelpers.public_factory(
from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate
from .elements import _find_columns
from .ddl import sort_tables
+from sqlalchemy.sql import schema
def find_join_source(clauses, join_to):
@@ -436,6 +437,12 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
return pairs
+def is_value_pair_dict(params):
+ """Check if params is a value list/tuple representing a dictionary."""
+ return (isinstance(params, (list, tuple)) and
+ all(isinstance(p, (list, tuple)) and len(p) == 2 and
+ isinstance(p[0], schema.Column) for p in params))
+
class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""Clones and modifies clauses based on column correspondence.