diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-06-03 17:38:35 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-06-06 13:31:54 -0400 |
| commit | 3ab2364e78641c4f0e4b6456afc2cbed39b0d0e6 (patch) | |
| tree | f3dc26609070c1a357a366592c791a3ec0655483 /lib/sqlalchemy | |
| parent | 14bc09203a8b5b2bc001f764ad7cce6a184975cc (diff) | |
| download | sqlalchemy-3ab2364e78641c4f0e4b6456afc2cbed39b0d0e6.tar.gz | |
Convert bulk update/delete to new execution model
This reorganizes the BulkUD model in sqlalchemy.orm.persistence
to be based on the CompileState concept and to allow plain
update() / delete() to be passed to session.execute() where
the ORM synchronize session logic will take place.
Also gets "synchronize_session='fetch'" working with horizontal
sharding.
Adding a few more result.scalar_one() types of methods
as scalar_one() seems like what is normally desired.
Fixes: #5160
Change-Id: I8001ebdad089da34119eb459709731ba6c0ba975
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/engine/cursor.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/result.py | 85 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 95 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/hybrid.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/events.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 498 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 115 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 114 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/roles.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 58 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/compat.py | 1 |
21 files changed, 667 insertions, 436 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 1d832e4af..d03d79df7 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1630,6 +1630,15 @@ class CursorResult(BaseCursorResult, Result): def _raw_row_iterator(self): return self._fetchiter_impl() + def merge(self, *others): + merged_result = super(CursorResult, self).merge(*others) + setup_rowcounts = not self._metadata.returns_rows + if setup_rowcounts: + merged_result.rowcount = sum( + result.rowcount for result in (self,) + others + ) + return merged_result + def close(self): """Close this :class:`_engine.CursorResult`. diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 600229037..b29bc22d4 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -951,7 +951,7 @@ class Result(InPlaceGenerative): """ return self._allrows() - def _only_one_row(self, raise_for_second_row, raise_for_none): + def _only_one_row(self, raise_for_second_row, raise_for_none, scalar): onerow = self._fetchone_impl row = onerow(hard_close=True) @@ -1010,27 +1010,43 @@ class Result(InPlaceGenerative): # if we checked for second row then that would have # closed us :) self._soft_close(hard=True) - post_creational_filter = self._post_creational_filter - if post_creational_filter: - row = post_creational_filter(row) - return row + if not scalar: + post_creational_filter = self._post_creational_filter + if post_creational_filter: + row = post_creational_filter(row) + + if scalar and row: + return row[0] + else: + return row def first(self): """Fetch the first row or None if no row is present. Closes the result set and discards remaining rows. + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the :meth:`.Result.scalar` method, + or combine :meth:`.Result.scalars` and :meth:`.Result.first`. + .. comment: A warning is emitted if additional rows remain. :return: a :class:`.Row` object if no filters are applied, or None if no rows remain. When filters are applied, such as :meth:`_engine.Result.mappings` - or :meth:`._engine.Result.scalar`, different kinds of objects + or :meth:`._engine.Result.scalars`, different kinds of objects may be returned. + .. seealso:: + + :meth:`_result.Result.scalar` + + :meth:`_result.Result.one` + """ - return self._only_one_row(False, False) + return self._only_one_row(False, False, False) def one_or_none(self): """Return at most one result or raise an exception. @@ -1055,15 +1071,50 @@ class Result(InPlaceGenerative): :meth:`_result.Result.one` """ - return self._only_one_row(True, False) + return self._only_one_row(True, False, False) + + def scalar_one(self): + """Return exactly one scalar result or raise an exception. + + This is equvalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one`. + + .. seealso:: + + :meth:`.Result.one` + + :meth:`.Result.scalars` + + """ + return self._only_one_row(True, True, True) + + def scalar_one_or_none(self): + """Return exactly one or no scalar result. + + This is equvalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one_or_none`. + + .. seealso:: + + :meth:`.Result.one_or_none` + + :meth:`.Result.scalars` + + """ + return self._only_one_row(True, False, True) def one(self): - """Return exactly one result or raise an exception. + """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no rows, or :class:`.MultipleResultsFound` if multiple rows would be returned. + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the :meth:`.Result.scalar_one` method, + or combine :meth:`.Result.scalars` and :meth:`.Result.one`. + .. versionadded:: 1.4 :return: The first :class:`.Row`. @@ -1079,24 +1130,26 @@ class Result(InPlaceGenerative): :meth:`_result.Result.one_or_none` + :meth:`_result.Result.scalar_one` + """ - return self._only_one_row(True, True) + return self._only_one_row(True, True, False) def scalar(self): """Fetch the first column of the first row, and close the result set. + Returns None if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + After calling this method, the object is fully closed, e.g. the :meth:`_engine.CursorResult.close` method will have been called. - :return: a Python scalar value , or None if no rows remain + :return: a Python scalar value , or None if no rows remain. """ - row = self.first() - if row is not None: - return row[0] - else: - return None + return self._only_one_row(False, False, True) class FrozenResult(object): diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index c3ac71c10..0983807cb 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -50,58 +50,6 @@ class ShardedQuery(Query): """ return self.execution_options(_sa_shard_id=shard_id) - def _execute_crud(self, stmt, mapper): - def exec_for_shard(shard_id): - conn = self.session.connection( - mapper=mapper, - shard_id=shard_id, - clause=stmt, - close_with_result=True, - ) - result = conn._execute_20( - stmt, self.load_options._params, self._execution_options - ) - return result - - if self._shard_id is not None: - return exec_for_shard(self._shard_id) - else: - rowcount = 0 - results = [] - # TODO: this will have to be the new object - for shard_id in self.execute_chooser(self): - result = exec_for_shard(shard_id) - rowcount += result.rowcount - results.append(result) - - return ShardedResult(results, rowcount) - - -class ShardedResult(object): - """A value object that represents multiple :class:`_engine.CursorResult` - objects. - - This is used by the :meth:`.ShardedQuery._execute_crud` hook to return - an object that takes the place of the single :class:`_engine.CursorResult`. - - Attribute include ``result_proxies``, which is a sequence of the - actual :class:`_engine.CursorResult` objects, - as well as ``aggregate_rowcount`` - or ``rowcount``, which is the sum of all the individual rowcount values. - - .. versionadded:: 1.3 - """ - - __slots__ = ("result_proxies", "aggregate_rowcount") - - def __init__(self, result_proxies, aggregate_rowcount): - self.result_proxies = result_proxies - self.aggregate_rowcount = aggregate_rowcount - - @property - def rowcount(self): - return self.aggregate_rowcount - class ShardedSession(Session): def __init__( @@ -259,37 +207,40 @@ class ShardedSession(Session): def execute_and_instances(orm_context): - if orm_context.bind_arguments.get("_horizontal_shard", False): - return None - params = orm_context.parameters - load_options = orm_context.load_options + if orm_context.is_select: + load_options = active_options = orm_context.load_options + update_options = None + if params is None: + params = active_options._params + + else: + load_options = None + update_options = active_options = orm_context.update_delete_options + session = orm_context.session # orm_query = orm_context.orm_query - if params is None: - params = load_options._params - - def iter_for_shard(shard_id, load_options): + def iter_for_shard(shard_id, load_options, update_options): execution_options = dict(orm_context.local_execution_options) bind_arguments = dict(orm_context.bind_arguments) - bind_arguments["_horizontal_shard"] = True bind_arguments["shard_id"] = shard_id - load_options += {"_refresh_identity_token": shard_id} - execution_options["_sa_orm_load_options"] = load_options + if orm_context.is_select: + load_options += {"_refresh_identity_token": shard_id} + execution_options["_sa_orm_load_options"] = load_options + else: + update_options += {"_refresh_identity_token": shard_id} + execution_options["_sa_orm_update_options"] = update_options - return session.execute( - orm_context.statement, - orm_context.parameters, - execution_options, - bind_arguments, + return orm_context.invoke_statement( + bind_arguments=bind_arguments, execution_options=execution_options ) - if load_options._refresh_identity_token is not None: - shard_id = load_options._refresh_identity_token + if active_options._refresh_identity_token is not None: + shard_id = active_options._refresh_identity_token elif "_sa_shard_id" in orm_context.merged_execution_options: shard_id = orm_context.merged_execution_options["_sa_shard_id"] elif "shard_id" in orm_context.bind_arguments: @@ -298,11 +249,11 @@ def execute_and_instances(orm_context): shard_id = None if shard_id is not None: - return iter_for_shard(shard_id, load_options) + return iter_for_shard(shard_id, load_options, update_options) else: partial = [] for shard_id in session.execute_chooser(orm_context): - result_ = iter_for_shard(shard_id, load_options) + result_ = iter_for_shard(shard_id, load_options, update_options) partial.append(result_) return partial[0].merge(*partial[1:]) diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 9f73b5d31..efd8d7d6b 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -777,7 +777,7 @@ things it can be used for. from .. import util from ..orm import attributes from ..orm import interfaces - +from ..sql import elements HYBRID_METHOD = util.symbol("HYBRID_METHOD") """Symbol indicating an :class:`InspectionAttr` that's @@ -1144,6 +1144,9 @@ class ExprComparator(Comparator): return self.hybrid.info def _bulk_update_tuples(self, value): + if isinstance(value, elements.BindParameter): + value = value.value + if isinstance(self.expression, attributes.QueryableAttribute): return self.expression._bulk_update_tuples(value) elif self.hybrid.update_expr is not None: diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index bd4074ea1..a16db66f6 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -189,7 +189,7 @@ class ORMCompileState(CompileState): @classmethod def orm_pre_session_exec( - cls, session, statement, execution_options, bind_arguments + cls, session, statement, params, execution_options, bind_arguments ): load_options = execution_options.get( "_sa_orm_load_options", QueryContext.default_load_options @@ -216,6 +216,8 @@ class ORMCompileState(CompileState): if load_options._autoflush: session._autoflush() + return execution_options + @classmethod def orm_setup_cursor_result( cls, session, statement, execution_options, bind_arguments, result diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 6be4f0dff..027f2521b 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -420,6 +420,9 @@ class CompositeProperty(DescriptorProperty): return CompositeProperty.CompositeBundle(self.prop, clauses) def _bulk_update_tuples(self, value): + if isinstance(value, sql.elements.BindParameter): + value = value.value + if value is None: values = [None for key in self.prop._attribute_keys] elif isinstance(value, self.prop.composite_class): diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index be7aa272e..217aa76c7 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1764,7 +1764,7 @@ class SessionEvents(event.Events): lambda update_context: ( update_context.session, update_context.query, - update_context.context, + None, update_context.result, ), ) @@ -1782,12 +1782,13 @@ class SessionEvents(event.Events): was called upon. * ``values`` The "values" dictionary that was passed to :meth:`_query.Query.update`. - * ``context`` The :class:`.QueryContext` object, corresponding - to the invocation of an ORM query. * ``result`` the :class:`_engine.CursorResult` returned as a result of the bulk UPDATE operation. + .. versionchanged:: 1.4 the update_context no longer has a + ``QueryContext`` object associated with it. + .. seealso:: :meth:`.QueryEvents.before_compile_update` @@ -1802,7 +1803,7 @@ class SessionEvents(event.Events): lambda delete_context: ( delete_context.session, delete_context.query, - delete_context.context, + None, delete_context.result, ), ) @@ -1818,12 +1819,13 @@ class SessionEvents(event.Events): * ``query`` -the :class:`_query.Query` object that this update operation was called upon. - * ``context`` The :class:`.QueryContext` object, corresponding - to the invocation of an ORM query. * ``result`` the :class:`_engine.CursorResult` returned as a result of the bulk DELETE operation. + .. versionchanged:: 1.4 the update_context no longer has a + ``QueryContext`` object associated with it. + .. seealso:: :meth:`.QueryEvents.before_compile_delete` diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 4166e6d2a..c4cb89c03 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2235,14 +2235,28 @@ class Mapper( @HasMemoized.memoized_instancemethod def __clause_element__(self): - return self.selectable._annotate( - { - "entity_namespace": self, - "parententity": self, - "parentmapper": self, - "compile_state_plugin": "orm", - } - )._set_propagate_attrs( + + annotations = { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + "compile_state_plugin": "orm", + } + if self.persist_selectable is not self.local_table: + # joined table inheritance, with polymorphic selectable, + # etc. + annotations["dml_table"] = self.local_table._annotate( + { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + "compile_state_plugin": "orm", + } + )._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + + return self.selectable._annotate(annotations)._set_propagate_attrs( {"compile_state_plugin": "orm", "plugin_subject": self} ) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 163ebf22a..19d43d354 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -28,11 +28,15 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..future import select as future_select from ..sql import coercions from ..sql import expression from ..sql import operators from ..sql import roles -from ..sql.base import _from_objects +from ..sql.base import CompileState +from ..sql.base import Options +from ..sql.dml import DeleteDMLState +from ..sql.dml import UpdateDMLState from ..sql.elements import BooleanClauseList @@ -1650,243 +1654,193 @@ def _sort_states(mapper, states): ) -class BulkUD(object): - """Handle bulk update and deletes via a :class:`_query.Query`.""" +_EMPTY_DICT = util.immutabledict() - def __init__(self, query): - self.query = query.enable_eagerloads(False) - self._validate_query_state() - def _validate_query_state(self): - for attr, methname, notset, op in ( - ("_limit_clause", "limit()", None, operator.is_), - ("_offset_clause", "offset()", None, operator.is_), - ("_order_by_clauses", "order_by()", (), operator.eq), - ("_group_by_clauses", "group_by()", (), operator.eq), - ("_distinct", "distinct()", False, operator.is_), - ( - "_from_obj", - "join(), outerjoin(), select_from(), or from_self()", - (), - operator.eq, - ), - ( - "_legacy_setup_joins", - "join(), outerjoin(), select_from(), or from_self()", - (), - operator.eq, - ), - ): - if not op(getattr(self.query, attr), notset): - raise sa_exc.InvalidRequestError( - "Can't call Query.update() or Query.delete() " - "when %s has been called" % (methname,) - ) - - @property - def session(self): - return self.query.session +class BulkUDCompileState(CompileState): + class default_update_options(Options): + _synchronize_session = "evaluate" + _autoflush = True + _subject_mapper = None + _resolved_values = _EMPTY_DICT + _resolved_keys_as_propnames = _EMPTY_DICT + _value_evaluators = _EMPTY_DICT + _matched_objects = None + _matched_rows = None + _refresh_identity_token = None @classmethod - def _factory(cls, lookup, synchronize_session, *arg): - try: - klass = lookup[synchronize_session] - except KeyError as err: - util.raise_( - sa_exc.ArgumentError( - "Valid strategies for session synchronization " - "are %s" % (", ".join(sorted(repr(x) for x in lookup))) - ), - replace_context=err, + def orm_pre_session_exec( + cls, session, statement, params, execution_options, bind_arguments + ): + sync = execution_options.get("synchronize_session", None) + if sync is None: + sync = statement._execution_options.get( + "synchronize_session", None ) - else: - return klass(*arg) - - def exec_(self): - self._do_before_compile() - self._do_pre() - self._do_pre_synchronize() - self._do_exec() - self._do_post_synchronize() - self._do_post() - - def _execute_stmt(self, stmt): - self.result = self.query._execute_crud(stmt, self.mapper) - self.rowcount = self.result.rowcount - - def _do_before_compile(self): - raise NotImplementedError() - @util.preload_module("sqlalchemy.orm.context") - def _do_pre(self): - query_context = util.preloaded.orm_context - query = self.query - - self.compile_state = ( - self.context - ) = compile_state = query._compile_state() - - self.mapper = compile_state._entity_zero() - - if isinstance( - compile_state._entities[0], query_context._RawColumnEntity, - ): - # check for special case of query(table) - tables = set() - for ent in compile_state._entities: - if not isinstance(ent, query_context._RawColumnEntity,): - tables.clear() - break - else: - tables.update(_from_objects(ent.column)) + update_options = execution_options.get( + "_sa_orm_update_options", + BulkUDCompileState.default_update_options, + ) - if len(tables) != 1: - raise sa_exc.InvalidRequestError( - "This operation requires only one Table or " - "entity be specified as the target." + if sync is not None: + if sync not in ("evaluate", "fetch", False): + raise sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are 'evaluate', 'fetch', False" ) - else: - self.primary_table = tables.pop() + update_options += {"_synchronize_session": sync} + bind_arguments["clause"] = statement + try: + plugin_subject = statement._propagate_attrs["plugin_subject"] + except KeyError: + assert False, "statement had 'orm' plugin but no plugin_subject" else: - self.primary_table = compile_state._only_entity_zero( - "This operation requires only one Table or " - "entity be specified as the target." - ).mapper.local_table + bind_arguments["mapper"] = plugin_subject.mapper - session = query.session + update_options += {"_subject_mapper": plugin_subject.mapper} - if query.load_options._autoflush: + if update_options._autoflush: session._autoflush() - def _do_pre_synchronize(self): - pass + if update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) - def _do_post_synchronize(self): - pass + return util.immutabledict(execution_options).union( + dict(_sa_orm_update_options=update_options) + ) + @classmethod + def orm_setup_cursor_result( + cls, session, statement, execution_options, bind_arguments, result + ): + update_options = execution_options["_sa_orm_update_options"] + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_evaluate(session, update_options) + elif update_options._synchronize_session == "fetch": + cls._do_post_synchronize_fetch(session, update_options) -class BulkEvaluate(BulkUD): - """BulkUD which does the 'evaluate' method of session state resolution.""" + return result - def _additional_evaluators(self, evaluator_compiler): - pass + @classmethod + def _do_pre_synchronize_evaluate( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): + mapper = update_options._subject_mapper + target_cls = mapper.class_ - def _do_pre_synchronize(self): - query = self.query - target_cls = self.compile_state._mapper_zero().class_ + value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT try: evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - if query._where_criteria: + if statement._where_criteria: eval_condition = evaluator_compiler.process( - *query._where_criteria + *statement._where_criteria ) else: def eval_condition(obj): return True - self._additional_evaluators(evaluator_compiler) + # TODO: something more robust for this conditional + if statement.__visit_name__ == "update": + resolved_values = cls._get_resolved_values(mapper, statement) + value_evaluators = {} + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + for key, value in resolved_keys_as_propnames: + value_evaluators[key] = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) except evaluator.UnevaluatableError as err: util.raise_( sa_exc.InvalidRequestError( 'Could not evaluate current criteria in Python: "%s". ' "Specify 'fetch' or False for the " - "synchronize_session parameter." % err + "synchronize_session execution option." % err ), from_=err, ) # TODO: detect when the where clause is a trivial primary key match - self.matched_objects = [ + matched_objects = [ obj - for ( - cls, - pk, - identity_token, - ), obj in query.session.identity_map.items() - if issubclass(cls, target_cls) and eval_condition(obj) + for (cls, pk, identity_token,), obj in session.identity_map.items() + if issubclass(cls, target_cls) + and eval_condition(obj) + and identity_token == update_options._refresh_identity_token ] - - -class BulkFetch(BulkUD): - """BulkUD which does the 'fetch' method of session state resolution.""" - - def _do_pre_synchronize(self): - query = self.query - session = query.session - select_stmt = self.compile_state.statement.with_only_columns( - self.primary_table.primary_key - ) - self.matched_rows = session.execute( - select_stmt, mapper=self.mapper, params=query.load_options._params - ).fetchall() - - -class BulkUpdate(BulkUD): - """BulkUD which handles UPDATEs.""" - - def __init__(self, query, values, update_kwargs): - super(BulkUpdate, self).__init__(query) - self.values = values - self.update_kwargs = update_kwargs + return update_options + { + "_matched_objects": matched_objects, + "_value_evaluators": value_evaluators, + "_resolved_keys_as_propnames": resolved_keys_as_propnames, + } @classmethod - def factory(cls, query, synchronize_session, values, update_kwargs): - return BulkUD._factory( - { - "evaluate": BulkUpdateEvaluate, - "fetch": BulkUpdateFetch, - False: BulkUpdate, - }, - synchronize_session, - query, - values, - update_kwargs, - ) - - def _do_before_compile(self): - if self.query.dispatch.before_compile_update: - for fn in self.query.dispatch.before_compile_update: - new_query = fn(self.query, self) - if new_query is not None: - self.query = new_query + def _get_resolved_values(cls, mapper, statement): + if statement._multi_values: + return [] + elif statement._ordered_values: + iterator = statement._ordered_values + elif statement._values: + iterator = statement._values.items() + else: + return [] - @property - def _resolved_values(self): values = [] - for k, v in ( - self.values.items() - if hasattr(self.values, "items") - else self.values - ): - if self.mapper: - if isinstance(k, util.string_types): - desc = sql.util._entity_namespace_key(self.mapper, k) - values.extend(desc._bulk_update_tuples(v)) - elif isinstance(k, attributes.QueryableAttribute): - values.extend(k._bulk_update_tuples(v)) + if iterator: + for k, v in iterator: + if mapper: + if isinstance(k, util.string_types): + desc = sql.util._entity_namespace_key(mapper, k) + values.extend(desc._bulk_update_tuples(v)) + elif isinstance(k, attributes.QueryableAttribute): + values.extend(k._bulk_update_tuples(v)) + else: + values.append((k, v)) else: values.append((k, v)) - else: - values.append((k, v)) return values - @property - def _resolved_values_keys_as_propnames(self): + @classmethod + def _resolved_keys_as_propnames(cls, mapper, resolved_values): values = [] - for k, v in self._resolved_values: + for k, v in resolved_values: if isinstance(k, attributes.QueryableAttribute): values.append((k.key, v)) continue elif hasattr(k, "__clause_element__"): k = k.__clause_element__() - if self.mapper and isinstance(k, expression.ColumnElement): + if mapper and isinstance(k, expression.ColumnElement): try: - attr = self.mapper._columntoproperty[k] + attr = mapper._columntoproperty[k] except orm_exc.UnmappedColumnError: pass else: @@ -1897,87 +1851,99 @@ class BulkUpdate(BulkUD): ) return values - def _do_exec(self): - values = self._resolved_values + @classmethod + def _do_pre_synchronize_fetch( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): + mapper = update_options._subject_mapper - if not self.update_kwargs.get("preserve_parameter_order", False): - values = dict(values) + if mapper: + primary_table = mapper.local_table + else: + primary_table = statement._raw_columns[0] - update_stmt = sql.update( - self.primary_table, **self.update_kwargs - ).values(values) + # note this creates a Select() *without* the ORM plugin. + # we don't want that here. + select_stmt = future_select(*primary_table.primary_key) + select_stmt._where_criteria = statement._where_criteria - update_stmt._where_criteria = self.compile_state._where_criteria + matched_rows = session.execute( + select_stmt, params, execution_options, bind_arguments + ).fetchall() - self._execute_stmt(update_stmt) + if statement.__visit_name__ == "update": + resolved_values = cls._get_resolved_values(mapper, statement) + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + else: + resolved_keys_as_propnames = _EMPTY_DICT - def _do_post(self): - session = self.query.session - session.dispatch.after_bulk_update(self) + return update_options + { + "_matched_rows": matched_rows, + "_resolved_keys_as_propnames": resolved_keys_as_propnames, + } -class BulkDelete(BulkUD): - """BulkUD which handles DELETEs.""" +@CompileState.plugin_for("orm", "update") +class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): + @classmethod + def create_for_statement(cls, statement, compiler, **kw): - def __init__(self, query): - super(BulkDelete, self).__init__(query) + self = cls.__new__(cls) - @classmethod - def factory(cls, query, synchronize_session): - return BulkUD._factory( - { - "evaluate": BulkDeleteEvaluate, - "fetch": BulkDeleteFetch, - False: BulkDelete, - }, - synchronize_session, - query, + self.mapper = mapper = statement.table._annotations.get( + "parentmapper", None ) - def _do_before_compile(self): - if self.query.dispatch.before_compile_delete: - for fn in self.query.dispatch.before_compile_delete: - new_query = fn(self.query, self) - if new_query is not None: - self.query = new_query + self._resolved_values = cls._get_resolved_values(mapper, statement) - def _do_exec(self): - delete_stmt = sql.delete(self.primary_table,) - delete_stmt._where_criteria = self.compile_state._where_criteria + if not statement._preserve_parameter_order and statement._values: + self._resolved_values = dict(self._resolved_values) - self._execute_stmt(delete_stmt) + new_stmt = sql.Update.__new__(sql.Update) + new_stmt.__dict__.update(statement.__dict__) + new_stmt.table = mapper.local_table - def _do_post(self): - session = self.query.session - session.dispatch.after_bulk_delete(self) + # note if the statement has _multi_values, these + # are passed through to the new statement, which will then raise + # InvalidRequestError because UPDATE doesn't support multi_values + # right now. + if statement._ordered_values: + new_stmt._ordered_values = self._resolved_values + elif statement._values: + new_stmt._values = self._resolved_values + UpdateDMLState.__init__(self, new_stmt, compiler, **kw) -class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): - """BulkUD which handles UPDATEs using the "evaluate" - method of session resolution.""" + return self - def _additional_evaluators(self, evaluator_compiler): - self.value_evaluators = {} - values = self._resolved_values_keys_as_propnames - for key, value in values: - self.value_evaluators[key] = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) + @classmethod + def _do_post_synchronize_evaluate(cls, session, update_options): - def _do_post_synchronize(self): - session = self.query.session states = set() - evaluated_keys = list(self.value_evaluators.keys()) - for obj in self.matched_objects: + evaluated_keys = list(update_options._value_evaluators.keys()) + for obj in update_options._matched_objects: + state, dict_ = ( attributes.instance_state(obj), attributes.instance_dict(obj), ) + assert ( + state.identity_token == update_options._refresh_identity_token + ) + # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: - dict_[key] = self.value_evaluators[key](obj) + dict_[key] = update_options._value_evaluators[key](obj) state.manager.dispatch.refresh(state, None, to_evaluate) @@ -1991,39 +1957,25 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): states.add(state) session._register_altered(states) - -class BulkDeleteEvaluate(BulkEvaluate, BulkDelete): - """BulkUD which handles DELETEs using the "evaluate" - method of session resolution.""" - - def _do_post_synchronize(self): - self.query.session._remove_newly_deleted( - [attributes.instance_state(obj) for obj in self.matched_objects] - ) - - -class BulkUpdateFetch(BulkFetch, BulkUpdate): - """BulkUD which handles UPDATEs using the "fetch" - method of session resolution.""" - - def _do_post_synchronize(self): - session = self.query.session - target_mapper = self.compile_state._mapper_zero() + @classmethod + def _do_post_synchronize_fetch(cls, session, update_options): + target_mapper = update_options._subject_mapper states = set( [ attributes.instance_state(session.identity_map[identity_key]) for identity_key in [ target_mapper.identity_key_from_primary_key( - list(primary_key) + list(primary_key), + identity_token=update_options._refresh_identity_token, ) - for primary_key in self.matched_rows + for primary_key in update_options._matched_rows ] if identity_key in session.identity_map ] ) - values = self._resolved_values_keys_as_propnames + values = update_options._resolved_keys_as_propnames attrib = set(k for k, v in values) for state in states: to_expire = attrib.intersection(state.dict) @@ -2032,18 +1984,38 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate): session._register_altered(states) -class BulkDeleteFetch(BulkFetch, BulkDelete): - """BulkUD which handles DELETEs using the "fetch" - method of session resolution.""" +@CompileState.plugin_for("orm", "delete") +class BulkORMDelete(DeleteDMLState, BulkUDCompileState): + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + self = cls.__new__(cls) + + self.mapper = statement.table._annotations.get("parentmapper", None) + + DeleteDMLState.__init__(self, statement, compiler, **kw) + + return self + + @classmethod + def _do_post_synchronize_evaluate(cls, session, update_options): + + session._remove_newly_deleted( + [ + attributes.instance_state(obj) + for obj in update_options._matched_objects + ] + ) + + @classmethod + def _do_post_synchronize_fetch(cls, session, update_options): + target_mapper = update_options._subject_mapper - def _do_post_synchronize(self): - session = self.query.session - target_mapper = self.compile_state._mapper_zero() - for primary_key in self.matched_rows: + for primary_key in update_options._matched_rows: # TODO: inline this and call remove_newly_deleted # once identity_key = target_mapper.identity_key_from_primary_key( - list(primary_key) + list(primary_key), + identity_token=update_options._refresh_identity_token, ) if identity_key in session.identity_map: session._remove_newly_deleted( diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 5137f9b1d..284ea9d72 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -19,12 +19,12 @@ database to return iterable result sets. """ import itertools +import operator from . import attributes from . import exc as orm_exc from . import interfaces from . import loading -from . import persistence from .base import _assertions from .context import _column_descriptions from .context import _legacy_determine_last_joined_entity @@ -2825,15 +2825,6 @@ class Query( return result - def _execute_crud(self, stmt, mapper): - conn = self.session.connection( - mapper=mapper, clause=stmt, close_with_result=True - ) - - return conn._execute_20( - stmt, self.load_options._params, self._execution_options - ) - def __str__(self): statement = self._statement_20() @@ -3178,9 +3169,27 @@ class Query( """ - delete_op = persistence.BulkDelete.factory(self, synchronize_session) - delete_op.exec_() - return delete_op.rowcount + bulk_del = BulkDelete(self,) + if self.dispatch.before_compile_delete: + for fn in self.dispatch.before_compile_delete: + new_query = fn(bulk_del.query, bulk_del) + if new_query is not None: + bulk_del.query = new_query + + self = bulk_del.query + + delete_ = sql.delete(*self._raw_columns) + delete_._where_criteria = self._where_criteria + result = self.session.execute( + delete_, + self.load_options._params, + execution_options={"synchronize_session": synchronize_session}, + ) + bulk_del.result = result + self.session.dispatch.after_bulk_delete(bulk_del) + result.close() + + return result.rowcount def update(self, values, synchronize_session="evaluate", update_args=None): r"""Perform a bulk update query. @@ -3313,11 +3322,27 @@ class Query( """ update_args = update_args or {} - update_op = persistence.BulkUpdate.factory( - self, synchronize_session, values, update_args + + bulk_ud = BulkUpdate(self, values, update_args) + + if self.dispatch.before_compile_update: + for fn in self.dispatch.before_compile_update: + new_query = fn(bulk_ud.query, bulk_ud) + if new_query is not None: + bulk_ud.query = new_query + self = bulk_ud.query + + upd = sql.update(*self._raw_columns, **update_args).values(values) + upd._where_criteria = self._where_criteria + result = self.session.execute( + upd, + self.load_options._params, + execution_options={"synchronize_session": synchronize_session}, ) - update_op.exec_() - return update_op.rowcount + bulk_ud.result = result + self.session.dispatch.after_bulk_update(bulk_ud) + result.close() + return result.rowcount def _compile_state(self, for_statement=False, **kw): """Create an out-of-compiler ORMCompileState object. @@ -3427,3 +3452,59 @@ class AliasOption(interfaces.LoaderOption): def process_compile_state(self, compile_state): pass + + +class BulkUD(object): + """State used for the orm.Query version of update() / delete(). + + This object is now specific to Query only. + + """ + + def __init__(self, query): + self.query = query.enable_eagerloads(False) + self._validate_query_state() + self.mapper = self.query._entity_from_pre_ent_zero() + + def _validate_query_state(self): + for attr, methname, notset, op in ( + ("_limit_clause", "limit()", None, operator.is_), + ("_offset_clause", "offset()", None, operator.is_), + ("_order_by_clauses", "order_by()", (), operator.eq), + ("_group_by_clauses", "group_by()", (), operator.eq), + ("_distinct", "distinct()", False, operator.is_), + ( + "_from_obj", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), + ( + "_legacy_setup_joins", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), + ): + if not op(getattr(self.query, attr), notset): + raise sa_exc.InvalidRequestError( + "Can't call Query.update() or Query.delete() " + "when %s has been called" % (methname,) + ) + + @property + def session(self): + return self.query.session + + +class BulkUpdate(BulkUD): + """BulkUD which handles UPDATEs.""" + + def __init__(self, query, values, update_kwargs): + super(BulkUpdate, self).__init__(query) + self.values = values + self.update_kwargs = update_kwargs + + +class BulkDelete(BulkUD): + """BulkUD which handles DELETEs.""" diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index ee42419a2..5ad8bcf2f 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -33,7 +33,9 @@ from .. import future from .. import util from ..inspection import inspect from ..sql import coercions +from ..sql import dml from ..sql import roles +from ..sql import selectable from ..sql import visitors from ..sql.base import CompileState @@ -113,16 +115,24 @@ class ORMExecuteState(util.MemoizedSlots): "_execution_options", "_merged_execution_options", "bind_arguments", + "_compile_state_cls", ) def __init__( - self, session, statement, parameters, execution_options, bind_arguments + self, + session, + statement, + parameters, + execution_options, + bind_arguments, + compile_state_cls, ): self.session = session self.statement = statement self.parameters = parameters self._execution_options = execution_options self.bind_arguments = bind_arguments + self._compile_state_cls = compile_state_cls def invoke_statement( self, @@ -194,6 +204,38 @@ class ORMExecuteState(util.MemoizedSlots): ) @property + def is_orm_statement(self): + """return True if the operation is an ORM statement. + + This indictes that the select(), update(), or delete() being + invoked contains ORM entities as subjects. For a statement + that does not have ORM entities and instead refers only to + :class:`.Table` metadata, it is invoked as a Core SQL statement + and no ORM-level automation takes place. + + """ + return self._compile_state_cls is not None + + @property + def is_select(self): + """return True if this is a SELECT operation.""" + return isinstance(self.statement, selectable.Select) + + @property + def is_update(self): + """return True if this is an UPDATE operation.""" + return isinstance(self.statement, dml.Update) + + @property + def is_delete(self): + """return True if this is a DELETE operation.""" + return isinstance(self.statement, dml.Delete) + + @property + def _is_crud(self): + return isinstance(self.statement, (dml.Update, dml.Delete)) + + @property def execution_options(self): """Placeholder for execution options. @@ -270,11 +312,31 @@ class ORMExecuteState(util.MemoizedSlots): def load_options(self): """Return the load_options that will be used for this execution.""" + if not self.is_select: + raise sa_exc.InvalidRequestError( + "This ORM execution is not against a SELECT statement " + "so there are no load options." + ) return self._execution_options.get( "_sa_orm_load_options", context.QueryContext.default_load_options ) @property + def update_delete_options(self): + """Return the update_delete_options that will be used for this + execution.""" + + if not self._is_crud: + raise sa_exc.InvalidRequestError( + "This ORM execution is not against an UPDATE or DELETE " + "statement so there are no update options." + ) + return self._execution_options.get( + "_sa_orm_update_options", + persistence.BulkUDCompileState.default_update_options, + ) + + @property def user_defined_options(self): """The sequence of :class:`.UserDefinedOptions` that have been associated with the statement being invoked. @@ -1455,35 +1517,37 @@ class Session(_SessionClassMethods): compile_state_cls = CompileState._get_plugin_class_for_plugin( statement, "orm" ) + else: + compile_state_cls = None - compile_state_cls.orm_pre_session_exec( - self, statement, execution_options, bind_arguments + if compile_state_cls is not None: + execution_options = compile_state_cls.orm_pre_session_exec( + self, statement, params, execution_options, bind_arguments ) - - if self.dispatch.do_orm_execute: - skip_events = bind_arguments.pop("_sa_skip_events", False) - - if not skip_events: - orm_exec_state = ORMExecuteState( - self, - statement, - params, - execution_options, - bind_arguments, - ) - for fn in self.dispatch.do_orm_execute: - result = fn(orm_exec_state) - if result: - return result - else: - compile_state_cls = None bind_arguments.setdefault("clause", statement) if statement._is_future: execution_options = util.immutabledict().merge_with( execution_options, {"future_result": True} ) + if self.dispatch.do_orm_execute: + # run this event whether or not we are in ORM mode + skip_events = bind_arguments.get("_sa_skip_events", False) + if not skip_events: + orm_exec_state = ORMExecuteState( + self, + statement, + params, + execution_options, + bind_arguments, + compile_state_cls, + ) + for fn in self.dispatch.do_orm_execute: + result = fn(orm_exec_state) + if result: + return result + bind = self.get_bind(**bind_arguments) conn = self._connection_for_bind(bind, close_with_result=True) @@ -1601,8 +1665,8 @@ class Session(_SessionClassMethods): self.__binds[insp] = bind elif insp.is_mapper: self.__binds[insp.class_] = bind - for selectable in insp._all_tables: - self.__binds[selectable] = bind + for _selectable in insp._all_tables: + self.__binds[_selectable] = bind else: raise sa_exc.ArgumentError( "Not an acceptable bind target: %s" % key @@ -1664,7 +1728,9 @@ class Session(_SessionClassMethods): """ self._add_bind(table, bind) - def get_bind(self, mapper=None, clause=None, bind=None): + def get_bind( + self, mapper=None, clause=None, bind=None, _sa_skip_events=None + ): """Return a "bind" to which this :class:`.Session` is bound. The "bind" is usually an instance of :class:`_engine.Engine`, diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f14319089..5dd3b519a 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -446,10 +446,14 @@ class CompileState(object): plugin_name = statement._propagate_attrs.get( "compile_state_plugin", "default" ) - else: - plugin_name = "default" + klass = cls.plugins.get( + (plugin_name, statement.__visit_name__), None + ) + if klass is None: + klass = cls.plugins[("default", statement.__visit_name__)] - klass = cls.plugins[(plugin_name, statement.__visit_name__)] + else: + klass = cls.plugins[("default", statement.__visit_name__)] if klass is cls: return cls(statement, compiler, **kw) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index db43e42a6..4c6a0317a 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -755,6 +755,16 @@ class AnonymizedFromClauseImpl(StrictFromClauseImpl): return element.alias(name=name, flat=flat) +class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): + __slots__ = () + + def _post_coercion(self, element, **kw): + if "dml_table" in element._annotations: + return element._annotations["dml_table"] + else: + return element + + class DMLSelectImpl(_NoTextCoercion, RoleImpl): __slots__ = () diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index f4160b552..2519438d1 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3215,6 +3215,8 @@ class SQLCompiler(Compiled): toplevel = not self.stack if toplevel: self.isupdate = True + if not self.compile_state: + self.compile_state = compile_state extra_froms = compile_state._extra_froms is_multitable = bool(extra_froms) @@ -3342,6 +3344,8 @@ class SQLCompiler(Compiled): toplevel = not self.stack if toplevel: self.isdelete = True + if not self.compile_state: + self.compile_state = compile_state extra_froms = compile_state._extra_froms diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 467a764d6..a82641d77 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -19,6 +19,7 @@ from .base import CompileState from .base import DialectKWArgs from .base import Executable from .base import HasCompileState +from .elements import BooleanClauseList from .elements import ClauseElement from .elements import Null from .selectable import HasCTE @@ -150,7 +151,6 @@ class UpdateDMLState(DMLState): def __init__(self, statement, compiler, **kw): self.statement = statement - self.isupdate = True self._preserve_parameter_order = statement._preserve_parameter_order if statement._ordered_values is not None: @@ -447,7 +447,9 @@ class ValuesBase(UpdateBase): _returning = () def __init__(self, table, values, prefixes): - self.table = coercions.expect(roles.FromClauseRole, table) + self.table = coercions.expect( + roles.DMLTableRole, table, apply_propagate_attrs=self + ) if values is not None: self.values.non_generative(self, values) if prefixes: @@ -949,6 +951,28 @@ class DMLWhereBase(object): coercions.expect(roles.WhereHavingRole, whereclause), ) + def filter(self, *criteria): + """A synonym for the :meth:`_dml.DMLWhereBase.where` method.""" + + return self.where(*criteria) + + @property + def whereclause(self): + """Return the completed WHERE clause for this :class:`.DMLWhereBase` + statement. + + This assembles the current collection of WHERE criteria + into a single :class:`_expression.BooleanClauseList` construct. + + + .. versionadded:: 1.4 + + """ + + return BooleanClauseList._construct_for_whereclause( + self._where_criteria + ) + class Update(DMLWhereBase, ValuesBase): """Represent an Update construct. @@ -1266,7 +1290,9 @@ class Delete(DMLWhereBase, UpdateBase): """ self._bind = bind - self.table = coercions.expect(roles.FromClauseRole, table) + self.table = coercions.expect( + roles.DMLTableRole, table, apply_propagate_attrs=self + ) self._returning = returning if prefixes: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 5a55fe5f2..3d94ec9ff 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -184,10 +184,15 @@ class CompoundElementRole(SQLRole): ) +# TODO: are we using this? class DMLRole(StatementRole): pass +class DMLTableRole(FromClauseRole): + _role_name = "subject table for an INSERT, UPDATE or DELETE" + + class DMLColumnRole(SQLRole): _role_name = "SET/VALUES column expression or string key" diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index d6845e05f..a95fc561a 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -789,7 +789,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._reset_column_collection() -class Join(FromClause): +class Join(roles.DMLTableRole, FromClause): """represent a ``JOIN`` construct between two :class:`_expression.FromClause` elements. @@ -1406,7 +1406,7 @@ class AliasedReturnsRows(NoInit, FromClause): return self.element.bind -class Alias(AliasedReturnsRows): +class Alias(roles.DMLTableRole, AliasedReturnsRows): """Represents an table or selectable alias (AS). Represents an alias, as typically applied to any table or @@ -1987,7 +1987,7 @@ class FromGrouping(GroupedElement, FromClause): self.element = state["element"] -class TableClause(Immutable, FromClause): +class TableClause(roles.DMLTableRole, Immutable, FromClause): """Represents a minimal "table" construct. This is a lightweight table object that has only a name, a diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 388097e45..68281f33d 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -10,6 +10,7 @@ from .. import util from ..inspection import inspect from ..util import collections_abc from ..util import HasMemoized +from ..util import py37 SKIP_TRAVERSE = util.symbol("skip_traverse") COMPARE_FAILED = False @@ -562,23 +563,38 @@ class _CacheKey(ExtendedInternalTraversal): ) def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): + if py37: + # in py37 we can assume two dictionaries created in the same + # insert ordering will retain that sorting + return ( + attrname, + tuple( + ( + k._gen_cache_key(anon_map, bindparams) + if hasattr(k, "__clause_element__") + else k, + obj[k]._gen_cache_key(anon_map, bindparams), + ) + for k in obj + ), + ) + else: + expr_values = {k for k in obj if hasattr(k, "__clause_element__")} + if expr_values: + # expr values can't be sorted deterministically right now, + # so no cache + anon_map[NO_CACHE] = True + return () - expr_values = {k for k in obj if hasattr(k, "__clause_element__")} - if expr_values: - # expr values can't be sorted deterministically right now, - # so no cache - anon_map[NO_CACHE] = True - return () - - str_values = expr_values.symmetric_difference(obj) + str_values = expr_values.symmetric_difference(obj) - return ( - attrname, - tuple( - (k, obj[k]._gen_cache_key(anon_map, bindparams)) - for k in sorted(str_values) - ), - ) + return ( + attrname, + tuple( + (k, obj[k]._gen_cache_key(anon_map, bindparams)) + for k in sorted(str_values) + ), + ) def visit_dml_multi_values( self, attrname, obj, parent, anon_map, bindparams @@ -1130,6 +1146,18 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for lv, rv in zip(left, right): if not self._compare_dml_values_or_ce(lv, rv, **kw): return COMPARE_FAILED + elif isinstance(right, collections_abc.Sequence): + return COMPARE_FAILED + elif py37: + # dictionaries guaranteed to support insert ordering in + # py37 so that we can compare the keys in order. without + # this, we can't compare SQL expression keys because we don't + # know which key is which + for (lk, lv), (rk, rv) in zip(left.items(), right.items()): + if not self._compare_dml_values_or_ce(lk, rk, **kw): + return COMPARE_FAILED + if not self._compare_dml_values_or_ce(lv, rv, **kw): + return COMPARE_FAILED else: for lk in left: lv = left[lk] diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 0ea9f067e..54da06a3d 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -403,10 +403,6 @@ class AssertsCompiledSQL(object): LABEL_STYLE_TABLENAME_PLUS_COL ) clause = compile_state.statement - elif isinstance(clause, orm.persistence.BulkUD): - with mock.patch.object(clause, "_execute_stmt") as stmt_mock: - clause.exec_() - clause = stmt_mock.mock_calls[0][1][0] if compile_kwargs: kw["compile_kwargs"] = compile_kwargs diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 55a6cdcf9..273570357 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -65,6 +65,7 @@ from .compat import pickle # noqa from .compat import print_ # noqa from .compat import py2k # noqa from .compat import py36 # noqa +from .compat import py37 # noqa from .compat import py3k # noqa from .compat import quote_plus # noqa from .compat import raise_ # noqa diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 247dbc13c..5c46395f9 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -15,6 +15,7 @@ import platform import sys +py37 = sys.version_info >= (3, 7) py36 = sys.version_info >= (3, 6) py3k = sys.version_info >= (3, 0) py2k = sys.version_info < (3, 0) |
