summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--lib/sqlalchemy/orm/persistence.py15
-rw-r--r--lib/sqlalchemy/sql/crud.py36
-rw-r--r--lib/sqlalchemy/sql/dml.py50
-rw-r--r--lib/sqlalchemy/sql/util.py1
-rw-r--r--test/orm/test_cycles.py14
-rw-r--r--test/orm/test_query.py48
-rw-r--r--test/sql/test_update.py72
8 files changed, 196 insertions, 41 deletions
diff --git a/.gitignore b/.gitignore
index 55066f843..81fd2d9ed 100644
--- a/.gitignore
+++ b/.gitignore
@@ -19,3 +19,4 @@ coverage.xml
sqlnet.log
/mapping_setup.py
/test.py
+/.cache/
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index d89a93dd3..57ac0f08e 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -1258,10 +1258,13 @@ class BulkUpdate(BulkUD):
"Invalid expression type: %r" % key)
def _do_exec(self):
- values = dict(
- (self._resolve_string_to_expr(k), v)
- for k, v in self.values.items()
- )
+ if isinstance(self.values, (list, tuple)):
+ values = tuple((self._resolve_string_to_expr(k), v)
+ for k, v in self.values)
+ else:
+ values = dict((self._resolve_string_to_expr(k), v)
+ for k, v in self.values.items())
+
update_stmt = sql.update(self.primary_table,
self.context.whereclause, values,
**self.update_kwargs)
@@ -1311,7 +1314,9 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
def _additional_evaluators(self, evaluator_compiler):
self.value_evaluators = {}
- for key, value in self.values.items():
+ values = (self.values.items() if hasattr(self.values, 'items')
+ else self.values)
+ for key, value in values:
key = self._resolve_key_to_attrname(key)
if key is not None:
self.value_evaluators[key] = evaluator_compiler.process(
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 72b66c036..8b236cc56 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -14,6 +14,7 @@ from .. import exc
from . import elements
import operator
+
REQUIRED = util.symbol('REQUIRED', """
Placeholder for the value within a :class:`.BindParameter`
which is required to be present when the statement is passed
@@ -61,15 +62,25 @@ def _get_crud_params(compiler, stmt, **kw):
_column_as_key, _getattr_col_key, _col_bind_name = \
_key_getters_for_crud_column(compiler)
+ # We have to keep parameters' order if we are doing an update and the
+ # statement paramenters are a list or tuple of pairs. It would also work
+ # without isupdate check, but adding it shortcircuits the boolean operation
+ # resulting in false for all inserts.
+ if stmt._preserve_parameter_order:
+ stmt_parameters = util.OrderedDict(stmt_parameters)
+ dict_type = util.OrderedDict
+ else:
+ dict_type = dict
+
# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
- parameters = {}
+ parameters = dict_type()
else:
- parameters = dict((_column_as_key(key), REQUIRED)
- for key in compiler.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
+ parameters = dict_type((_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if not stmt_parameters or
+ key not in stmt_parameters)
# create a list of column assignment clauses as tuples
values = []
@@ -97,7 +108,8 @@ def _get_crud_params(compiler, stmt, **kw):
_scan_cols(
compiler, stmt, parameters,
_getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ _col_bind_name, check_columns, values, kw,
+ keep_order=stmt._preserve_parameter_order)
if parameters and stmt_parameters:
check = set(parameters).intersection(
@@ -202,7 +214,7 @@ def _scan_insert_from_select_cols(
def _scan_cols(
compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
+ _column_as_key, _col_bind_name, check_columns, values, kw, keep_order):
need_pks, implicit_returning, \
implicit_return_defaults, postfetch_lastrowid = \
@@ -210,6 +222,16 @@ def _scan_cols(
cols = stmt.table.columns
+ if keep_order:
+ # Order columns with parameters first, preserving their original order,
+ # and then the rest of the columns
+ keys = tuple(parameters.keys()) if parameters else tuple()
+ table_cols = tuple(cols)
+ cols = sorted(table_cols,
+ key=(lambda x: keys.index(_getattr_col_key(x))
+ if _getattr_col_key(x) in keys
+ else len(keys) + table_cols.index(x)))
+
for c in cols:
col_key = _getattr_col_key(c)
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 6756f1554..7243e56e1 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -15,6 +15,7 @@ from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \
from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes
from .. import util
from .. import exc
+from sqlalchemy.sql import schema
class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement):
@@ -32,24 +33,26 @@ class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement):
def _process_colparams(self, parameters):
def process_single(p):
if isinstance(p, (list, tuple)):
- return dict(
- (c.key, pval)
- for c, pval in zip(self.table.c, p)
- )
+ return dict((c.key, pval) for c, pval in zip(self.table.c, p))
else:
return p
- if (isinstance(parameters, (list, tuple)) and parameters and
- isinstance(parameters[0], (list, tuple, dict))):
+ if parameters and isinstance(parameters, (list, tuple)):
+ p0 = parameters[0]
+ is_lt = isinstance(p0, (list, tuple))
+ # If it's an ordered dict in the form of value pairs return it
+ if is_lt and len(p0) == 2 and isinstance(p0[0], schema.Column):
+ return parameters, False, True
- if not self._supports_multi_parameters:
- raise exc.InvalidRequestError(
- "This construct does not support "
- "multiple parameter sets.")
+ if is_lt or isinstance(p0, dict):
+ if not self._supports_multi_parameters:
+ raise exc.InvalidRequestError(
+ "This construct does not support "
+ "multiple parameter sets.")
- return [process_single(p) for p in parameters], True
- else:
- return process_single(parameters), False
+ return [process_single(p) for p in parameters], True, False
+
+ return process_single(parameters), False, False
def params(self, *arg, **kw):
"""Set the parameters for the statement.
@@ -178,12 +181,14 @@ class ValuesBase(UpdateBase):
_supports_multi_parameters = False
_has_multi_parameters = False
+ _preserve_parameter_order = False
select = None
def __init__(self, table, values, prefixes):
self.table = _interpret_as_from(table)
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(values)
+
+ self.parameters, self._has_multi_parameters, \
+ self._preserve_parameter_order = self._process_colparams(values)
if prefixes:
self._setup_prefixes(prefixes)
@@ -315,12 +320,14 @@ class ValuesBase(UpdateBase):
v = {}
if self.parameters is None:
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(v)
+ self.parameters, self._has_multi_parameters, \
+ self._preserve_parameter_order = self._process_colparams(v)
else:
if self._has_multi_parameters:
self.parameters = list(self.parameters)
- p, self._has_multi_parameters = self._process_colparams(v)
+ p, self._has_multi_parameters, \
+ self._preserve_parameter_order = self._process_colparams(v)
+
if not self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
@@ -329,7 +336,8 @@ class ValuesBase(UpdateBase):
self.parameters.extend(p)
else:
self.parameters = self.parameters.copy()
- p, self._has_multi_parameters = self._process_colparams(v)
+ p, self._has_multi_parameters, \
+ self._preserve_parameter_order = self._process_colparams(v)
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
@@ -548,8 +556,8 @@ class Insert(ValuesBase):
raise exc.InvalidRequestError(
"This construct already inserts value expressions")
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(
+ self.parameters, self._has_multi_parameters, \
+ self._preserve_parameter_order = self._process_colparams(
dict((_column_as_key(n), Null()) for n in names))
self.select_names = names
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index cbd74faac..bf1bfd310 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -436,7 +436,6 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
return pairs
-
class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""Clones and modifies clauses based on column correspondence.
diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py
index c95b8d152..56386e8d2 100644
--- a/test/orm/test_cycles.py
+++ b/test/orm/test_cycles.py
@@ -1181,9 +1181,10 @@ class PostUpdateBatchingTest(fixtures.MappedTest):
testing.db,
sess.flush,
CompiledSQL(
- "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, "
- "c3_id=:c3_id WHERE parent.id = :parent_id",
- lambda ctx: {'c2_id': c23.id, 'parent_id': p1.id, 'c1_id': c12.id, 'c3_id': c31.id}
+ "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, c3_id=:c3_id "
+ "WHERE parent.id = :parent_id",
+ lambda ctx: {'c2_id': c23.id, 'parent_id': p1.id,
+ 'c1_id': c12.id, 'c3_id': c31.id}
)
)
@@ -1193,8 +1194,9 @@ class PostUpdateBatchingTest(fixtures.MappedTest):
testing.db,
sess.flush,
CompiledSQL(
- "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, "
- "c3_id=:c3_id WHERE parent.id = :parent_id",
- lambda ctx: {'c2_id': None, 'parent_id': p1.id, 'c1_id': None, 'c3_id': None}
+ "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, c3_id=:c3_id "
+ "WHERE parent.id = :parent_id",
+ lambda ctx: {'c2_id': None, 'parent_id': p1.id,
+ 'c1_id': None, 'c3_id': None}
)
)
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index a373f1482..458637453 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -3867,7 +3867,7 @@ class SessionBindTest(QueryTest):
def _assert_bind_args(self, session):
get_bind = mock.Mock(side_effect=session.get_bind)
with mock.patch.object(session, "get_bind", get_bind):
- yield
+ yield get_bind
for call_ in get_bind.mock_calls:
is_(call_[1][0], inspect(self.classes.User))
is_not_(call_[2]['clause'], None)
@@ -3903,6 +3903,52 @@ class SessionBindTest(QueryTest):
session.query(User).filter(User.id == 15).update(
{"name": "foob"}, synchronize_session=False)
+ def test_bulk_update_unordered_dict(self):
+ User = self.classes.User
+ session = Session()
+
+ # Do an update using unordered dict and check that the parametes used
+ # are unordered
+ with self._assert_bind_args(session) as mock_args:
+ session.query(User).filter(User.id == 15).update(
+ {'name': 'foob', 'id': 123})
+ # Confirm that parameters are a dict instead of tuple or list
+ params_type = type(mock_args.mock_calls[0][2]['clause'].parameters)
+ assert params_type is dict
+
+ def test_bulk_update_ordered_dict(self):
+ User = self.classes.User
+ session = Session()
+
+ # Do an update using an ordered dict and check that the parametes used
+ # are unordered
+ with self._assert_bind_args(session) as mock_args:
+ session.query(User).filter(User.id == 15).update(
+ util.OrderedDict((('name', 'foob'), ('id', 123))))
+ params_type = type(mock_args.mock_calls[0][2]['clause'].parameters)
+ assert params_type is dict
+
+ def test_bulk_update_with_order(self):
+ User = self.classes.User
+ session = Session()
+
+ # Do update using a tuple and check that order is preserved
+ with self._assert_bind_args(session) as mock_args:
+ session.query(User).filter(User.id == 15).update(
+ (('id', 123), ('name', 'foob')))
+ cols = [c[0].name for c
+ in mock_args.mock_calls[0][2]['clause'].parameters]
+ assert ['id', 'name'] == cols
+
+ # Now invert the order and use a list instead, and check that order is
+ # also preserved
+ with self._assert_bind_args(session) as mock_args:
+ session.query(User).filter(User.id == 15).update(
+ [('id', 123), ('name', 'foob')])
+ cols = [c[0].name for c
+ in mock_args.mock_calls[0][2]['clause'].parameters]
+ assert ['id', 'name'] == cols
+
def test_bulk_delete_no_sync(self):
User = self.classes.User
session = Session()
diff --git a/test/sql/test_update.py b/test/sql/test_update.py
index 58c86613b..059c3ad6d 100644
--- a/test/sql/test_update.py
+++ b/test/sql/test_update.py
@@ -4,6 +4,7 @@ from sqlalchemy.dialects import mysql
from sqlalchemy.engine import default
from sqlalchemy.testing import AssertsCompiledSQL, eq_, fixtures
from sqlalchemy.testing.schema import Table, Column
+from sqlalchemy import util
class _UpdateFromTestBase(object):
@@ -165,6 +166,77 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
table1.c.name: table1.c.name + 'lala',
table1.c.myid: func.do_stuff(table1.c.myid, literal('hoho'))
}
+
+ self.assert_compile(
+ update(
+ table1,
+ (table1.c.myid == func.hoho(4)) & (
+ table1.c.name == literal('foo') +
+ table1.c.name +
+ literal('lala')),
+ values=values),
+ 'UPDATE mytable '
+ 'SET '
+ 'myid=do_stuff(mytable.myid, :param_1), '
+ 'name=(mytable.name || :name_1) '
+ 'WHERE '
+ 'mytable.myid = hoho(:hoho_1) AND '
+ 'mytable.name = :param_2 || mytable.name || :param_3')
+
+ def test_update_12(self):
+ table1 = self.tables.mytable
+
+ # Confirm that we can pass values as tuple value pairs
+ values = (
+ (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))),
+ (table1.c.name, table1.c.name + 'lala'))
+ self.assert_compile(
+ update(
+ table1,
+ (table1.c.myid == func.hoho(4)) & (
+ table1.c.name == literal('foo') +
+ table1.c.name +
+ literal('lala')),
+ values=values),
+ 'UPDATE mytable '
+ 'SET '
+ 'myid=do_stuff(mytable.myid, :param_1), '
+ 'name=(mytable.name || :name_1) '
+ 'WHERE '
+ 'mytable.myid = hoho(:hoho_1) AND '
+ 'mytable.name = :param_2 || mytable.name || :param_3')
+
+ def test_update_13(self):
+ table1 = self.tables.mytable
+
+ # Confirm that we can pass values as list value pairs
+ values = [
+ (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))),
+ (table1.c.name, table1.c.name + 'lala')]
+ self.assert_compile(
+ update(
+ table1,
+ (table1.c.myid == func.hoho(4)) & (
+ table1.c.name == literal('foo') +
+ table1.c.name +
+ literal('lala')),
+ values=values),
+ 'UPDATE mytable '
+ 'SET '
+ 'myid=do_stuff(mytable.myid, :param_1), '
+ 'name=(mytable.name || :name_1) '
+ 'WHERE '
+ 'mytable.myid = hoho(:hoho_1) AND '
+ 'mytable.name = :param_2 || mytable.name || :param_3')
+
+ def test_update_14(self):
+ table1 = self.tables.mytable
+
+ # Confirm that ordered dicts are treated as normal dicts
+ values = util.OrderedDict((
+ (table1.c.name, table1.c.name + 'lala'),
+ (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho')))))
+
self.assert_compile(
update(
table1,