diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-10-09 17:18:00 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-10-09 17:18:00 -0400 |
commit | 58ad3b5ed78249f1faf48774bf6dedd424a1f435 (patch) | |
tree | 46f10805d70064abfc405f49035d7c1e4e6648d8 | |
parent | 78a7bbdb3b0906a35528bdc829a08f0644d6fd7b (diff) | |
parent | 44e5a31ccee5335962602327132a4196fb1c7911 (diff) | |
download | sqlalchemy-pr200.tar.gz |
Merge remote-tracking branch 'origin/pr/200' into pr200pr200
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 36 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 50 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 1 | ||||
-rw-r--r-- | test/orm/test_cycles.py | 14 | ||||
-rw-r--r-- | test/orm/test_query.py | 48 | ||||
-rw-r--r-- | test/sql/test_update.py | 72 |
8 files changed, 196 insertions, 41 deletions
diff --git a/.gitignore b/.gitignore index 55066f843..81fd2d9ed 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ coverage.xml sqlnet.log /mapping_setup.py /test.py +/.cache/ diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index d89a93dd3..57ac0f08e 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1258,10 +1258,13 @@ class BulkUpdate(BulkUD): "Invalid expression type: %r" % key) def _do_exec(self): - values = dict( - (self._resolve_string_to_expr(k), v) - for k, v in self.values.items() - ) + if isinstance(self.values, (list, tuple)): + values = tuple((self._resolve_string_to_expr(k), v) + for k, v in self.values) + else: + values = dict((self._resolve_string_to_expr(k), v) + for k, v in self.values.items()) + update_stmt = sql.update(self.primary_table, self.context.whereclause, values, **self.update_kwargs) @@ -1311,7 +1314,9 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): def _additional_evaluators(self, evaluator_compiler): self.value_evaluators = {} - for key, value in self.values.items(): + values = (self.values.items() if hasattr(self.values, 'items') + else self.values) + for key, value in values: key = self._resolve_key_to_attrname(key) if key is not None: self.value_evaluators[key] = evaluator_compiler.process( diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 72b66c036..8b236cc56 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -14,6 +14,7 @@ from .. import exc from . import elements import operator + REQUIRED = util.symbol('REQUIRED', """ Placeholder for the value within a :class:`.BindParameter` which is required to be present when the statement is passed @@ -61,15 +62,25 @@ def _get_crud_params(compiler, stmt, **kw): _column_as_key, _getattr_col_key, _col_bind_name = \ _key_getters_for_crud_column(compiler) + # We have to keep parameters' order if we are doing an update and the + # statement paramenters are a list or tuple of pairs. It would also work + # without isupdate check, but adding it shortcircuits the boolean operation + # resulting in false for all inserts. + if stmt._preserve_parameter_order: + stmt_parameters = util.OrderedDict(stmt_parameters) + dict_type = util.OrderedDict + else: + dict_type = dict + # if we have statement parameters - set defaults in the # compiled params if compiler.column_keys is None: - parameters = {} + parameters = dict_type() else: - parameters = dict((_column_as_key(key), REQUIRED) - for key in compiler.column_keys - if not stmt_parameters or - key not in stmt_parameters) + parameters = dict_type((_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if not stmt_parameters or + key not in stmt_parameters) # create a list of column assignment clauses as tuples values = [] @@ -97,7 +108,8 @@ def _get_crud_params(compiler, stmt, **kw): _scan_cols( compiler, stmt, parameters, _getattr_col_key, _column_as_key, - _col_bind_name, check_columns, values, kw) + _col_bind_name, check_columns, values, kw, + keep_order=stmt._preserve_parameter_order) if parameters and stmt_parameters: check = set(parameters).intersection( @@ -202,7 +214,7 @@ def _scan_insert_from_select_cols( def _scan_cols( compiler, stmt, parameters, _getattr_col_key, - _column_as_key, _col_bind_name, check_columns, values, kw): + _column_as_key, _col_bind_name, check_columns, values, kw, keep_order): need_pks, implicit_returning, \ implicit_return_defaults, postfetch_lastrowid = \ @@ -210,6 +222,16 @@ def _scan_cols( cols = stmt.table.columns + if keep_order: + # Order columns with parameters first, preserving their original order, + # and then the rest of the columns + keys = tuple(parameters.keys()) if parameters else tuple() + table_cols = tuple(cols) + cols = sorted(table_cols, + key=(lambda x: keys.index(_getattr_col_key(x)) + if _getattr_col_key(x) in keys + else len(keys) + table_cols.index(x))) + for c in cols: col_key = _getattr_col_key(c) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 6756f1554..7243e56e1 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -15,6 +15,7 @@ from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \ from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes from .. import util from .. import exc +from sqlalchemy.sql import schema class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): @@ -32,24 +33,26 @@ class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): def _process_colparams(self, parameters): def process_single(p): if isinstance(p, (list, tuple)): - return dict( - (c.key, pval) - for c, pval in zip(self.table.c, p) - ) + return dict((c.key, pval) for c, pval in zip(self.table.c, p)) else: return p - if (isinstance(parameters, (list, tuple)) and parameters and - isinstance(parameters[0], (list, tuple, dict))): + if parameters and isinstance(parameters, (list, tuple)): + p0 = parameters[0] + is_lt = isinstance(p0, (list, tuple)) + # If it's an ordered dict in the form of value pairs return it + if is_lt and len(p0) == 2 and isinstance(p0[0], schema.Column): + return parameters, False, True - if not self._supports_multi_parameters: - raise exc.InvalidRequestError( - "This construct does not support " - "multiple parameter sets.") + if is_lt or isinstance(p0, dict): + if not self._supports_multi_parameters: + raise exc.InvalidRequestError( + "This construct does not support " + "multiple parameter sets.") - return [process_single(p) for p in parameters], True - else: - return process_single(parameters), False + return [process_single(p) for p in parameters], True, False + + return process_single(parameters), False, False def params(self, *arg, **kw): """Set the parameters for the statement. @@ -178,12 +181,14 @@ class ValuesBase(UpdateBase): _supports_multi_parameters = False _has_multi_parameters = False + _preserve_parameter_order = False select = None def __init__(self, table, values, prefixes): self.table = _interpret_as_from(table) - self.parameters, self._has_multi_parameters = \ - self._process_colparams(values) + + self.parameters, self._has_multi_parameters, \ + self._preserve_parameter_order = self._process_colparams(values) if prefixes: self._setup_prefixes(prefixes) @@ -315,12 +320,14 @@ class ValuesBase(UpdateBase): v = {} if self.parameters is None: - self.parameters, self._has_multi_parameters = \ - self._process_colparams(v) + self.parameters, self._has_multi_parameters, \ + self._preserve_parameter_order = self._process_colparams(v) else: if self._has_multi_parameters: self.parameters = list(self.parameters) - p, self._has_multi_parameters = self._process_colparams(v) + p, self._has_multi_parameters, \ + self._preserve_parameter_order = self._process_colparams(v) + if not self._has_multi_parameters: raise exc.ArgumentError( "Can't mix single-values and multiple values " @@ -329,7 +336,8 @@ class ValuesBase(UpdateBase): self.parameters.extend(p) else: self.parameters = self.parameters.copy() - p, self._has_multi_parameters = self._process_colparams(v) + p, self._has_multi_parameters, \ + self._preserve_parameter_order = self._process_colparams(v) if self._has_multi_parameters: raise exc.ArgumentError( "Can't mix single-values and multiple values " @@ -548,8 +556,8 @@ class Insert(ValuesBase): raise exc.InvalidRequestError( "This construct already inserts value expressions") - self.parameters, self._has_multi_parameters = \ - self._process_colparams( + self.parameters, self._has_multi_parameters, \ + self._preserve_parameter_order = self._process_colparams( dict((_column_as_key(n), Null()) for n in names)) self.select_names = names diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index cbd74faac..bf1bfd310 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -436,7 +436,6 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, return pairs - class ClauseAdapter(visitors.ReplacingCloningVisitor): """Clones and modifies clauses based on column correspondence. diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index c95b8d152..56386e8d2 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -1181,9 +1181,10 @@ class PostUpdateBatchingTest(fixtures.MappedTest): testing.db, sess.flush, CompiledSQL( - "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, " - "c3_id=:c3_id WHERE parent.id = :parent_id", - lambda ctx: {'c2_id': c23.id, 'parent_id': p1.id, 'c1_id': c12.id, 'c3_id': c31.id} + "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, c3_id=:c3_id " + "WHERE parent.id = :parent_id", + lambda ctx: {'c2_id': c23.id, 'parent_id': p1.id, + 'c1_id': c12.id, 'c3_id': c31.id} ) ) @@ -1193,8 +1194,9 @@ class PostUpdateBatchingTest(fixtures.MappedTest): testing.db, sess.flush, CompiledSQL( - "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, " - "c3_id=:c3_id WHERE parent.id = :parent_id", - lambda ctx: {'c2_id': None, 'parent_id': p1.id, 'c1_id': None, 'c3_id': None} + "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, c3_id=:c3_id " + "WHERE parent.id = :parent_id", + lambda ctx: {'c2_id': None, 'parent_id': p1.id, + 'c1_id': None, 'c3_id': None} ) ) diff --git a/test/orm/test_query.py b/test/orm/test_query.py index a373f1482..458637453 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -3867,7 +3867,7 @@ class SessionBindTest(QueryTest): def _assert_bind_args(self, session): get_bind = mock.Mock(side_effect=session.get_bind) with mock.patch.object(session, "get_bind", get_bind): - yield + yield get_bind for call_ in get_bind.mock_calls: is_(call_[1][0], inspect(self.classes.User)) is_not_(call_[2]['clause'], None) @@ -3903,6 +3903,52 @@ class SessionBindTest(QueryTest): session.query(User).filter(User.id == 15).update( {"name": "foob"}, synchronize_session=False) + def test_bulk_update_unordered_dict(self): + User = self.classes.User + session = Session() + + # Do an update using unordered dict and check that the parametes used + # are unordered + with self._assert_bind_args(session) as mock_args: + session.query(User).filter(User.id == 15).update( + {'name': 'foob', 'id': 123}) + # Confirm that parameters are a dict instead of tuple or list + params_type = type(mock_args.mock_calls[0][2]['clause'].parameters) + assert params_type is dict + + def test_bulk_update_ordered_dict(self): + User = self.classes.User + session = Session() + + # Do an update using an ordered dict and check that the parametes used + # are unordered + with self._assert_bind_args(session) as mock_args: + session.query(User).filter(User.id == 15).update( + util.OrderedDict((('name', 'foob'), ('id', 123)))) + params_type = type(mock_args.mock_calls[0][2]['clause'].parameters) + assert params_type is dict + + def test_bulk_update_with_order(self): + User = self.classes.User + session = Session() + + # Do update using a tuple and check that order is preserved + with self._assert_bind_args(session) as mock_args: + session.query(User).filter(User.id == 15).update( + (('id', 123), ('name', 'foob'))) + cols = [c[0].name for c + in mock_args.mock_calls[0][2]['clause'].parameters] + assert ['id', 'name'] == cols + + # Now invert the order and use a list instead, and check that order is + # also preserved + with self._assert_bind_args(session) as mock_args: + session.query(User).filter(User.id == 15).update( + [('id', 123), ('name', 'foob')]) + cols = [c[0].name for c + in mock_args.mock_calls[0][2]['clause'].parameters] + assert ['id', 'name'] == cols + def test_bulk_delete_no_sync(self): User = self.classes.User session = Session() diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 58c86613b..059c3ad6d 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -4,6 +4,7 @@ from sqlalchemy.dialects import mysql from sqlalchemy.engine import default from sqlalchemy.testing import AssertsCompiledSQL, eq_, fixtures from sqlalchemy.testing.schema import Table, Column +from sqlalchemy import util class _UpdateFromTestBase(object): @@ -165,6 +166,77 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): table1.c.name: table1.c.name + 'lala', table1.c.myid: func.do_stuff(table1.c.myid, literal('hoho')) } + + self.assert_compile( + update( + table1, + (table1.c.myid == func.hoho(4)) & ( + table1.c.name == literal('foo') + + table1.c.name + + literal('lala')), + values=values), + 'UPDATE mytable ' + 'SET ' + 'myid=do_stuff(mytable.myid, :param_1), ' + 'name=(mytable.name || :name_1) ' + 'WHERE ' + 'mytable.myid = hoho(:hoho_1) AND ' + 'mytable.name = :param_2 || mytable.name || :param_3') + + def test_update_12(self): + table1 = self.tables.mytable + + # Confirm that we can pass values as tuple value pairs + values = ( + (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))), + (table1.c.name, table1.c.name + 'lala')) + self.assert_compile( + update( + table1, + (table1.c.myid == func.hoho(4)) & ( + table1.c.name == literal('foo') + + table1.c.name + + literal('lala')), + values=values), + 'UPDATE mytable ' + 'SET ' + 'myid=do_stuff(mytable.myid, :param_1), ' + 'name=(mytable.name || :name_1) ' + 'WHERE ' + 'mytable.myid = hoho(:hoho_1) AND ' + 'mytable.name = :param_2 || mytable.name || :param_3') + + def test_update_13(self): + table1 = self.tables.mytable + + # Confirm that we can pass values as list value pairs + values = [ + (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))), + (table1.c.name, table1.c.name + 'lala')] + self.assert_compile( + update( + table1, + (table1.c.myid == func.hoho(4)) & ( + table1.c.name == literal('foo') + + table1.c.name + + literal('lala')), + values=values), + 'UPDATE mytable ' + 'SET ' + 'myid=do_stuff(mytable.myid, :param_1), ' + 'name=(mytable.name || :name_1) ' + 'WHERE ' + 'mytable.myid = hoho(:hoho_1) AND ' + 'mytable.name = :param_2 || mytable.name || :param_3') + + def test_update_14(self): + table1 = self.tables.mytable + + # Confirm that ordered dicts are treated as normal dicts + values = util.OrderedDict(( + (table1.c.name, table1.c.name + 'lala'), + (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))))) + self.assert_compile( update( table1, |