# orm/bulk_persistence.py # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors """additional ORM persistence classes related to "bulk" operations, specifically outside of the flush() process. """ from __future__ import annotations from typing import Any from typing import cast from typing import Dict from typing import Iterable from typing import Optional from typing import overload from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import attributes from . import context from . import evaluator from . import exc as orm_exc from . import loading from . import persistence from .base import NO_VALUE from .context import AbstractORMCompileState from .context import FromStatement from .context import ORMFromStatementCompileState from .context import QueryContext from .. import exc as sa_exc from .. import util from ..engine import Dialect from ..engine import result as _result from ..sql import coercions from ..sql import dml from ..sql import expression from ..sql import roles from ..sql import select from ..sql import sqltypes from ..sql.base import _entity_namespace_key from ..sql.base import CompileState from ..sql.base import Options from ..sql.dml import DeleteDMLState from ..sql.dml import InsertDMLState from ..sql.dml import UpdateDMLState from ..util import EMPTY_DICT from ..util.typing import Literal if TYPE_CHECKING: from ._typing import DMLStrategyArgument from ._typing import OrmExecuteOptionsParameter from ._typing import SynchronizeSessionArgument from .mapper import Mapper from .session import _BindArguments from .session import ORMExecuteState from .session import Session from .session import SessionTransaction from .state import InstanceState from ..engine import Connection from ..engine import cursor from ..engine.interfaces import _CoreAnyExecuteParams _O = TypeVar("_O", bound=object) @overload def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, return_defaults: bool, render_nulls: bool, use_orm_insert_stmt: Literal[None] = ..., execution_options: Optional[OrmExecuteOptionsParameter] = ..., ) -> None: ... @overload def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, return_defaults: bool, render_nulls: bool, use_orm_insert_stmt: Optional[dml.Insert] = ..., execution_options: Optional[OrmExecuteOptionsParameter] = ..., ) -> cursor.CursorResult[Any]: ... def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, return_defaults: bool, render_nulls: bool, use_orm_insert_stmt: Optional[dml.Insert] = None, execution_options: Optional[OrmExecuteOptionsParameter] = None, ) -> Optional[cursor.CursorResult[Any]]: base_mapper = mapper.base_mapper if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " "not supported in bulk_insert()" ) if isstates: if return_defaults: states = [(state, state.dict) for state in mappings] mappings = [dict_ for (state, dict_) in states] else: mappings = [state.dict for state in mappings] else: mappings = [dict(m) for m in mappings] _expand_composites(mapper, mappings) connection = session_transaction.connection(base_mapper) return_result: Optional[cursor.CursorResult[Any]] = None mappers_to_run = [ (table, mp) for table, mp in base_mapper._sorted_tables.items() if table in mapper._pks_by_table ] if return_defaults: # not used by new-style bulk inserts, only used for legacy bookkeeping = True elif len(mappers_to_run) > 1: # if we have more than one table, mapper to run where we will be # either horizontally splicing, or copying values between tables, # we need the "bookkeeping" / deterministic returning order bookkeeping = True else: bookkeeping = False for table, super_mapper in mappers_to_run: # find bindparams in the statement. For bulk, we don't really know if # a key in the params applies to a different table since we are # potentially inserting for multiple tables here; looking at the # bindparam() is a lot more direct. in most cases this will # use _generate_cache_key() which is memoized, although in practice # the ultimate statement that's executed is probably not the same # object so that memoization might not matter much. extra_bp_names = ( [ b.key for b in use_orm_insert_stmt._get_embedded_bindparams() if b.key in mappings[0] ] if use_orm_insert_stmt is not None else () ) records = ( ( None, state_dict, params, mapper, connection, value_params, has_all_pks, has_all_defaults, ) for ( state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults, ) in persistence._collect_insert_commands( table, ((None, mapping, mapper, connection) for mapping in mappings), bulk=True, return_defaults=bookkeeping, render_nulls=render_nulls, include_bulk_keys=extra_bp_names, ) ) result = persistence._emit_insert_statements( base_mapper, None, super_mapper, table, records, bookkeeping=bookkeeping, use_orm_insert_stmt=use_orm_insert_stmt, execution_options=execution_options, ) if use_orm_insert_stmt is not None: if not use_orm_insert_stmt._returning or return_result is None: return_result = result elif result.returns_rows: assert bookkeeping return_result = return_result.splice_horizontally(result) if return_defaults and isstates: identity_cls = mapper._identity_class identity_props = [p.key for p in mapper._identity_key_props] for state, dict_ in states: state.key = ( identity_cls, tuple([dict_[key] for key in identity_props]), ) if use_orm_insert_stmt is not None: assert return_result is not None return return_result @overload def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Literal[None] = ..., enable_check_rowcount: bool = True, ) -> None: ... @overload def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = ..., enable_check_rowcount: bool = True, ) -> _result.Result[Any]: ... def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = None, enable_check_rowcount: bool = True, ) -> Optional[_result.Result[Any]]: base_mapper = mapper.base_mapper search_keys = mapper._primary_key_propkeys if mapper._version_id_prop: search_keys = {mapper._version_id_prop.key}.union(search_keys) def _changed_dict(mapper, state): return { k: v for k, v in state.dict.items() if k in state.committed_state or k in search_keys } if isstates: if update_changed_only: mappings = [_changed_dict(mapper, state) for state in mappings] else: mappings = [state.dict for state in mappings] else: mappings = [dict(m) for m in mappings] _expand_composites(mapper, mappings) if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " "not supported in bulk_update()" ) connection = session_transaction.connection(base_mapper) # find bindparams in the statement. see _bulk_insert for similar # notes for the insert case extra_bp_names = ( [ b.key for b in use_orm_update_stmt._get_embedded_bindparams() if b.key in mappings[0] ] if use_orm_update_stmt is not None else () ) for table, super_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue records = persistence._collect_update_commands( None, table, ( ( None, mapping, mapper, connection, ( mapping[mapper._version_id_prop.key] if mapper._version_id_prop else None ), ) for mapping in mappings ), bulk=True, use_orm_update_stmt=use_orm_update_stmt, include_bulk_keys=extra_bp_names, ) persistence._emit_update_statements( base_mapper, None, super_mapper, table, records, bookkeeping=False, use_orm_update_stmt=use_orm_update_stmt, enable_check_rowcount=enable_check_rowcount, ) if use_orm_update_stmt is not None: return _result.null_result() def _expand_composites(mapper, mappings): composite_attrs = mapper.composites if not composite_attrs: return composite_keys = set(composite_attrs.keys()) populators = { key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() for key in composite_keys } for mapping in mappings: for key in composite_keys.intersection(mapping): populators[key](mapping) class ORMDMLState(AbstractORMCompileState): is_dml_returning = True from_statement_ctx: Optional[ORMFromStatementCompileState] = None @classmethod def _get_orm_crud_kv_pairs( cls, mapper, statement, kv_iterator, needs_to_be_cacheable ): core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs for k, v in kv_iterator: k = coercions.expect(roles.DMLColumnRole, k) if isinstance(k, str): desc = _entity_namespace_key(mapper, k, default=NO_VALUE) if desc is NO_VALUE: yield ( coercions.expect(roles.DMLColumnRole, k), coercions.expect( roles.ExpressionElementRole, v, type_=sqltypes.NullType(), is_crud=True, ) if needs_to_be_cacheable else v, ) else: yield from core_get_crud_kv_pairs( statement, desc._bulk_update_tuples(v), needs_to_be_cacheable, ) elif "entity_namespace" in k._annotations: k_anno = k._annotations attr = _entity_namespace_key( k_anno["entity_namespace"], k_anno["proxy_key"] ) yield from core_get_crud_kv_pairs( statement, attr._bulk_update_tuples(v), needs_to_be_cacheable, ) else: yield ( k, v if not needs_to_be_cacheable else coercions.expect( roles.ExpressionElementRole, v, type_=sqltypes.NullType(), is_crud=True, ), ) @classmethod def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): plugin_subject = statement._propagate_attrs["plugin_subject"] if not plugin_subject or not plugin_subject.mapper: return UpdateDMLState._get_multi_crud_kv_pairs( statement, kv_iterator ) return [ dict( cls._get_orm_crud_kv_pairs( plugin_subject.mapper, statement, value_dict.items(), False ) ) for value_dict in kv_iterator ] @classmethod def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): assert ( needs_to_be_cacheable ), "no test coverage for needs_to_be_cacheable=False" plugin_subject = statement._propagate_attrs["plugin_subject"] if not plugin_subject or not plugin_subject.mapper: return UpdateDMLState._get_crud_kv_pairs( statement, kv_iterator, needs_to_be_cacheable ) return list( cls._get_orm_crud_kv_pairs( plugin_subject.mapper, statement, kv_iterator, needs_to_be_cacheable, ) ) @classmethod def get_entity_description(cls, statement): ext_info = statement.table._annotations["parententity"] mapper = ext_info.mapper if ext_info.is_aliased_class: _label_name = ext_info.name else: _label_name = mapper.class_.__name__ return { "name": _label_name, "type": mapper.class_, "expr": ext_info.entity, "entity": ext_info.entity, "table": mapper.local_table, } @classmethod def get_returning_column_descriptions(cls, statement): def _ent_for_col(c): return c._annotations.get("parententity", None) def _attr_for_col(c, ent): if ent is None: return c proxy_key = c._annotations.get("proxy_key", None) if not proxy_key: return c else: return getattr(ent.entity, proxy_key, c) return [ { "name": c.key, "type": c.type, "expr": _attr_for_col(c, ent), "aliased": ent.is_aliased_class, "entity": ent.entity, } for c, ent in [ (c, _ent_for_col(c)) for c in statement._all_selected_columns ] ] def _setup_orm_returning( self, compiler, orm_level_statement, dml_level_statement, dml_mapper, *, use_supplemental_cols=True, ): """establish ORM column handlers for an INSERT, UPDATE, or DELETE which uses explicit returning(). called within compilation level create_for_statement. The _return_orm_returning() method then receives the Result after the statement was executed, and applies ORM loading to the state that we first established here. """ if orm_level_statement._returning: fs = FromStatement( orm_level_statement._returning, dml_level_statement, _adapt_on_names=False, ) fs = fs.options(*orm_level_statement._with_options) self.select_statement = fs self.from_statement_ctx = ( fsc ) = ORMFromStatementCompileState.create_for_statement(fs, compiler) fsc.setup_dml_returning_compile_state(dml_mapper) dml_level_statement = dml_level_statement._generate() dml_level_statement._returning = () cols_to_return = [c for c in fsc.primary_columns if c is not None] # since we are splicing result sets together, make sure there # are columns of some kind returned in each result set if not cols_to_return: cols_to_return.extend(dml_mapper.primary_key) if use_supplemental_cols: dml_level_statement = dml_level_statement.return_defaults( # this is a little weird looking, but by passing # primary key as the main list of cols, this tells # return_defaults to omit server-default cols (and # actually all cols, due to some weird thing we should # clean up in crud.py). # Since we have cols_to_return, just return what we asked # for (plus primary key, which ORM persistence needs since # we likely set bookkeeping=True here, which is another # whole thing...). We dont want to clutter the # statement up with lots of other cols the user didn't # ask for. see #9685 *dml_mapper.primary_key, supplemental_cols=cols_to_return, ) else: dml_level_statement = dml_level_statement.returning( *cols_to_return ) return dml_level_statement @classmethod def _return_orm_returning( cls, session, statement, params, execution_options, bind_arguments, result, ): execution_context = result.context compile_state = execution_context.compiled.compile_state if compile_state.from_statement_ctx: load_options = execution_options.get( "_sa_orm_load_options", QueryContext.default_load_options ) querycontext = QueryContext( compile_state.from_statement_ctx, compile_state.select_statement, params, session, load_options, execution_options, bind_arguments, ) return loading.instances(result, querycontext) else: return result class BulkUDCompileState(ORMDMLState): class default_update_options(Options): _dml_strategy: DMLStrategyArgument = "auto" _synchronize_session: SynchronizeSessionArgument = "auto" _can_use_returning: bool = False _is_delete_using: bool = False _is_update_from: bool = False _autoflush: bool = True _subject_mapper: Optional[Mapper[Any]] = None _resolved_values = EMPTY_DICT _eval_condition = None _matched_rows = None _identity_token = None @classmethod def can_use_returning( cls, dialect: Dialect, mapper: Mapper[Any], *, is_multitable: bool = False, is_update_from: bool = False, is_delete_using: bool = False, is_executemany: bool = False, ) -> bool: raise NotImplementedError() @classmethod def orm_pre_session_exec( cls, session, statement, params, execution_options, bind_arguments, is_pre_event, ): ( update_options, execution_options, ) = BulkUDCompileState.default_update_options.from_execution_options( "_sa_orm_update_options", { "synchronize_session", "autoflush", "identity_token", "is_delete_using", "is_update_from", "dml_strategy", }, execution_options, statement._execution_options, ) bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] except KeyError: assert False, "statement had 'orm' plugin but no plugin_subject" else: bind_arguments["mapper"] = plugin_subject.mapper update_options += {"_subject_mapper": plugin_subject.mapper} if not isinstance(params, list): if update_options._dml_strategy == "auto": update_options += {"_dml_strategy": "orm"} elif update_options._dml_strategy == "bulk": raise sa_exc.InvalidRequestError( 'Can\'t use "bulk" ORM insert strategy without ' "passing separate parameters" ) else: if update_options._dml_strategy == "auto": update_options += {"_dml_strategy": "bulk"} sync = update_options._synchronize_session if sync is not None: if sync not in ("auto", "evaluate", "fetch", False): raise sa_exc.ArgumentError( "Valid strategies for session synchronization " "are 'auto', 'evaluate', 'fetch', False" ) if update_options._dml_strategy == "bulk" and sync == "fetch": raise sa_exc.InvalidRequestError( "The 'fetch' synchronization strategy is not available " "for 'bulk' ORM updates (i.e. multiple parameter sets)" ) if not is_pre_event: if update_options._autoflush: session._autoflush() if update_options._dml_strategy == "orm": if update_options._synchronize_session == "auto": update_options = cls._do_pre_synchronize_auto( session, statement, params, execution_options, bind_arguments, update_options, ) elif update_options._synchronize_session == "evaluate": update_options = cls._do_pre_synchronize_evaluate( session, statement, params, execution_options, bind_arguments, update_options, ) elif update_options._synchronize_session == "fetch": update_options = cls._do_pre_synchronize_fetch( session, statement, params, execution_options, bind_arguments, update_options, ) elif update_options._dml_strategy == "bulk": if update_options._synchronize_session == "auto": update_options += {"_synchronize_session": "evaluate"} # indicators from the "pre exec" step that are then # added to the DML statement, which will also be part of the cache # key. The compile level create_for_statement() method will then # consume these at compiler time. statement = statement._annotate( { "synchronize_session": update_options._synchronize_session, "is_delete_using": update_options._is_delete_using, "is_update_from": update_options._is_update_from, "dml_strategy": update_options._dml_strategy, "can_use_returning": update_options._can_use_returning, } ) return ( statement, util.immutabledict(execution_options).union( {"_sa_orm_update_options": update_options} ), ) @classmethod def orm_setup_cursor_result( cls, session, statement, params, execution_options, bind_arguments, result, ): # this stage of the execution is called after the # do_orm_execute event hook. meaning for an extension like # horizontal sharding, this step happens *within* the horizontal # sharding event handler which calls session.execute() re-entrantly # and will occur for each backend individually. # the sharding extension then returns its own merged result from the # individual ones we return here. update_options = execution_options["_sa_orm_update_options"] if update_options._dml_strategy == "orm": if update_options._synchronize_session == "evaluate": cls._do_post_synchronize_evaluate( session, statement, result, update_options ) elif update_options._synchronize_session == "fetch": cls._do_post_synchronize_fetch( session, statement, result, update_options ) elif update_options._dml_strategy == "bulk": if update_options._synchronize_session == "evaluate": cls._do_post_synchronize_bulk_evaluate( session, params, result, update_options ) return result return cls._return_orm_returning( session, statement, params, execution_options, bind_arguments, result, ) @classmethod def _adjust_for_extra_criteria(cls, global_attributes, ext_info): """Apply extra criteria filtering. For all distinct single-table-inheritance mappers represented in the table being updated or deleted, produce additional WHERE criteria such that only the appropriate subtypes are selected from the total results. Additionally, add WHERE criteria originating from LoaderCriteriaOptions collected from the statement. """ return_crit = () adapter = ext_info._adapter if ext_info.is_aliased_class else None if ( "additional_entity_criteria", ext_info.mapper, ) in global_attributes: return_crit += tuple( ae._resolve_where_criteria(ext_info) for ae in global_attributes[ ("additional_entity_criteria", ext_info.mapper) ] if ae.include_aliases or ae.entity is ext_info ) if ext_info.mapper._single_table_criterion is not None: return_crit += (ext_info.mapper._single_table_criterion,) if adapter: return_crit = tuple(adapter.traverse(crit) for crit in return_crit) return return_crit @classmethod def _interpret_returning_rows(cls, mapper, rows): """translate from local inherited table columns to base mapper primary key columns. Joined inheritance mappers always establish the primary key in terms of the base table. When we UPDATE a sub-table, we can only get RETURNING for the sub-table's columns. Here, we create a lookup from the local sub table's primary key columns to the base table PK columns so that we can get identity key values from RETURNING that's against the joined inheritance sub-table. the complexity here is to support more than one level deep of inheritance, where we have to link columns to each other across the inheritance hierarchy. """ if mapper.local_table is not mapper.base_mapper.local_table: return rows # this starts as a mapping of # local_pk_col: local_pk_col. # we will then iteratively rewrite the "value" of the dict with # each successive superclass column local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key} for mp in mapper.iterate_to_root(): if mp.inherits is None: break elif mp.local_table is mp.inherits.local_table: continue t_to_e = dict(mp._table_to_equated[mp.inherits.local_table]) col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]} for pk, super_ in local_pk_to_base_pk.items(): local_pk_to_base_pk[pk] = col_to_col[super_] lookup = { local_pk_to_base_pk[lpk]: idx for idx, lpk in enumerate(mapper.local_table.primary_key) } primary_key_convert = [ lookup[bpk] for bpk in mapper.base_mapper.primary_key ] return [tuple(row[idx] for idx in primary_key_convert) for row in rows] @classmethod def _get_matched_objects_on_criteria(cls, update_options, states): mapper = update_options._subject_mapper eval_condition = update_options._eval_condition raw_data = [ (state.obj(), state, state.dict) for state in states if state.mapper.isa(mapper) and not state.expired ] identity_token = update_options._identity_token if identity_token is not None: raw_data = [ (obj, state, dict_) for obj, state, dict_ in raw_data if state.identity_token == identity_token ] result = [] for obj, state, dict_ in raw_data: evaled_condition = eval_condition(obj) # caution: don't use "in ()" or == here, _EXPIRE_OBJECT # evaluates as True for all comparisons if ( evaled_condition is True or evaled_condition is evaluator._EXPIRED_OBJECT ): result.append( ( obj, state, dict_, evaled_condition is evaluator._EXPIRED_OBJECT, ) ) return result @classmethod def _eval_condition_from_statement(cls, update_options, statement): mapper = update_options._subject_mapper target_cls = mapper.class_ evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) crit = () if statement._where_criteria: crit += statement._where_criteria global_attributes = {} for opt in statement._with_options: if opt._is_criteria_option: opt.get_global_criteria(global_attributes) if global_attributes: crit += cls._adjust_for_extra_criteria(global_attributes, mapper) if crit: eval_condition = evaluator_compiler.process(*crit) else: # workaround for mypy https://github.com/python/mypy/issues/14027 def _eval_condition(obj): return True eval_condition = _eval_condition return eval_condition @classmethod def _do_pre_synchronize_auto( cls, session, statement, params, execution_options, bind_arguments, update_options, ): """setup auto sync strategy "auto" checks if we can use "evaluate" first, then falls back to "fetch" evaluate is vastly more efficient for the common case where session is empty, only has a few objects, and the UPDATE statement can potentially match thousands/millions of rows. OTOH more complex criteria that fails to work with "evaluate" we would hope usually correlates with fewer net rows. """ try: eval_condition = cls._eval_condition_from_statement( update_options, statement ) except evaluator.UnevaluatableError: pass else: return update_options + { "_eval_condition": eval_condition, "_synchronize_session": "evaluate", } update_options += {"_synchronize_session": "fetch"} return cls._do_pre_synchronize_fetch( session, statement, params, execution_options, bind_arguments, update_options, ) @classmethod def _do_pre_synchronize_evaluate( cls, session, statement, params, execution_options, bind_arguments, update_options, ): try: eval_condition = cls._eval_condition_from_statement( update_options, statement ) except evaluator.UnevaluatableError as err: raise sa_exc.InvalidRequestError( 'Could not evaluate current criteria in Python: "%s". ' "Specify 'fetch' or False for the " "synchronize_session execution option." % err ) from err return update_options + { "_eval_condition": eval_condition, } @classmethod def _get_resolved_values(cls, mapper, statement): if statement._multi_values: return [] elif statement._ordered_values: return list(statement._ordered_values) elif statement._values: return list(statement._values.items()) else: return [] @classmethod def _resolved_keys_as_propnames(cls, mapper, resolved_values): values = [] for k, v in resolved_values: if mapper and isinstance(k, expression.ColumnElement): try: attr = mapper._columntoproperty[k] except orm_exc.UnmappedColumnError: pass else: values.append((attr.key, v)) else: raise sa_exc.InvalidRequestError( "Attribute name not found, can't be " "synchronized back to objects: %r" % k ) return values @classmethod def _do_pre_synchronize_fetch( cls, session, statement, params, execution_options, bind_arguments, update_options, ): mapper = update_options._subject_mapper select_stmt = ( select(*(mapper.primary_key + (mapper.select_identity_token,))) .select_from(mapper) .options(*statement._with_options) ) select_stmt._where_criteria = statement._where_criteria # conditionally run the SELECT statement for pre-fetch, testing the # "bind" for if we can use RETURNING or not using the do_orm_execute # event. If RETURNING is available, the do_orm_execute event # will cancel the SELECT from being actually run. # # The way this is organized seems strange, why don't we just # call can_use_returning() before invoking the statement and get # answer?, why does this go through the whole execute phase using an # event? Answer: because we are integrating with extensions such # as the horizontal sharding extention that "multiplexes" an individual # statement run through multiple engines, and it uses # do_orm_execute() to do that. can_use_returning = None def skip_for_returning(orm_context: ORMExecuteState) -> Any: bind = orm_context.session.get_bind(**orm_context.bind_arguments) nonlocal can_use_returning per_bind_result = cls.can_use_returning( bind.dialect, mapper, is_update_from=update_options._is_update_from, is_delete_using=update_options._is_delete_using, is_executemany=orm_context.is_executemany, ) if can_use_returning is not None: if can_use_returning != per_bind_result: raise sa_exc.InvalidRequestError( "For synchronize_session='fetch', can't mix multiple " "backends where some support RETURNING and others " "don't" ) elif orm_context.is_executemany and not per_bind_result: raise sa_exc.InvalidRequestError( "For synchronize_session='fetch', can't use multiple " "parameter sets in ORM mode, which this backend does not " "support with RETURNING" ) else: can_use_returning = per_bind_result if per_bind_result: return _result.null_result() else: return None result = session.execute( select_stmt, params, execution_options=execution_options, bind_arguments=bind_arguments, _add_event=skip_for_returning, ) matched_rows = result.fetchall() return update_options + { "_matched_rows": matched_rows, "_can_use_returning": can_use_returning, } @CompileState.plugin_for("orm", "insert") class BulkORMInsert(ORMDMLState, InsertDMLState): class default_insert_options(Options): _dml_strategy: DMLStrategyArgument = "auto" _render_nulls: bool = False _return_defaults: bool = False _subject_mapper: Optional[Mapper[Any]] = None _autoflush: bool = True select_statement: Optional[FromStatement] = None @classmethod def orm_pre_session_exec( cls, session, statement, params, execution_options, bind_arguments, is_pre_event, ): ( insert_options, execution_options, ) = BulkORMInsert.default_insert_options.from_execution_options( "_sa_orm_insert_options", {"dml_strategy", "autoflush"}, execution_options, statement._execution_options, ) bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] except KeyError: assert False, "statement had 'orm' plugin but no plugin_subject" else: bind_arguments["mapper"] = plugin_subject.mapper insert_options += {"_subject_mapper": plugin_subject.mapper} if not params: if insert_options._dml_strategy == "auto": insert_options += {"_dml_strategy": "orm"} elif insert_options._dml_strategy == "bulk": raise sa_exc.InvalidRequestError( 'Can\'t use "bulk" ORM insert strategy without ' "passing separate parameters" ) else: if insert_options._dml_strategy == "auto": insert_options += {"_dml_strategy": "bulk"} if insert_options._dml_strategy != "raw": # for ORM object loading, like ORMContext, we have to disable # result set adapt_to_context, because we will be generating a # new statement with specific columns that's cached inside of # an ORMFromStatementCompileState, which we will re-use for # each result. if not execution_options: execution_options = context._orm_load_exec_options else: execution_options = execution_options.union( context._orm_load_exec_options ) if not is_pre_event and insert_options._autoflush: session._autoflush() statement = statement._annotate( {"dml_strategy": insert_options._dml_strategy} ) return ( statement, util.immutabledict(execution_options).union( {"_sa_orm_insert_options": insert_options} ), ) @classmethod def orm_execute_statement( cls, session: Session, statement: dml.Insert, params: _CoreAnyExecuteParams, execution_options: OrmExecuteOptionsParameter, bind_arguments: _BindArguments, conn: Connection, ) -> _result.Result: insert_options = execution_options.get( "_sa_orm_insert_options", cls.default_insert_options ) if insert_options._dml_strategy not in ( "raw", "bulk", "orm", "auto", ): raise sa_exc.ArgumentError( "Valid strategies for ORM insert strategy " "are 'raw', 'orm', 'bulk', 'auto" ) result: _result.Result[Any] if insert_options._dml_strategy == "raw": result = conn.execute( statement, params or {}, execution_options=execution_options ) return result if insert_options._dml_strategy == "bulk": mapper = insert_options._subject_mapper if ( statement._post_values_clause is not None and mapper._multiple_persistence_tables ): raise sa_exc.InvalidRequestError( "bulk INSERT with a 'post values' clause " "(typically upsert) not supported for multi-table " f"mapper {mapper}" ) assert mapper is not None assert session._transaction is not None result = _bulk_insert( mapper, cast( "Iterable[Dict[str, Any]]", [params] if isinstance(params, dict) else params, ), session._transaction, isstates=False, return_defaults=insert_options._return_defaults, render_nulls=insert_options._render_nulls, use_orm_insert_stmt=statement, execution_options=execution_options, ) elif insert_options._dml_strategy == "orm": result = conn.execute( statement, params or {}, execution_options=execution_options ) else: raise AssertionError() if not bool(statement._returning): return result return cls._return_orm_returning( session, statement, params, execution_options, bind_arguments, result, ) @classmethod def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: self = cast( BulkORMInsert, super().create_for_statement(statement, compiler, **kw), ) if compiler is not None: toplevel = not compiler.stack else: toplevel = True if not toplevel: return self mapper = statement._propagate_attrs["plugin_subject"] dml_strategy = statement._annotations.get("dml_strategy", "raw") if dml_strategy == "bulk": self._setup_for_bulk_insert(compiler) elif dml_strategy == "orm": self._setup_for_orm_insert(compiler, mapper) return self @classmethod def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): return { col.key if col is not None else k: v for col, k, v in ( (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() ) } def _setup_for_orm_insert(self, compiler, mapper): statement = orm_level_statement = cast(dml.Insert, self.statement) statement = self._setup_orm_returning( compiler, orm_level_statement, statement, dml_mapper=mapper, use_supplemental_cols=False, ) self.statement = statement def _setup_for_bulk_insert(self, compiler): """establish an INSERT statement within the context of bulk insert. This method will be within the "conn.execute()" call that is invoked by persistence._emit_insert_statement(). """ statement = orm_level_statement = cast(dml.Insert, self.statement) an = statement._annotations emit_insert_table, emit_insert_mapper = ( an["_emit_insert_table"], an["_emit_insert_mapper"], ) statement = statement._clone() statement.table = emit_insert_table if self._dict_parameters: self._dict_parameters = { col: val for col, val in self._dict_parameters.items() if col.table is emit_insert_table } statement = self._setup_orm_returning( compiler, orm_level_statement, statement, dml_mapper=emit_insert_mapper, use_supplemental_cols=True, ) self.statement = statement @CompileState.plugin_for("orm", "update") class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): @classmethod def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) dml_strategy = statement._annotations.get( "dml_strategy", "unspecified" ) if dml_strategy == "bulk": self._setup_for_bulk_update(statement, compiler) elif dml_strategy in ("orm", "unspecified"): self._setup_for_orm_update(statement, compiler) return self def _setup_for_orm_update(self, statement, compiler, **kw): orm_level_statement = statement ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper self._resolved_values = self._get_resolved_values(mapper, statement) self._init_global_attributes( statement, compiler, toplevel=True, process_criteria_for_toplevel=True, ) if statement._values: self._resolved_values = dict(self._resolved_values) new_stmt = statement._clone() new_stmt.table = mapper.local_table # note if the statement has _multi_values, these # are passed through to the new statement, which will then raise # InvalidRequestError because UPDATE doesn't support multi_values # right now. if statement._ordered_values: new_stmt._ordered_values = self._resolved_values elif statement._values: new_stmt._values = self._resolved_values new_crit = self._adjust_for_extra_criteria( self.global_attributes, mapper ) if new_crit: new_stmt = new_stmt.where(*new_crit) # if we are against a lambda statement we might not be the # topmost object that received per-execute annotations # do this first as we need to determine if there is # UPDATE..FROM UpdateDMLState.__init__(self, new_stmt, compiler, **kw) use_supplemental_cols = False synchronize_session = compiler._annotations.get( "synchronize_session", None ) can_use_returning = compiler._annotations.get( "can_use_returning", None ) if can_use_returning is not False: # even though pre_exec has determined basic # can_use_returning for the dialect, if we are to use # RETURNING we need to run can_use_returning() at this level # unconditionally because is_delete_using was not known # at the pre_exec level can_use_returning = ( synchronize_session == "fetch" and self.can_use_returning( compiler.dialect, mapper, is_multitable=self.is_multitable ) ) if synchronize_session == "fetch" and can_use_returning: use_supplemental_cols = True # NOTE: we might want to RETURNING the actual columns to be # synchronized also. however this is complicated and difficult # to align against the behavior of "evaluate". Additionally, # in a large number (if not the majority) of cases, we have the # "evaluate" answer, usually a fixed value, in memory already and # there's no need to re-fetch the same value # over and over again. so perhaps if it could be RETURNING just # the elements that were based on a SQL expression and not # a constant. For now it doesn't quite seem worth it new_stmt = new_stmt.return_defaults( *(list(mapper.local_table.primary_key)) ) new_stmt = self._setup_orm_returning( compiler, orm_level_statement, new_stmt, dml_mapper=mapper, use_supplemental_cols=use_supplemental_cols, ) self.statement = new_stmt def _setup_for_bulk_update(self, statement, compiler, **kw): """establish an UPDATE statement within the context of bulk insert. This method will be within the "conn.execute()" call that is invoked by persistence._emit_update_statement(). """ statement = cast(dml.Update, statement) an = statement._annotations emit_update_table, _ = ( an["_emit_update_table"], an["_emit_update_mapper"], ) statement = statement._clone() statement.table = emit_update_table UpdateDMLState.__init__(self, statement, compiler, **kw) if self._ordered_values: raise sa_exc.InvalidRequestError( "bulk ORM UPDATE does not support ordered_values() for " "custom UPDATE statements with bulk parameter sets. Use a " "non-bulk UPDATE statement or use values()." ) if self._dict_parameters: self._dict_parameters = { col: val for col, val in self._dict_parameters.items() if col.table is emit_update_table } self.statement = statement @classmethod def orm_execute_statement( cls, session: Session, statement: dml.Update, params: _CoreAnyExecuteParams, execution_options: OrmExecuteOptionsParameter, bind_arguments: _BindArguments, conn: Connection, ) -> _result.Result: update_options = execution_options.get( "_sa_orm_update_options", cls.default_update_options ) if update_options._dml_strategy not in ("orm", "auto", "bulk"): raise sa_exc.ArgumentError( "Valid strategies for ORM UPDATE strategy " "are 'orm', 'auto', 'bulk'" ) result: _result.Result[Any] if update_options._dml_strategy == "bulk": enable_check_rowcount = not statement._where_criteria assert update_options._synchronize_session != "fetch" if ( statement._where_criteria and update_options._synchronize_session == "evaluate" ): raise sa_exc.InvalidRequestError( "bulk synchronize of persistent objects not supported " "when using bulk update with additional WHERE " "criteria right now. add synchronize_session=None " "execution option to bypass synchronize of persistent " "objects." ) mapper = update_options._subject_mapper assert mapper is not None assert session._transaction is not None result = _bulk_update( mapper, cast( "Iterable[Dict[str, Any]]", [params] if isinstance(params, dict) else params, ), session._transaction, isstates=False, update_changed_only=False, use_orm_update_stmt=statement, enable_check_rowcount=enable_check_rowcount, ) return cls.orm_setup_cursor_result( session, statement, params, execution_options, bind_arguments, result, ) else: return super().orm_execute_statement( session, statement, params, execution_options, bind_arguments, conn, ) @classmethod def can_use_returning( cls, dialect: Dialect, mapper: Mapper[Any], *, is_multitable: bool = False, is_update_from: bool = False, is_delete_using: bool = False, is_executemany: bool = False, ) -> bool: # normal answer for "should we use RETURNING" at all. normal_answer = ( dialect.update_returning and mapper.local_table.implicit_returning ) if not normal_answer: return False if is_executemany: return dialect.update_executemany_returning # these workarounds are currently hypothetical for UPDATE, # unlike DELETE where they impact MariaDB if is_update_from: return dialect.update_returning_multifrom elif is_multitable and not dialect.update_returning_multifrom: raise sa_exc.CompileError( f'Dialect "{dialect.name}" does not support RETURNING ' "with UPDATE..FROM; for synchronize_session='fetch', " "please add the additional execution option " "'is_update_from=True' to the statement to indicate that " "a separate SELECT should be used for this backend." ) return True @classmethod def _do_post_synchronize_bulk_evaluate( cls, session, params, result, update_options ): if not params: return mapper = update_options._subject_mapper pk_keys = [prop.key for prop in mapper._identity_key_props] identity_map = session.identity_map for param in params: identity_key = mapper.identity_key_from_primary_key( (param[key] for key in pk_keys), update_options._identity_token, ) state = identity_map.fast_get_state(identity_key) if not state: continue evaluated_keys = set(param).difference(pk_keys) dict_ = state.dict # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: if key in dict_: dict_[key] = param[key] state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) # attributes that were formerly modified instead get expired. # this only gets hit if the session had pending changes # and autoflush were set to False. to_expire = evaluated_keys.intersection(dict_).difference( to_evaluate ) if to_expire: state._expire_attributes(dict_, to_expire) @classmethod def _do_post_synchronize_evaluate( cls, session, statement, result, update_options ): matched_objects = cls._get_matched_objects_on_criteria( update_options, session.identity_map.all_states(), ) cls._apply_update_set_values_to_objects( session, update_options, statement, [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], ) @classmethod def _do_post_synchronize_fetch( cls, session, statement, result, update_options ): target_mapper = update_options._subject_mapper returned_defaults_rows = result.returned_defaults_rows if returned_defaults_rows: pk_rows = cls._interpret_returning_rows( target_mapper, returned_defaults_rows ) matched_rows = [ tuple(row) + (update_options._identity_token,) for row in pk_rows ] else: matched_rows = update_options._matched_rows objs = [ session.identity_map[identity_key] for identity_key in [ target_mapper.identity_key_from_primary_key( list(primary_key), identity_token=identity_token, ) for primary_key, identity_token in [ (row[0:-1], row[-1]) for row in matched_rows ] if update_options._identity_token is None or identity_token == update_options._identity_token ] if identity_key in session.identity_map ] if not objs: return cls._apply_update_set_values_to_objects( session, update_options, statement, [ ( obj, attributes.instance_state(obj), attributes.instance_dict(obj), ) for obj in objs ], ) @classmethod def _apply_update_set_values_to_objects( cls, session, update_options, statement, matched_objects ): """apply values to objects derived from an update statement, e.g. UPDATE..SET """ mapper = update_options._subject_mapper target_cls = mapper.class_ evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) resolved_values = cls._get_resolved_values(mapper, statement) resolved_keys_as_propnames = cls._resolved_keys_as_propnames( mapper, resolved_values ) value_evaluators = {} for key, value in resolved_keys_as_propnames: try: _evaluator = evaluator_compiler.process( coercions.expect(roles.ExpressionElementRole, value) ) except evaluator.UnevaluatableError: pass else: value_evaluators[key] = _evaluator evaluated_keys = list(value_evaluators.keys()) attrib = {k for k, v in resolved_keys_as_propnames} states = set() for obj, state, dict_ in matched_objects: to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: if key in dict_: # only run eval for attributes that are present. dict_[key] = value_evaluators[key](obj) state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) # attributes that were formerly modified instead get expired. # this only gets hit if the session had pending changes # and autoflush were set to False. to_expire = attrib.intersection(dict_).difference(to_evaluate) if to_expire: state._expire_attributes(dict_, to_expire) states.add(state) session._register_altered(states) @CompileState.plugin_for("orm", "delete") class BulkORMDelete(BulkUDCompileState, DeleteDMLState): @classmethod def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) orm_level_statement = statement ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper self._init_global_attributes( statement, compiler, toplevel=True, process_criteria_for_toplevel=True, ) new_stmt = statement._clone() new_stmt.table = mapper.local_table new_crit = cls._adjust_for_extra_criteria( self.global_attributes, mapper ) if new_crit: new_stmt = new_stmt.where(*new_crit) # do this first as we need to determine if there is # DELETE..FROM DeleteDMLState.__init__(self, new_stmt, compiler, **kw) use_supplemental_cols = False synchronize_session = compiler._annotations.get( "synchronize_session", None ) can_use_returning = compiler._annotations.get( "can_use_returning", None ) if can_use_returning is not False: # even though pre_exec has determined basic # can_use_returning for the dialect, if we are to use # RETURNING we need to run can_use_returning() at this level # unconditionally because is_delete_using was not known # at the pre_exec level can_use_returning = ( synchronize_session == "fetch" and self.can_use_returning( compiler.dialect, mapper, is_multitable=self.is_multitable, is_delete_using=compiler._annotations.get( "is_delete_using", False ), ) ) if can_use_returning: use_supplemental_cols = True new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) new_stmt = self._setup_orm_returning( compiler, orm_level_statement, new_stmt, dml_mapper=mapper, use_supplemental_cols=use_supplemental_cols, ) self.statement = new_stmt return self @classmethod def orm_execute_statement( cls, session: Session, statement: dml.Delete, params: _CoreAnyExecuteParams, execution_options: OrmExecuteOptionsParameter, bind_arguments: _BindArguments, conn: Connection, ) -> _result.Result: update_options = execution_options.get( "_sa_orm_update_options", cls.default_update_options ) if update_options._dml_strategy == "bulk": raise sa_exc.InvalidRequestError( "Bulk ORM DELETE not supported right now. " "Statement may be invoked at the " "Core level using " "session.connection().execute(stmt, parameters)" ) if update_options._dml_strategy not in ( "orm", "auto", ): raise sa_exc.ArgumentError( "Valid strategies for ORM DELETE strategy are 'orm', 'auto'" ) return super().orm_execute_statement( session, statement, params, execution_options, bind_arguments, conn ) @classmethod def can_use_returning( cls, dialect: Dialect, mapper: Mapper[Any], *, is_multitable: bool = False, is_update_from: bool = False, is_delete_using: bool = False, is_executemany: bool = False, ) -> bool: # normal answer for "should we use RETURNING" at all. normal_answer = ( dialect.delete_returning and mapper.local_table.implicit_returning ) if not normal_answer: return False # now get into special workarounds because MariaDB supports # DELETE...RETURNING but not DELETE...USING...RETURNING. if is_delete_using: # is_delete_using hint was passed. use # additional dialect feature (True for PG, False for MariaDB) return dialect.delete_returning_multifrom elif is_multitable and not dialect.delete_returning_multifrom: # is_delete_using hint was not passed, but we determined # at compile time that this is in fact a DELETE..USING. # it's too late to continue since we did not pre-SELECT. # raise that we need that hint up front. raise sa_exc.CompileError( f'Dialect "{dialect.name}" does not support RETURNING ' "with DELETE..USING; for synchronize_session='fetch', " "please add the additional execution option " "'is_delete_using=True' to the statement to indicate that " "a separate SELECT should be used for this backend." ) return True @classmethod def _do_post_synchronize_evaluate( cls, session, statement, result, update_options ): matched_objects = cls._get_matched_objects_on_criteria( update_options, session.identity_map.all_states(), ) to_delete = [] for _, state, dict_, is_partially_expired in matched_objects: if is_partially_expired: state._expire(dict_, session.identity_map._modified) else: to_delete.append(state) if to_delete: session._remove_newly_deleted(to_delete) @classmethod def _do_post_synchronize_fetch( cls, session, statement, result, update_options ): target_mapper = update_options._subject_mapper returned_defaults_rows = result.returned_defaults_rows if returned_defaults_rows: pk_rows = cls._interpret_returning_rows( target_mapper, returned_defaults_rows ) matched_rows = [ tuple(row) + (update_options._identity_token,) for row in pk_rows ] else: matched_rows = update_options._matched_rows for row in matched_rows: primary_key = row[0:-1] identity_token = row[-1] # TODO: inline this and call remove_newly_deleted # once identity_key = target_mapper.identity_key_from_primary_key( list(primary_key), identity_token=identity_token, ) if identity_key in session.identity_map: session._remove_newly_deleted( [ attributes.instance_state( session.identity_map[identity_key] ) ] )