diff options
Diffstat (limited to 'lib/sqlalchemy/orm/bulk_persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/bulk_persistence.py | 1459 |
1 files changed, 1141 insertions, 318 deletions
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 225292d17..3ed34a57a 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -15,24 +15,32 @@ specifically outside of the flush() process. from __future__ import annotations from typing import Any +from typing import cast from typing import Dict from typing import Iterable +from typing import Optional +from typing import overload from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import attributes +from . import context from . import evaluator from . import exc as orm_exc +from . import loading from . import persistence from .base import NO_VALUE from .context import AbstractORMCompileState +from .context import FromStatement +from .context import ORMFromStatementCompileState +from .context import QueryContext from .. import exc as sa_exc -from .. import sql from .. import util from ..engine import Dialect from ..engine import result as _result from ..sql import coercions +from ..sql import dml from ..sql import expression from ..sql import roles from ..sql import select @@ -48,16 +56,24 @@ from ..util.typing import Literal if TYPE_CHECKING: from .mapper import Mapper + from .session import _BindArguments from .session import ORMExecuteState + from .session import Session from .session import SessionTransaction from .state import InstanceState + from ..engine import Connection + from ..engine import cursor + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _ExecuteOptionsParameter _O = TypeVar("_O", bound=object) -_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"] +_SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"] +_DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"] +@overload def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], @@ -65,7 +81,36 @@ def _bulk_insert( isstates: bool, return_defaults: bool, render_nulls: bool, + use_orm_insert_stmt: Literal[None] = ..., + execution_options: Optional[_ExecuteOptionsParameter] = ..., ) -> None: + ... + + +@overload +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = ..., + execution_options: Optional[_ExecuteOptionsParameter] = ..., +) -> cursor.CursorResult[Any]: + ... + + +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, +) -> Optional[cursor.CursorResult[Any]]: base_mapper = mapper.base_mapper if session_transaction.session.connection_callable: @@ -81,13 +126,27 @@ def _bulk_insert( else: mappings = [state.dict for state in mappings] else: - mappings = list(mappings) + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) connection = session_transaction.connection(base_mapper) + + return_result: Optional[cursor.CursorResult[Any]] = None + for table, super_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(super_mapper): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue + is_joined_inh_supertable = super_mapper is not mapper + bookkeeping = ( + is_joined_inh_supertable + or return_defaults + or ( + use_orm_insert_stmt is not None + and bool(use_orm_insert_stmt._returning) + ) + ) + records = ( ( None, @@ -112,18 +171,25 @@ def _bulk_insert( table, ((None, mapping, mapper, connection) for mapping in mappings), bulk=True, - return_defaults=return_defaults, + return_defaults=bookkeeping, render_nulls=render_nulls, ) ) - persistence._emit_insert_statements( + result = persistence._emit_insert_statements( base_mapper, None, super_mapper, table, records, - bookkeeping=return_defaults, + bookkeeping=bookkeeping, + use_orm_insert_stmt=use_orm_insert_stmt, + execution_options=execution_options, ) + if use_orm_insert_stmt is not None: + if not use_orm_insert_stmt._returning or return_result is None: + return_result = result + elif result.returns_rows: + return_result = return_result.splice_horizontally(result) if return_defaults and isstates: identity_cls = mapper._identity_class @@ -134,14 +200,43 @@ def _bulk_insert( tuple([dict_[key] for key in identity_props]), ) + if use_orm_insert_stmt is not None: + assert return_result is not None + return return_result + +@overload def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, update_changed_only: bool, + use_orm_update_stmt: Literal[None] = ..., ) -> None: + ... + + +@overload +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = ..., +) -> _result.Result[Any]: + ... + + +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = None, +) -> Optional[_result.Result[Any]]: base_mapper = mapper.base_mapper search_keys = mapper._primary_key_propkeys @@ -161,7 +256,8 @@ def _bulk_update( else: mappings = [state.dict for state in mappings] else: - mappings = list(mappings) + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) if session_transaction.session.connection_callable: raise NotImplementedError( @@ -172,7 +268,7 @@ def _bulk_update( connection = session_transaction.connection(base_mapper) for table, super_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(super_mapper): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue records = persistence._collect_update_commands( @@ -193,8 +289,8 @@ def _bulk_update( for mapping in mappings ), bulk=True, + use_orm_update_stmt=use_orm_update_stmt, ) - persistence._emit_update_statements( base_mapper, None, @@ -202,10 +298,125 @@ def _bulk_update( table, records, bookkeeping=False, + use_orm_update_stmt=use_orm_update_stmt, ) + if use_orm_update_stmt is not None: + return _result.null_result() + + +def _expand_composites(mapper, mappings): + composite_attrs = mapper.composites + if not composite_attrs: + return + + composite_keys = set(composite_attrs.keys()) + populators = { + key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() + for key in composite_keys + } + for mapping in mappings: + for key in composite_keys.intersection(mapping): + populators[key](mapping) + class ORMDMLState(AbstractORMCompileState): + is_dml_returning = True + from_statement_ctx: Optional[ORMFromStatementCompileState] = None + + @classmethod + def _get_orm_crud_kv_pairs( + cls, mapper, statement, kv_iterator, needs_to_be_cacheable + ): + + core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs + + for k, v in kv_iterator: + k = coercions.expect(roles.DMLColumnRole, k) + + if isinstance(k, str): + desc = _entity_namespace_key(mapper, k, default=NO_VALUE) + if desc is NO_VALUE: + yield ( + coercions.expect(roles.DMLColumnRole, k), + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + if needs_to_be_cacheable + else v, + ) + else: + yield from core_get_crud_kv_pairs( + statement, + desc._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + elif "entity_namespace" in k._annotations: + k_anno = k._annotations + attr = _entity_namespace_key( + k_anno["entity_namespace"], k_anno["proxy_key"] + ) + yield from core_get_crud_kv_pairs( + statement, + attr._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + else: + yield ( + k, + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ), + ) + + @classmethod + def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_multi_crud_kv_pairs( + statement, kv_iterator + ) + + return [ + dict( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, statement, value_dict.items(), False + ) + ) + for value_dict in kv_iterator + ] + + @classmethod + def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): + assert ( + needs_to_be_cacheable + ), "no test coverage for needs_to_be_cacheable=False" + + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_crud_kv_pairs( + statement, kv_iterator, needs_to_be_cacheable + ) + + return list( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, + statement, + kv_iterator, + needs_to_be_cacheable, + ) + ) + @classmethod def get_entity_description(cls, statement): ext_info = statement.table._annotations["parententity"] @@ -250,18 +461,101 @@ class ORMDMLState(AbstractORMCompileState): ] ] + def _setup_orm_returning( + self, + compiler, + orm_level_statement, + dml_level_statement, + use_supplemental_cols=True, + dml_mapper=None, + ): + """establish ORM column handlers for an INSERT, UPDATE, or DELETE + which uses explicit returning(). + + called within compilation level create_for_statement. + + The _return_orm_returning() method then receives the Result + after the statement was executed, and applies ORM loading to the + state that we first established here. + + """ + + if orm_level_statement._returning: + + fs = FromStatement( + orm_level_statement._returning, dml_level_statement + ) + fs = fs.options(*orm_level_statement._with_options) + self.select_statement = fs + self.from_statement_ctx = ( + fsc + ) = ORMFromStatementCompileState.create_for_statement(fs, compiler) + fsc.setup_dml_returning_compile_state(dml_mapper) + + dml_level_statement = dml_level_statement._generate() + dml_level_statement._returning = () + + cols_to_return = [c for c in fsc.primary_columns if c is not None] + + # since we are splicing result sets together, make sure there + # are columns of some kind returned in each result set + if not cols_to_return: + cols_to_return.extend(dml_mapper.primary_key) + + if use_supplemental_cols: + dml_level_statement = dml_level_statement.return_defaults( + supplemental_cols=cols_to_return + ) + else: + dml_level_statement = dml_level_statement.returning( + *cols_to_return + ) + + return dml_level_statement + + @classmethod + def _return_orm_returning( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + if compile_state.from_statement_ctx: + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + querycontext = QueryContext( + compile_state.from_statement_ctx, + compile_state.select_statement, + params, + session, + load_options, + execution_options, + bind_arguments, + ) + return loading.instances(result, querycontext) + else: + return result + class BulkUDCompileState(ORMDMLState): class default_update_options(Options): - _synchronize_session: _SynchronizeSessionArgument = "evaluate" - _is_delete_using = False - _is_update_from = False - _autoflush = True - _subject_mapper = None + _dml_strategy: _DMLStrategyArgument = "auto" + _synchronize_session: _SynchronizeSessionArgument = "auto" + _can_use_returning: bool = False + _is_delete_using: bool = False + _is_update_from: bool = False + _autoflush: bool = True + _subject_mapper: Optional[Mapper[Any]] = None _resolved_values = EMPTY_DICT - _resolved_keys_as_propnames = EMPTY_DICT - _value_evaluators = EMPTY_DICT - _matched_objects = None + _eval_condition = None _matched_rows = None _refresh_identity_token = None @@ -295,19 +589,16 @@ class BulkUDCompileState(ORMDMLState): execution_options, ) = BulkUDCompileState.default_update_options.from_execution_options( "_sa_orm_update_options", - {"synchronize_session", "is_delete_using", "is_update_from"}, + { + "synchronize_session", + "is_delete_using", + "is_update_from", + "dml_strategy", + }, execution_options, statement._execution_options, ) - sync = update_options._synchronize_session - if sync is not None: - if sync not in ("evaluate", "fetch", False): - raise sa_exc.ArgumentError( - "Valid strategies for session synchronization " - "are 'evaluate', 'fetch', False" - ) - bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -318,43 +609,86 @@ class BulkUDCompileState(ORMDMLState): update_options += {"_subject_mapper": plugin_subject.mapper} + if not isinstance(params, list): + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "orm"} + elif update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "bulk"} + elif update_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + sync = update_options._synchronize_session + if sync is not None: + if sync not in ("auto", "evaluate", "fetch", False): + raise sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are 'auto', 'evaluate', 'fetch', False" + ) + if update_options._dml_strategy == "bulk" and sync == "fetch": + raise sa_exc.InvalidRequestError( + "The 'fetch' synchronization strategy is not available " + "for 'bulk' ORM updates (i.e. multiple parameter sets)" + ) + if update_options._autoflush: session._autoflush() + if update_options._dml_strategy == "orm": + + if update_options._synchronize_session == "auto": + update_options = cls._do_pre_synchronize_auto( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "auto": + update_options += {"_synchronize_session": "evaluate"} + + # indicators from the "pre exec" step that are then + # added to the DML statement, which will also be part of the cache + # key. The compile level create_for_statement() method will then + # consume these at compiler time. statement = statement._annotate( { "synchronize_session": update_options._synchronize_session, "is_delete_using": update_options._is_delete_using, "is_update_from": update_options._is_update_from, + "dml_strategy": update_options._dml_strategy, + "can_use_returning": update_options._can_use_returning, } ) - # this stage of the execution is called before the do_orm_execute event - # hook. meaning for an extension like horizontal sharding, this step - # happens before the extension splits out into multiple backends and - # runs only once. if we do pre_sync_fetch, we execute a SELECT - # statement, which the horizontal sharding extension splits amongst the - # shards and combines the results together. - - if update_options._synchronize_session == "evaluate": - update_options = cls._do_pre_synchronize_evaluate( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._synchronize_session == "fetch": - update_options = cls._do_pre_synchronize_fetch( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - return ( statement, util.immutabledict(execution_options).union( @@ -382,12 +716,30 @@ class BulkUDCompileState(ORMDMLState): # individual ones we return here. update_options = execution_options["_sa_orm_update_options"] - if update_options._synchronize_session == "evaluate": - cls._do_post_synchronize_evaluate(session, result, update_options) - elif update_options._synchronize_session == "fetch": - cls._do_post_synchronize_fetch(session, result, update_options) + if update_options._dml_strategy == "orm": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_evaluate( + session, statement, result, update_options + ) + elif update_options._synchronize_session == "fetch": + cls._do_post_synchronize_fetch( + session, statement, result, update_options + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_bulk_evaluate( + session, params, result, update_options + ) + return result - return result + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) @classmethod def _adjust_for_extra_criteria(cls, global_attributes, ext_info): @@ -473,11 +825,76 @@ class BulkUDCompileState(ORMDMLState): primary_key_convert = [ lookup[bpk] for bpk in mapper.base_mapper.primary_key ] - return [tuple(row[idx] for idx in primary_key_convert) for row in rows] @classmethod - def _do_pre_synchronize_evaluate( + def _get_matched_objects_on_criteria(cls, update_options, states): + mapper = update_options._subject_mapper + eval_condition = update_options._eval_condition + + raw_data = [ + (state.obj(), state, state.dict) + for state in states + if state.mapper.isa(mapper) and not state.expired + ] + + identity_token = update_options._refresh_identity_token + if identity_token is not None: + raw_data = [ + (obj, state, dict_) + for obj, state, dict_ in raw_data + if state.identity_token == identity_token + ] + + result = [] + for obj, state, dict_ in raw_data: + evaled_condition = eval_condition(obj) + + # caution: don't use "in ()" or == here, _EXPIRE_OBJECT + # evaluates as True for all comparisons + if ( + evaled_condition is True + or evaled_condition is evaluator._EXPIRED_OBJECT + ): + result.append( + ( + obj, + state, + dict_, + evaled_condition is evaluator._EXPIRED_OBJECT, + ) + ) + return result + + @classmethod + def _eval_condition_from_statement(cls, update_options, statement): + mapper = update_options._subject_mapper + target_cls = mapper.class_ + + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + crit = () + if statement._where_criteria: + crit += statement._where_criteria + + global_attributes = {} + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(global_attributes) + + if global_attributes: + crit += cls._adjust_for_extra_criteria(global_attributes, mapper) + + if crit: + eval_condition = evaluator_compiler.process(*crit) + else: + + def eval_condition(obj): + return True + + return eval_condition + + @classmethod + def _do_pre_synchronize_auto( cls, session, statement, @@ -486,33 +903,59 @@ class BulkUDCompileState(ORMDMLState): bind_arguments, update_options, ): - mapper = update_options._subject_mapper - target_cls = mapper.class_ + """setup auto sync strategy + + + "auto" checks if we can use "evaluate" first, then falls back + to "fetch" + + evaluate is vastly more efficient for the common case + where session is empty, only has a few objects, and the UPDATE + statement can potentially match thousands/millions of rows. - value_evaluators = resolved_keys_as_propnames = EMPTY_DICT + OTOH more complex criteria that fails to work with "evaluate" + we would hope usually correlates with fewer net rows. + + """ try: - evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - crit = () - if statement._where_criteria: - crit += statement._where_criteria + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) - global_attributes = {} - for opt in statement._with_options: - if opt._is_criteria_option: - opt.get_global_criteria(global_attributes) + except evaluator.UnevaluatableError: + pass + else: + return update_options + { + "_eval_condition": eval_condition, + "_synchronize_session": "evaluate", + } - if global_attributes: - crit += cls._adjust_for_extra_criteria( - global_attributes, mapper - ) + update_options += {"_synchronize_session": "fetch"} + return cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) - if crit: - eval_condition = evaluator_compiler.process(*crit) - else: + @classmethod + def _do_pre_synchronize_evaluate( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): - def eval_condition(obj): - return True + try: + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) except evaluator.UnevaluatableError as err: raise sa_exc.InvalidRequestError( @@ -521,52 +964,8 @@ class BulkUDCompileState(ORMDMLState): "synchronize_session execution option." % err ) from err - if statement.__visit_name__ == "lambda_element": - # ._resolved is called on every LambdaElement in order to - # generate the cache key, so this access does not add - # additional expense - effective_statement = statement._resolved - else: - effective_statement = statement - - if effective_statement.__visit_name__ == "update": - resolved_values = cls._get_resolved_values( - mapper, effective_statement - ) - value_evaluators = {} - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - for key, value in resolved_keys_as_propnames: - try: - _evaluator = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError: - pass - else: - value_evaluators[key] = _evaluator - - # TODO: detect when the where clause is a trivial primary key match. - matched_objects = [ - state.obj() - for state in session.identity_map.all_states() - if state.mapper.isa(mapper) - and not state.expired - and eval_condition(state.obj()) - and ( - update_options._refresh_identity_token is None - # TODO: coverage for the case where horizontal sharding - # invokes an update() or delete() given an explicit identity - # token up front - or state.identity_token - == update_options._refresh_identity_token - ) - ] return update_options + { - "_matched_objects": matched_objects, - "_value_evaluators": value_evaluators, - "_resolved_keys_as_propnames": resolved_keys_as_propnames, + "_eval_condition": eval_condition, } @classmethod @@ -584,12 +983,6 @@ class BulkUDCompileState(ORMDMLState): def _resolved_keys_as_propnames(cls, mapper, resolved_values): values = [] for k, v in resolved_values: - if isinstance(k, attributes.QueryableAttribute): - values.append((k.key, v)) - continue - elif hasattr(k, "__clause_element__"): - k = k.__clause_element__() - if mapper and isinstance(k, expression.ColumnElement): try: attr = mapper._columntoproperty[k] @@ -599,7 +992,8 @@ class BulkUDCompileState(ORMDMLState): values.append((attr.key, v)) else: raise sa_exc.InvalidRequestError( - "Invalid expression type: %r" % k + "Attribute name not found, can't be " + "synchronized back to objects: %r" % k ) return values @@ -622,14 +1016,43 @@ class BulkUDCompileState(ORMDMLState): ) select_stmt._where_criteria = statement._where_criteria + # conditionally run the SELECT statement for pre-fetch, testing the + # "bind" for if we can use RETURNING or not using the do_orm_execute + # event. If RETURNING is available, the do_orm_execute event + # will cancel the SELECT from being actually run. + # + # The way this is organized seems strange, why don't we just + # call can_use_returning() before invoking the statement and get + # answer?, why does this go through the whole execute phase using an + # event? Answer: because we are integrating with extensions such + # as the horizontal sharding extention that "multiplexes" an individual + # statement run through multiple engines, and it uses + # do_orm_execute() to do that. + + can_use_returning = None + def skip_for_returning(orm_context: ORMExecuteState) -> Any: bind = orm_context.session.get_bind(**orm_context.bind_arguments) - if cls.can_use_returning( + nonlocal can_use_returning + + per_bind_result = cls.can_use_returning( bind.dialect, mapper, is_update_from=update_options._is_update_from, is_delete_using=update_options._is_delete_using, - ): + ) + + if can_use_returning is not None: + if can_use_returning != per_bind_result: + raise sa_exc.InvalidRequestError( + "For synchronize_session='fetch', can't mix multiple " + "backends where some support RETURNING and others " + "don't" + ) + else: + can_use_returning = per_bind_result + + if per_bind_result: return _result.null_result() else: return None @@ -643,52 +1066,22 @@ class BulkUDCompileState(ORMDMLState): ) matched_rows = result.fetchall() - value_evaluators = EMPTY_DICT - - if statement.__visit_name__ == "lambda_element": - # ._resolved is called on every LambdaElement in order to - # generate the cache key, so this access does not add - # additional expense - effective_statement = statement._resolved - else: - effective_statement = statement - - if effective_statement.__visit_name__ == "update": - target_cls = mapper.class_ - evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - resolved_values = cls._get_resolved_values( - mapper, effective_statement - ) - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - value_evaluators = {} - for key, value in resolved_keys_as_propnames: - try: - _evaluator = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError: - pass - else: - value_evaluators[key] = _evaluator - - else: - resolved_keys_as_propnames = EMPTY_DICT - return update_options + { - "_value_evaluators": value_evaluators, "_matched_rows": matched_rows, - "_resolved_keys_as_propnames": resolved_keys_as_propnames, + "_can_use_returning": can_use_returning, } @CompileState.plugin_for("orm", "insert") -class ORMInsert(ORMDMLState, InsertDMLState): +class BulkORMInsert(ORMDMLState, InsertDMLState): + class default_insert_options(Options): + _dml_strategy: _DMLStrategyArgument = "auto" + _render_nulls: bool = False + _return_defaults: bool = False + _subject_mapper: Optional[Mapper[Any]] = None + + select_statement: Optional[FromStatement] = None + @classmethod def orm_pre_session_exec( cls, @@ -699,6 +1092,16 @@ class ORMInsert(ORMDMLState, InsertDMLState): bind_arguments, is_reentrant_invoke, ): + + ( + insert_options, + execution_options, + ) = BulkORMInsert.default_insert_options.from_execution_options( + "_sa_orm_insert_options", + {"dml_strategy"}, + execution_options, + statement._execution_options, + ) bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -707,22 +1110,209 @@ class ORMInsert(ORMDMLState, InsertDMLState): else: bind_arguments["mapper"] = plugin_subject.mapper + insert_options += {"_subject_mapper": plugin_subject.mapper} + + if not params: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "orm"} + elif insert_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "bulk"} + elif insert_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + if insert_options._dml_strategy != "raw": + # for ORM object loading, like ORMContext, we have to disable + # result set adapt_to_context, because we will be generating a + # new statement with specific columns that's cached inside of + # an ORMFromStatementCompileState, which we will re-use for + # each result. + if not execution_options: + execution_options = context._orm_load_exec_options + else: + execution_options = execution_options.union( + context._orm_load_exec_options + ) + + statement = statement._annotate( + {"dml_strategy": insert_options._dml_strategy} + ) + return ( statement, - util.immutabledict(execution_options), + util.immutabledict(execution_options).union( + {"_sa_orm_insert_options": insert_options} + ), ) @classmethod - def orm_setup_cursor_result( + def orm_execute_statement( cls, - session, - statement, - params, - execution_options, - bind_arguments, - result, - ): - return result + session: Session, + statement: dml.Insert, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + insert_options = execution_options.get( + "_sa_orm_insert_options", cls.default_insert_options + ) + + if insert_options._dml_strategy not in ( + "raw", + "bulk", + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM insert strategy " + "are 'raw', 'orm', 'bulk', 'auto" + ) + + result: _result.Result[Any] + + if insert_options._dml_strategy == "raw": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return result + + if insert_options._dml_strategy == "bulk": + mapper = insert_options._subject_mapper + + if ( + statement._post_values_clause is not None + and mapper._multiple_persistence_tables + ): + raise sa_exc.InvalidRequestError( + "bulk INSERT with a 'post values' clause " + "(typically upsert) not supported for multi-table " + f"mapper {mapper}" + ) + + assert mapper is not None + assert session._transaction is not None + result = _bulk_insert( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + return_defaults=insert_options._return_defaults, + render_nulls=insert_options._render_nulls, + use_orm_insert_stmt=statement, + execution_options=execution_options, + ) + elif insert_options._dml_strategy == "orm": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + else: + raise AssertionError() + + if not bool(statement._returning): + return result + + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: + + self = cast( + BulkORMInsert, + super().create_for_statement(statement, compiler, **kw), + ) + + if compiler is not None: + toplevel = not compiler.stack + else: + toplevel = True + if not toplevel: + return self + + mapper = statement._propagate_attrs["plugin_subject"] + dml_strategy = statement._annotations.get("dml_strategy", "raw") + if dml_strategy == "bulk": + self._setup_for_bulk_insert(compiler) + elif dml_strategy == "orm": + self._setup_for_orm_insert(compiler, mapper) + + return self + + @classmethod + def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): + return { + col.key if col is not None else k: v + for col, k, v in ( + (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() + ) + } + + def _setup_for_orm_insert(self, compiler, mapper): + statement = orm_level_statement = cast(dml.Insert, self.statement) + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=False, + ) + self.statement = statement + + def _setup_for_bulk_insert(self, compiler): + """establish an INSERT statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_insert_statement(). + + """ + statement = orm_level_statement = cast(dml.Insert, self.statement) + an = statement._annotations + + emit_insert_table, emit_insert_mapper = ( + an["_emit_insert_table"], + an["_emit_insert_mapper"], + ) + + statement = statement._clone() + + statement.table = emit_insert_table + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_insert_table + } + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=True, + dml_mapper=emit_insert_mapper, + ) + + self.statement = statement @CompileState.plugin_for("orm", "update") @@ -732,13 +1322,27 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): self = cls.__new__(cls) + dml_strategy = statement._annotations.get( + "dml_strategy", "unspecified" + ) + + if dml_strategy == "bulk": + self._setup_for_bulk_update(statement, compiler) + elif dml_strategy in ("orm", "unspecified"): + self._setup_for_orm_update(statement, compiler) + + return self + + def _setup_for_orm_update(self, statement, compiler, **kw): + orm_level_statement = statement + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper self.extra_criteria_entities = {} - self._resolved_values = cls._get_resolved_values(mapper, statement) + self._resolved_values = self._get_resolved_values(mapper, statement) extra_criteria_attributes = {} @@ -749,8 +1353,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if statement._values: self._resolved_values = dict(self._resolved_values) - new_stmt = sql.Update.__new__(sql.Update) - new_stmt.__dict__.update(statement.__dict__) + new_stmt = statement._clone() new_stmt.table = mapper.local_table # note if the statement has _multi_values, these @@ -762,7 +1365,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): elif statement._values: new_stmt._values = self._resolved_values - new_crit = cls._adjust_for_extra_criteria( + new_crit = self._adjust_for_extra_criteria( extra_criteria_attributes, mapper ) if new_crit: @@ -776,21 +1379,150 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): UpdateDMLState.__init__(self, new_stmt, compiler, **kw) - if compiler._annotations.get( + use_supplemental_cols = False + + synchronize_session = compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning( - compiler.dialect, mapper, is_multitable=self.is_multitable - ): - if new_stmt._returning: - raise sa_exc.InvalidRequestError( - "Can't use synchronize_session='fetch' " - "with explicit returning()" + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, mapper, is_multitable=self.is_multitable ) - self.statement = self.statement.returning( - *mapper.local_table.primary_key ) - return self + if synchronize_session == "fetch" and can_use_returning: + use_supplemental_cols = True + + # NOTE: we might want to RETURNING the actual columns to be + # synchronized also. however this is complicated and difficult + # to align against the behavior of "evaluate". Additionally, + # in a large number (if not the majority) of cases, we have the + # "evaluate" answer, usually a fixed value, in memory already and + # there's no need to re-fetch the same value + # over and over again. so perhaps if it could be RETURNING just + # the elements that were based on a SQL expression and not + # a constant. For now it doesn't quite seem worth it + new_stmt = new_stmt.return_defaults( + *(list(mapper.local_table.primary_key)) + ) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt + + def _setup_for_bulk_update(self, statement, compiler, **kw): + """establish an UPDATE statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_update_statement(). + + """ + statement = cast(dml.Update, statement) + an = statement._annotations + + emit_update_table, _ = ( + an["_emit_update_table"], + an["_emit_update_mapper"], + ) + + statement = statement._clone() + statement.table = emit_update_table + + UpdateDMLState.__init__(self, statement, compiler, **kw) + + if self._ordered_values: + raise sa_exc.InvalidRequestError( + "bulk ORM UPDATE does not support ordered_values() for " + "custom UPDATE statements with bulk parameter sets. Use a " + "non-bulk UPDATE statement or use values()." + ) + + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_update_table + } + self.statement = statement + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Update, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy not in ("orm", "auto", "bulk"): + raise sa_exc.ArgumentError( + "Valid strategies for ORM UPDATE strategy " + "are 'orm', 'auto', 'bulk'" + ) + + result: _result.Result[Any] + + if update_options._dml_strategy == "bulk": + if statement._where_criteria: + raise sa_exc.InvalidRequestError( + "WHERE clause with bulk ORM UPDATE not " + "supported right now. Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + mapper = update_options._subject_mapper + assert mapper is not None + assert session._transaction is not None + result = _bulk_update( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + update_changed_only=False, + use_orm_update_stmt=statement, + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + else: + return super().orm_execute_statement( + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) @classmethod def can_use_returning( @@ -827,119 +1559,80 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): return True @classmethod - def _get_crud_kv_pairs(cls, statement, kv_iterator): - plugin_subject = statement._propagate_attrs["plugin_subject"] - - core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs - - if not plugin_subject or not plugin_subject.mapper: - return core_get_crud_kv_pairs(statement, kv_iterator) - - mapper = plugin_subject.mapper - - values = [] - - for k, v in kv_iterator: - k = coercions.expect(roles.DMLColumnRole, k) + def _do_post_synchronize_bulk_evaluate( + cls, session, params, result, update_options + ): + if not params: + return - if isinstance(k, str): - desc = _entity_namespace_key(mapper, k, default=NO_VALUE) - if desc is NO_VALUE: - values.append( - ( - k, - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ), - ) - ) - else: - values.extend( - core_get_crud_kv_pairs( - statement, desc._bulk_update_tuples(v) - ) - ) - elif "entity_namespace" in k._annotations: - k_anno = k._annotations - attr = _entity_namespace_key( - k_anno["entity_namespace"], k_anno["proxy_key"] - ) - values.extend( - core_get_crud_kv_pairs( - statement, attr._bulk_update_tuples(v) - ) - ) - else: - values.append( - ( - k, - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ), - ) - ) - return values + mapper = update_options._subject_mapper + pk_keys = [prop.key for prop in mapper._identity_key_props] - @classmethod - def _do_post_synchronize_evaluate(cls, session, result, update_options): + identity_map = session.identity_map - states = set() - evaluated_keys = list(update_options._value_evaluators.keys()) - values = update_options._resolved_keys_as_propnames - attrib = set(k for k, v in values) - for obj in update_options._matched_objects: - - state, dict_ = ( - attributes.instance_state(obj), - attributes.instance_dict(obj), + for param in params: + identity_key = mapper.identity_key_from_primary_key( + (param[key] for key in pk_keys), + update_options._refresh_identity_token, ) - - # the evaluated states were gathered across all identity tokens. - # however the post_sync events are called per identity token, - # so filter. - if ( - update_options._refresh_identity_token is not None - and state.identity_token - != update_options._refresh_identity_token - ): + state = identity_map.fast_get_state(identity_key) + if not state: continue + evaluated_keys = set(param).difference(pk_keys) + + dict_ = state.dict # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: if key in dict_: - dict_[key] = update_options._value_evaluators[key](obj) + dict_[key] = param[key] state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) - to_expire = attrib.intersection(dict_).difference(to_evaluate) + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. + to_expire = evaluated_keys.intersection(dict_).difference( + to_evaluate + ) if to_expire: state._expire_attributes(dict_, to_expire) - states.add(state) - session._register_altered(states) + @classmethod + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), + ) + + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], + ) @classmethod - def _do_post_synchronize_fetch(cls, session, result, update_options): + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): target_mapper = update_options._subject_mapper - states = set() - evaluated_keys = list(update_options._value_evaluators.keys()) - - if result.returns_rows: - rows = cls._interpret_returning_rows(target_mapper, result.all()) + returned_defaults_rows = result.returned_defaults_rows + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in rows + for row in pk_rows ] else: matched_rows = update_options._matched_rows @@ -960,23 +1653,69 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if identity_key in session.identity_map ] - values = update_options._resolved_keys_as_propnames - attrib = set(k for k, v in values) + if not objs: + return - for obj in objs: - state, dict_ = ( - attributes.instance_state(obj), - attributes.instance_dict(obj), - ) + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [ + ( + obj, + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) + for obj in objs + ], + ) + + @classmethod + def _apply_update_set_values_to_objects( + cls, session, update_options, statement, matched_objects + ): + """apply values to objects derived from an update statement, e.g. + UPDATE..SET <values> + + """ + mapper = update_options._subject_mapper + target_cls = mapper.class_ + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + resolved_values = cls._get_resolved_values(mapper, statement) + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + value_evaluators = {} + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + + evaluated_keys = list(value_evaluators.keys()) + attrib = set(k for k, v in resolved_keys_as_propnames) + + states = set() + for obj, state, dict_ in matched_objects: to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: if key in dict_: - dict_[key] = update_options._value_evaluators[key](obj) + # only run eval for attributes that are present. + dict_[key] = value_evaluators[key](obj) + state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. to_expire = attrib.intersection(dict_).difference(to_evaluate) if to_expire: state._expire_attributes(dict_, to_expire) @@ -991,6 +1730,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) + orm_level_statement = statement + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper @@ -1002,31 +1743,97 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): if opt._is_criteria_option: opt.get_global_criteria(extra_criteria_attributes) + new_stmt = statement._clone() + new_stmt.table = mapper.local_table + new_crit = cls._adjust_for_extra_criteria( extra_criteria_attributes, mapper ) if new_crit: - statement = statement.where(*new_crit) + new_stmt = new_stmt.where(*new_crit) # do this first as we need to determine if there is # DELETE..FROM - DeleteDMLState.__init__(self, statement, compiler, **kw) + DeleteDMLState.__init__(self, new_stmt, compiler, **kw) + + use_supplemental_cols = False - if compiler._annotations.get( + synchronize_session = compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning( - compiler.dialect, - mapper, - is_multitable=self.is_multitable, - is_delete_using=compiler._annotations.get( - "is_delete_using", False - ), - ): - self.statement = statement.returning(*statement.table.primary_key) + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, + mapper, + is_multitable=self.is_multitable, + is_delete_using=compiler._annotations.get( + "is_delete_using", False + ), + ) + ) + + if can_use_returning: + use_supplemental_cols = True + + new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt return self @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Delete, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + "Bulk ORM DELETE not supported right now. " + "Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + + if update_options._dml_strategy not in ( + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM DELETE strategy are 'orm', 'auto'" + ) + + return super().orm_execute_statement( + session, statement, params, execution_options, bind_arguments, conn + ) + + @classmethod def can_use_returning( cls, dialect: Dialect, @@ -1068,25 +1875,41 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): return True @classmethod - def _do_post_synchronize_evaluate(cls, session, result, update_options): - - session._remove_newly_deleted( - [ - attributes.instance_state(obj) - for obj in update_options._matched_objects - ] + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), ) + to_delete = [] + + for _, state, dict_, is_partially_expired in matched_objects: + if is_partially_expired: + state._expire(dict_, session.identity_map._modified) + else: + to_delete.append(state) + + if to_delete: + session._remove_newly_deleted(to_delete) + @classmethod - def _do_post_synchronize_fetch(cls, session, result, update_options): + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): target_mapper = update_options._subject_mapper - if result.returns_rows: - rows = cls._interpret_returning_rows(target_mapper, result.all()) + returned_defaults_rows = result.returned_defaults_rows + + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in rows + for row in pk_rows ] else: matched_rows = update_options._matched_rows |
