diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-06-21 12:21:21 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-06-23 10:41:39 -0400 |
| commit | 62be25cdfaab377319602a1852a1fddcbf6acd45 (patch) | |
| tree | 803838d317de872ba941264f8ae64e2d4dadc9ae /lib/sqlalchemy | |
| parent | 56e817bb0ef4eaca189b42b930a6e99ee4ed0671 (diff) | |
| download | sqlalchemy-62be25cdfaab377319602a1852a1fddcbf6acd45.tar.gz | |
Propose using RETURNING for bulk updates, deletes
This patch makes several improvements in the area of
bulk updates and deletes as well as the new session mechanics.
RETURNING is now used for an UPDATE or DELETE statement
emitted for a diaelct that supports "full returning"
in order to satisfy the "fetch" strategy; this currently
includes PostgreSQL and SQL Server. The Oracle dialect
does not support RETURNING for more than one row,
so a new dialect capability "full_returning" is added
in addition to the existing "implicit_returning", indicating
this dialect supports RETURNING for zero or more rows,
not just a single identity row.
The "fetch" strategy will gracefully degrade to
the previous SELECT mechanics for dialects that do not
support RETURNING.
Additionally, the "fetch" strategy will attempt to use
evaluation for the VALUES that were UPDATEd, rather
than just expiring the updated attributes. Values should
be evalutable in all cases where the value is not
a SQL expression.
The new approach also incurs some changes in the
session.execute mechanics, where do_orm_execute() event
handlers can now be chained to each return results;
this is in turn used by the handler to detect on a
per-bind basis if the fetch strategy needs to
do a SELECT or if it can do RETURNING. A test suite is
added to test_horizontal_shard that breaks up a single
UPDATE or DELETE operation among multiple backends
where some are SQLite and don't support RETURNING and
others are PostgreSQL and do.
The session event mechanics are corrected
in terms of the "orm pre execute" hook, which now
receives a flag "is_reentrant" so that the two
ORM implementations for this can skip on their work
if they are being called inside of ORMExecuteState.invoke(),
where previously bulk update/delete were calling its
SELECT a second time.
In order for "fetch" to get the correct identity when
called as pre-execute, it also requests the identity_token
for each mapped instance which is now added as an optional
capability of a SELECT for ORM columns. the identity_token
that's placed by horizontal_sharding is now made available
within each result row, so that even when fetching a
merged result of plain rows we can tell which row belongs
to which identity token.
The evaluator that takes place within the ORM bulk update and delete for
synchronize_session="evaluate" now supports the IN and NOT IN operators.
Tuple IN is also supported.
Fixes: #1653
Change-Id: I2292b56ae004b997cef0ba4d3fc350ae1dd5efc1
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/result.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 41 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/evaluator.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 19 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 250 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 72 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 22 |
11 files changed, 360 insertions, 101 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 5aaecf23a..4b211bde7 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2402,6 +2402,9 @@ class MSDialect(default.DefaultDialect): max_identifier_length = 128 schema_name = "dbo" + implicit_returning = True + full_returning = True + colspecs = { sqltypes.DateTime: _MSDateTime, sqltypes.Date: _MSDate, @@ -2567,11 +2570,10 @@ class MSDialect(default.DefaultDialect): "features may not function properly." % ".".join(str(x) for x in self.server_version_info) ) - if ( - self.server_version_info >= MS_2005_VERSION - and "implicit_returning" not in self.__dict__ - ): - self.implicit_returning = True + + if self.server_version_info < MS_2005_VERSION: + self.implicit_returning = self.full_returning = False + if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True if self.deprecate_large_types is None: diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index f3e775354..c2d9af4d2 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2510,6 +2510,9 @@ class PGDialect(default.DefaultDialect): inspector = PGInspector isolation_level = None + implicit_returning = True + full_returning = True + construct_arguments = [ ( schema.Index, @@ -2555,10 +2558,10 @@ class PGDialect(default.DefaultDialect): def initialize(self, connection): super(PGDialect, self).initialize(connection) - self.implicit_returning = self.server_version_info > ( - 8, - 2, - ) and self.__dict__.get("implicit_returning", True) + + if self.server_version_info <= (8, 2): + self.full_returning = self.implicit_returning = False + self.supports_native_enum = self.server_version_info >= (8, 3) if not self.supports_native_enum: self.colspecs = self.colspecs.copy() diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4d516e97c..1a8dbb4cd 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -67,6 +67,7 @@ class DefaultDialect(interfaces.Dialect): preexecute_autoincrement_sequences = False postfetch_lastrowid = True implicit_returning = False + full_returning = False cte_follows_insert = False diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index b29bc22d4..ead52a3f8 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -1259,6 +1259,10 @@ class IteratorResult(Result): return list(itertools.islice(self.iterator, 0, size)) +def null_result(): + return IteratorResult(SimpleResultMetaData([]), iter([])) + + class ChunkedIteratorResult(IteratorResult): """An :class:`.IteratorResult` that works from an iterator-producing callable. diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 0983807cb..9d7266d1a 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -220,7 +220,6 @@ def execute_and_instances(orm_context): update_options = active_options = orm_context.update_delete_options session = orm_context.session - # orm_query = orm_context.orm_query def iter_for_shard(shard_id, load_options, update_options): execution_options = dict(orm_context.local_execution_options) diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index f380229e1..77237f089 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -193,8 +193,17 @@ class ORMCompileState(CompileState): @classmethod def orm_pre_session_exec( - cls, session, statement, params, execution_options, bind_arguments + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_reentrant_invoke, ): + if is_reentrant_invoke: + return statement, execution_options + load_options = execution_options.get( "_sa_orm_load_options", QueryContext.default_load_options ) @@ -220,7 +229,7 @@ class ORMCompileState(CompileState): if load_options._autoflush: session._autoflush() - return execution_options + return statement, execution_options @classmethod def orm_setup_cursor_result( @@ -2259,9 +2268,20 @@ class _ColumnEntity(_QueryEntity): ) if _entity: - _ORMColumnEntity( - compile_state, column, _entity, parent_bundle=parent_bundle - ) + if "identity_token" in column._annotations: + _IdentityTokenEntity( + compile_state, + column, + _entity, + parent_bundle=parent_bundle, + ) + else: + _ORMColumnEntity( + compile_state, + column, + _entity, + parent_bundle=parent_bundle, + ) else: _RawColumnEntity( compile_state, column, parent_bundle=parent_bundle @@ -2462,3 +2482,14 @@ class _ORMColumnEntity(_ColumnEntity): compile_state.primary_columns.append(column) self._fetch_column = column + + +class _IdentityTokenEntity(_ORMColumnEntity): + def setup_compile_state(self, compile_state): + pass + + def row_processor(self, context, result): + def getter(row): + return context.load_options._refresh_identity_token + + return getter, self._label_name, self._extra_entities diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 51bc8e426..caa9ffe10 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -35,6 +35,10 @@ _straight_ops = set( ) ) +_extended_ops = { + operators.in_op: (lambda a, b: a in b), + operators.notin_op: (lambda a, b: a not in b), +} _notimplemented_ops = set( getattr(operators, op) @@ -43,9 +47,8 @@ _notimplemented_ops = set( "notlike_op", "ilike_op", "notilike_op", + "startswith_op", "between_op", - "in_op", - "notin_op", "endswith_op", "concat_op", ) @@ -136,6 +139,17 @@ class EvaluatorCompiler(object): return False return True + elif clause.operator is operators.comma_op: + + def evaluate(obj): + values = [] + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is None: + return None + values.append(value) + return tuple(values) + else: raise UnevaluatableError( "Cannot evaluate clauselist with operator %s" % clause.operator @@ -158,6 +172,16 @@ class EvaluatorCompiler(object): def evaluate(obj): return eval_left(obj) != eval_right(obj) + elif operator in _extended_ops: + + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is None or right_val is None: + return None + + return _extended_ops[operator](left_val, right_val) + elif operator in _straight_ops: def evaluate(obj): diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index bec6da74d..ef0e9a49b 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2240,7 +2240,6 @@ class Mapper( "entity_namespace": self, "parententity": self, "parentmapper": self, - "compile_state_plugin": "orm", } if self.persist_selectable is not self.local_table: # joined table inheritance, with polymorphic selectable, @@ -2250,7 +2249,6 @@ class Mapper( "entity_namespace": self, "parententity": self, "parentmapper": self, - "compile_state_plugin": "orm", } )._set_propagate_attrs( {"compile_state_plugin": "orm", "plugin_subject": self} @@ -2260,6 +2258,23 @@ class Mapper( {"compile_state_plugin": "orm", "plugin_subject": self} ) + @util.memoized_property + def select_identity_token(self): + return ( + expression.null() + ._annotate( + { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + "identity_token": True, + } + ) + ._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + ) + @property def selectable(self): """The :func:`_expression.select` construct this diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 8393eaf74..bd8efe77f 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -28,6 +28,7 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..engine import result as _result from ..future import select as future_select from ..sql import coercions from ..sql import expression @@ -1672,8 +1673,17 @@ class BulkUDCompileState(CompileState): @classmethod def orm_pre_session_exec( - cls, session, statement, params, execution_options, bind_arguments + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_reentrant_invoke, ): + if is_reentrant_invoke: + return statement, execution_options + sync = execution_options.get("synchronize_session", None) if sync is None: sync = statement._execution_options.get( @@ -1706,6 +1716,17 @@ class BulkUDCompileState(CompileState): if update_options._autoflush: session._autoflush() + statement = statement._annotate( + {"synchronize_session": update_options._synchronize_session} + ) + + # this stage of the execution is called before the do_orm_execute event + # hook. meaning for an extension like horizontal sharding, this step + # happens before the extension splits out into multiple backends and + # runs only once. if we do pre_sync_fetch, we execute a SELECT + # statement, which the horizontal sharding extension splits amongst the + # shards and combines the results together. + if update_options._synchronize_session == "evaluate": update_options = cls._do_pre_synchronize_evaluate( session, @@ -1725,19 +1746,31 @@ class BulkUDCompileState(CompileState): update_options, ) - return util.immutabledict(execution_options).union( - dict(_sa_orm_update_options=update_options) + return ( + statement, + util.immutabledict(execution_options).union( + dict(_sa_orm_update_options=update_options) + ), ) @classmethod def orm_setup_cursor_result( cls, session, statement, execution_options, bind_arguments, result ): + + # this stage of the execution is called after the + # do_orm_execute event hook. meaning for an extension like + # horizontal sharding, this step happens *within* the horizontal + # sharding event handler which calls session.execute() re-entrantly + # and will occur for each backend individually. + # the sharding extension then returns its own merged result from the + # individual ones we return here. + update_options = execution_options["_sa_orm_update_options"] if update_options._synchronize_session == "evaluate": - cls._do_post_synchronize_evaluate(session, update_options) + cls._do_post_synchronize_evaluate(session, result, update_options) elif update_options._synchronize_session == "fetch": - cls._do_post_synchronize_fetch(session, update_options) + cls._do_post_synchronize_fetch(session, result, update_options) return result @@ -1767,18 +1800,6 @@ class BulkUDCompileState(CompileState): def eval_condition(obj): return True - # TODO: something more robust for this conditional - if statement.__visit_name__ == "update": - resolved_values = cls._get_resolved_values(mapper, statement) - value_evaluators = {} - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - for key, value in resolved_keys_as_propnames: - value_evaluators[key] = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError as err: util.raise_( sa_exc.InvalidRequestError( @@ -1789,13 +1810,35 @@ class BulkUDCompileState(CompileState): from_=err, ) - # TODO: detect when the where clause is a trivial primary key match + if statement.__visit_name__ == "update": + resolved_values = cls._get_resolved_values(mapper, statement) + value_evaluators = {} + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + + # TODO: detect when the where clause is a trivial primary key match. matched_objects = [ obj for (cls, pk, identity_token,), obj in session.identity_map.items() if issubclass(cls, target_cls) and eval_condition(obj) - and identity_token == update_options._refresh_identity_token + and ( + update_options._refresh_identity_token is None + # TODO: coverage for the case where horiziontal sharding + # invokes an update() or delete() given an explicit identity + # token up front + or identity_token == update_options._refresh_identity_token + ) ] return update_options + { "_matched_objects": matched_objects, @@ -1868,29 +1911,56 @@ class BulkUDCompileState(CompileState): ): mapper = update_options._subject_mapper - if mapper: - primary_table = mapper.local_table - else: - primary_table = statement._raw_columns[0] - - # note this creates a Select() *without* the ORM plugin. - # we don't want that here. - select_stmt = future_select(*primary_table.primary_key) + select_stmt = future_select( + *(mapper.primary_key + (mapper.select_identity_token,)) + ) select_stmt._where_criteria = statement._where_criteria - matched_rows = session.execute( - select_stmt, params, execution_options, bind_arguments - ).fetchall() + def skip_for_full_returning(orm_context): + bind = orm_context.session.get_bind(**orm_context.bind_arguments) + if bind.dialect.full_returning: + return _result.null_result() + else: + return None + + result = session.execute( + select_stmt, + params, + execution_options, + bind_arguments, + _add_event=skip_for_full_returning, + ) + matched_rows = result.fetchall() + + value_evaluators = _EMPTY_DICT if statement.__visit_name__ == "update": + target_cls = mapper.class_ + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) resolved_values = cls._get_resolved_values(mapper, statement) resolved_keys_as_propnames = cls._resolved_keys_as_propnames( mapper, resolved_values ) + + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + value_evaluators = {} + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + else: resolved_keys_as_propnames = _EMPTY_DICT return update_options + { + "_value_evaluators": value_evaluators, "_matched_rows": matched_rows, "_resolved_keys_as_propnames": resolved_keys_as_propnames, } @@ -1925,15 +1995,23 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): elif statement._values: new_stmt._values = self._resolved_values + if ( + statement._annotations.get("synchronize_session", None) == "fetch" + and compiler.dialect.full_returning + ): + new_stmt = new_stmt.returning(*mapper.primary_key) + UpdateDMLState.__init__(self, new_stmt, compiler, **kw) return self @classmethod - def _do_post_synchronize_evaluate(cls, session, update_options): + def _do_post_synchronize_evaluate(cls, session, result, update_options): states = set() evaluated_keys = list(update_options._value_evaluators.keys()) + values = update_options._resolved_keys_as_propnames + attrib = set(k for k, v in values) for obj in update_options._matched_objects: state, dict_ = ( @@ -1941,9 +2019,15 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): attributes.instance_dict(obj), ) - assert ( - state.identity_token == update_options._refresh_identity_token - ) + # the evaluated states were gathered across all identity tokens. + # however the post_sync events are called per identity token, + # so filter. + if ( + update_options._refresh_identity_token is not None + and state.identity_token + != update_options._refresh_identity_token + ): + continue # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) @@ -1954,38 +2038,64 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): state._commit(dict_, list(to_evaluate)) - # expire attributes with pending changes - # (there was no autoflush, so they are overwritten) - state._expire_attributes( - dict_, set(evaluated_keys).difference(to_evaluate) - ) + to_expire = attrib.intersection(dict_).difference(to_evaluate) + if to_expire: + state._expire_attributes(dict_, to_expire) + states.add(state) session._register_altered(states) @classmethod - def _do_post_synchronize_fetch(cls, session, update_options): + def _do_post_synchronize_fetch(cls, session, result, update_options): target_mapper = update_options._subject_mapper - states = set( - [ - attributes.instance_state(session.identity_map[identity_key]) - for identity_key in [ - target_mapper.identity_key_from_primary_key( - list(primary_key), - identity_token=update_options._refresh_identity_token, - ) - for primary_key in update_options._matched_rows + states = set() + evaluated_keys = list(update_options._value_evaluators.keys()) + + if result.returns_rows: + matched_rows = [ + tuple(row) + (update_options._refresh_identity_token,) + for row in result.all() + ] + else: + matched_rows = update_options._matched_rows + + objs = [ + session.identity_map[identity_key] + for identity_key in [ + target_mapper.identity_key_from_primary_key( + list(primary_key), identity_token=identity_token, + ) + for primary_key, identity_token in [ + (row[0:-1], row[-1]) for row in matched_rows ] - if identity_key in session.identity_map + if update_options._refresh_identity_token is None + or identity_token == update_options._refresh_identity_token ] - ) + if identity_key in session.identity_map + ] values = update_options._resolved_keys_as_propnames attrib = set(k for k, v in values) - for state in states: - to_expire = attrib.intersection(state.dict) + + for obj in objs: + state, dict_ = ( + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) + + to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: + dict_[key] = update_options._value_evaluators[key](obj) + state.manager.dispatch.refresh(state, None, to_evaluate) + + state._commit(dict_, list(to_evaluate)) + + to_expire = attrib.intersection(dict_).difference(to_evaluate) if to_expire: - session._expire_state(state, to_expire) + state._expire_attributes(dict_, to_expire) + + states.add(state) session._register_altered(states) @@ -1995,14 +2105,24 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) - self.mapper = statement.table._annotations.get("parentmapper", None) + self.mapper = mapper = statement.table._annotations.get( + "parentmapper", None + ) + + if ( + mapper + and statement._annotations.get("synchronize_session", None) + == "fetch" + and compiler.dialect.full_returning + ): + statement = statement.returning(*mapper.primary_key) DeleteDMLState.__init__(self, statement, compiler, **kw) return self @classmethod - def _do_post_synchronize_evaluate(cls, session, update_options): + def _do_post_synchronize_evaluate(cls, session, result, update_options): session._remove_newly_deleted( [ @@ -2012,15 +2132,25 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): ) @classmethod - def _do_post_synchronize_fetch(cls, session, update_options): + def _do_post_synchronize_fetch(cls, session, result, update_options): target_mapper = update_options._subject_mapper - for primary_key in update_options._matched_rows: + if result.returns_rows: + matched_rows = [ + tuple(row) + (update_options._refresh_identity_token,) + for row in result.all() + ] + else: + matched_rows = update_options._matched_rows + + for row in matched_rows: + primary_key = row[0:-1] + identity_token = row[-1] + # TODO: inline this and call remove_newly_deleted # once identity_key = target_mapper.identity_key_from_primary_key( - list(primary_key), - identity_token=update_options._refresh_identity_token, + list(primary_key), identity_token=identity_token, ) if identity_key in session.identity_map: session._remove_newly_deleted( diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 5ad8bcf2f..a398da793 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -116,6 +116,8 @@ class ORMExecuteState(util.MemoizedSlots): "_merged_execution_options", "bind_arguments", "_compile_state_cls", + "_starting_event_idx", + "_events_todo", ) def __init__( @@ -126,6 +128,7 @@ class ORMExecuteState(util.MemoizedSlots): execution_options, bind_arguments, compile_state_cls, + events_todo, ): self.session = session self.statement = statement @@ -133,6 +136,10 @@ class ORMExecuteState(util.MemoizedSlots): self._execution_options = execution_options self.bind_arguments = bind_arguments self._compile_state_cls = compile_state_cls + self._events_todo = list(events_todo) + + def _remaining_events(self): + return self._events_todo[self._starting_event_idx + 1 :] def invoke_statement( self, @@ -200,7 +207,11 @@ class ORMExecuteState(util.MemoizedSlots): _execution_options = self._execution_options return self.session.execute( - statement, _params, _execution_options, _bind_arguments + statement, + _params, + _execution_options, + _bind_arguments, + _parent_execute_state=self, ) @property @@ -1376,6 +1387,8 @@ class Session(_SessionClassMethods): params=None, execution_options=util.immutabledict(), bind_arguments=None, + _parent_execute_state=None, + _add_event=None, **kw ): r"""Execute a SQL expression construct or string statement within @@ -1521,8 +1534,16 @@ class Session(_SessionClassMethods): compile_state_cls = None if compile_state_cls is not None: - execution_options = compile_state_cls.orm_pre_session_exec( - self, statement, params, execution_options, bind_arguments + ( + statement, + execution_options, + ) = compile_state_cls.orm_pre_session_exec( + self, + statement, + params, + execution_options, + bind_arguments, + _parent_execute_state is not None, ) else: bind_arguments.setdefault("clause", statement) @@ -1531,22 +1552,28 @@ class Session(_SessionClassMethods): execution_options, {"future_result": True} ) - if self.dispatch.do_orm_execute: - # run this event whether or not we are in ORM mode - skip_events = bind_arguments.get("_sa_skip_events", False) - if not skip_events: - orm_exec_state = ORMExecuteState( - self, - statement, - params, - execution_options, - bind_arguments, - compile_state_cls, - ) - for fn in self.dispatch.do_orm_execute: - result = fn(orm_exec_state) - if result: - return result + if _parent_execute_state: + events_todo = _parent_execute_state._remaining_events() + else: + events_todo = self.dispatch.do_orm_execute + if _add_event: + events_todo = list(events_todo) + [_add_event] + + if events_todo: + orm_exec_state = ORMExecuteState( + self, + statement, + params, + execution_options, + bind_arguments, + compile_state_cls, + events_todo, + ) + for idx, fn in enumerate(events_todo): + orm_exec_state._starting_event_idx = idx + result = fn(orm_exec_state) + if result: + return result bind = self.get_bind(**bind_arguments) @@ -1729,7 +1756,12 @@ class Session(_SessionClassMethods): self._add_bind(table, bind) def get_bind( - self, mapper=None, clause=None, bind=None, _sa_skip_events=None + self, + mapper=None, + clause=None, + bind=None, + _sa_skip_events=None, + _sa_skip_for_implicit_returning=False, ): """Return a "bind" to which this :class:`.Session` is bound. diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 2d51e7c9b..163276ca9 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -312,12 +312,30 @@ class SuiteRequirements(Requirements): return exclusions.open() @property + def full_returning(self): + """target platform supports RETURNING completely, including + multiple rows returned. + + """ + + return exclusions.only_if( + lambda config: config.db.dialect.full_returning, + "%(database)s %(does_support)s 'RETURNING of multiple rows'", + ) + + @property def returning(self): - """target platform supports RETURNING.""" + """target platform supports RETURNING for at least one row. + + .. seealso:: + + :attr:`.Requirements.full_returning` + + """ return exclusions.only_if( lambda config: config.db.dialect.implicit_returning, - "%(database)s %(does_support)s 'returning'", + "%(database)s %(does_support)s 'RETURNING of a single row'", ) @property |
