summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/bulk_persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/bulk_persistence.py')
-rw-r--r--lib/sqlalchemy/orm/bulk_persistence.py1459
1 files changed, 1141 insertions, 318 deletions
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py
index 225292d17..3ed34a57a 100644
--- a/lib/sqlalchemy/orm/bulk_persistence.py
+++ b/lib/sqlalchemy/orm/bulk_persistence.py
@@ -15,24 +15,32 @@ specifically outside of the flush() process.
from __future__ import annotations
from typing import Any
+from typing import cast
from typing import Dict
from typing import Iterable
+from typing import Optional
+from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import attributes
+from . import context
from . import evaluator
from . import exc as orm_exc
+from . import loading
from . import persistence
from .base import NO_VALUE
from .context import AbstractORMCompileState
+from .context import FromStatement
+from .context import ORMFromStatementCompileState
+from .context import QueryContext
from .. import exc as sa_exc
-from .. import sql
from .. import util
from ..engine import Dialect
from ..engine import result as _result
from ..sql import coercions
+from ..sql import dml
from ..sql import expression
from ..sql import roles
from ..sql import select
@@ -48,16 +56,24 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from .mapper import Mapper
+ from .session import _BindArguments
from .session import ORMExecuteState
+ from .session import Session
from .session import SessionTransaction
from .state import InstanceState
+ from ..engine import Connection
+ from ..engine import cursor
+ from ..engine.interfaces import _CoreAnyExecuteParams
+ from ..engine.interfaces import _ExecuteOptionsParameter
_O = TypeVar("_O", bound=object)
-_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"]
+_SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"]
+_DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"]
+@overload
def _bulk_insert(
mapper: Mapper[_O],
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
@@ -65,7 +81,36 @@ def _bulk_insert(
isstates: bool,
return_defaults: bool,
render_nulls: bool,
+ use_orm_insert_stmt: Literal[None] = ...,
+ execution_options: Optional[_ExecuteOptionsParameter] = ...,
) -> None:
+ ...
+
+
+@overload
+def _bulk_insert(
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ return_defaults: bool,
+ render_nulls: bool,
+ use_orm_insert_stmt: Optional[dml.Insert] = ...,
+ execution_options: Optional[_ExecuteOptionsParameter] = ...,
+) -> cursor.CursorResult[Any]:
+ ...
+
+
+def _bulk_insert(
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ return_defaults: bool,
+ render_nulls: bool,
+ use_orm_insert_stmt: Optional[dml.Insert] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+) -> Optional[cursor.CursorResult[Any]]:
base_mapper = mapper.base_mapper
if session_transaction.session.connection_callable:
@@ -81,13 +126,27 @@ def _bulk_insert(
else:
mappings = [state.dict for state in mappings]
else:
- mappings = list(mappings)
+ mappings = [dict(m) for m in mappings]
+ _expand_composites(mapper, mappings)
connection = session_transaction.connection(base_mapper)
+
+ return_result: Optional[cursor.CursorResult[Any]] = None
+
for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
+ if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
continue
+ is_joined_inh_supertable = super_mapper is not mapper
+ bookkeeping = (
+ is_joined_inh_supertable
+ or return_defaults
+ or (
+ use_orm_insert_stmt is not None
+ and bool(use_orm_insert_stmt._returning)
+ )
+ )
+
records = (
(
None,
@@ -112,18 +171,25 @@ def _bulk_insert(
table,
((None, mapping, mapper, connection) for mapping in mappings),
bulk=True,
- return_defaults=return_defaults,
+ return_defaults=bookkeeping,
render_nulls=render_nulls,
)
)
- persistence._emit_insert_statements(
+ result = persistence._emit_insert_statements(
base_mapper,
None,
super_mapper,
table,
records,
- bookkeeping=return_defaults,
+ bookkeeping=bookkeeping,
+ use_orm_insert_stmt=use_orm_insert_stmt,
+ execution_options=execution_options,
)
+ if use_orm_insert_stmt is not None:
+ if not use_orm_insert_stmt._returning or return_result is None:
+ return_result = result
+ elif result.returns_rows:
+ return_result = return_result.splice_horizontally(result)
if return_defaults and isstates:
identity_cls = mapper._identity_class
@@ -134,14 +200,43 @@ def _bulk_insert(
tuple([dict_[key] for key in identity_props]),
)
+ if use_orm_insert_stmt is not None:
+ assert return_result is not None
+ return return_result
+
+@overload
def _bulk_update(
mapper: Mapper[Any],
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
session_transaction: SessionTransaction,
isstates: bool,
update_changed_only: bool,
+ use_orm_update_stmt: Literal[None] = ...,
) -> None:
+ ...
+
+
+@overload
+def _bulk_update(
+ mapper: Mapper[Any],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ update_changed_only: bool,
+ use_orm_update_stmt: Optional[dml.Update] = ...,
+) -> _result.Result[Any]:
+ ...
+
+
+def _bulk_update(
+ mapper: Mapper[Any],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ update_changed_only: bool,
+ use_orm_update_stmt: Optional[dml.Update] = None,
+) -> Optional[_result.Result[Any]]:
base_mapper = mapper.base_mapper
search_keys = mapper._primary_key_propkeys
@@ -161,7 +256,8 @@ def _bulk_update(
else:
mappings = [state.dict for state in mappings]
else:
- mappings = list(mappings)
+ mappings = [dict(m) for m in mappings]
+ _expand_composites(mapper, mappings)
if session_transaction.session.connection_callable:
raise NotImplementedError(
@@ -172,7 +268,7 @@ def _bulk_update(
connection = session_transaction.connection(base_mapper)
for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
+ if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
continue
records = persistence._collect_update_commands(
@@ -193,8 +289,8 @@ def _bulk_update(
for mapping in mappings
),
bulk=True,
+ use_orm_update_stmt=use_orm_update_stmt,
)
-
persistence._emit_update_statements(
base_mapper,
None,
@@ -202,10 +298,125 @@ def _bulk_update(
table,
records,
bookkeeping=False,
+ use_orm_update_stmt=use_orm_update_stmt,
)
+ if use_orm_update_stmt is not None:
+ return _result.null_result()
+
+
+def _expand_composites(mapper, mappings):
+ composite_attrs = mapper.composites
+ if not composite_attrs:
+ return
+
+ composite_keys = set(composite_attrs.keys())
+ populators = {
+ key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
+ for key in composite_keys
+ }
+ for mapping in mappings:
+ for key in composite_keys.intersection(mapping):
+ populators[key](mapping)
+
class ORMDMLState(AbstractORMCompileState):
+ is_dml_returning = True
+ from_statement_ctx: Optional[ORMFromStatementCompileState] = None
+
+ @classmethod
+ def _get_orm_crud_kv_pairs(
+ cls, mapper, statement, kv_iterator, needs_to_be_cacheable
+ ):
+
+ core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
+
+ for k, v in kv_iterator:
+ k = coercions.expect(roles.DMLColumnRole, k)
+
+ if isinstance(k, str):
+ desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
+ if desc is NO_VALUE:
+ yield (
+ coercions.expect(roles.DMLColumnRole, k),
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ )
+ if needs_to_be_cacheable
+ else v,
+ )
+ else:
+ yield from core_get_crud_kv_pairs(
+ statement,
+ desc._bulk_update_tuples(v),
+ needs_to_be_cacheable,
+ )
+ elif "entity_namespace" in k._annotations:
+ k_anno = k._annotations
+ attr = _entity_namespace_key(
+ k_anno["entity_namespace"], k_anno["proxy_key"]
+ )
+ yield from core_get_crud_kv_pairs(
+ statement,
+ attr._bulk_update_tuples(v),
+ needs_to_be_cacheable,
+ )
+ else:
+ yield (
+ k,
+ v
+ if not needs_to_be_cacheable
+ else coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ ),
+ )
+
+ @classmethod
+ def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+ if not plugin_subject or not plugin_subject.mapper:
+ return UpdateDMLState._get_multi_crud_kv_pairs(
+ statement, kv_iterator
+ )
+
+ return [
+ dict(
+ cls._get_orm_crud_kv_pairs(
+ plugin_subject.mapper, statement, value_dict.items(), False
+ )
+ )
+ for value_dict in kv_iterator
+ ]
+
+ @classmethod
+ def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
+ assert (
+ needs_to_be_cacheable
+ ), "no test coverage for needs_to_be_cacheable=False"
+
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+ if not plugin_subject or not plugin_subject.mapper:
+ return UpdateDMLState._get_crud_kv_pairs(
+ statement, kv_iterator, needs_to_be_cacheable
+ )
+
+ return list(
+ cls._get_orm_crud_kv_pairs(
+ plugin_subject.mapper,
+ statement,
+ kv_iterator,
+ needs_to_be_cacheable,
+ )
+ )
+
@classmethod
def get_entity_description(cls, statement):
ext_info = statement.table._annotations["parententity"]
@@ -250,18 +461,101 @@ class ORMDMLState(AbstractORMCompileState):
]
]
+ def _setup_orm_returning(
+ self,
+ compiler,
+ orm_level_statement,
+ dml_level_statement,
+ use_supplemental_cols=True,
+ dml_mapper=None,
+ ):
+ """establish ORM column handlers for an INSERT, UPDATE, or DELETE
+ which uses explicit returning().
+
+ called within compilation level create_for_statement.
+
+ The _return_orm_returning() method then receives the Result
+ after the statement was executed, and applies ORM loading to the
+ state that we first established here.
+
+ """
+
+ if orm_level_statement._returning:
+
+ fs = FromStatement(
+ orm_level_statement._returning, dml_level_statement
+ )
+ fs = fs.options(*orm_level_statement._with_options)
+ self.select_statement = fs
+ self.from_statement_ctx = (
+ fsc
+ ) = ORMFromStatementCompileState.create_for_statement(fs, compiler)
+ fsc.setup_dml_returning_compile_state(dml_mapper)
+
+ dml_level_statement = dml_level_statement._generate()
+ dml_level_statement._returning = ()
+
+ cols_to_return = [c for c in fsc.primary_columns if c is not None]
+
+ # since we are splicing result sets together, make sure there
+ # are columns of some kind returned in each result set
+ if not cols_to_return:
+ cols_to_return.extend(dml_mapper.primary_key)
+
+ if use_supplemental_cols:
+ dml_level_statement = dml_level_statement.return_defaults(
+ supplemental_cols=cols_to_return
+ )
+ else:
+ dml_level_statement = dml_level_statement.returning(
+ *cols_to_return
+ )
+
+ return dml_level_statement
+
+ @classmethod
+ def _return_orm_returning(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+
+ execution_context = result.context
+ compile_state = execution_context.compiled.compile_state
+
+ if compile_state.from_statement_ctx:
+ load_options = execution_options.get(
+ "_sa_orm_load_options", QueryContext.default_load_options
+ )
+ querycontext = QueryContext(
+ compile_state.from_statement_ctx,
+ compile_state.select_statement,
+ params,
+ session,
+ load_options,
+ execution_options,
+ bind_arguments,
+ )
+ return loading.instances(result, querycontext)
+ else:
+ return result
+
class BulkUDCompileState(ORMDMLState):
class default_update_options(Options):
- _synchronize_session: _SynchronizeSessionArgument = "evaluate"
- _is_delete_using = False
- _is_update_from = False
- _autoflush = True
- _subject_mapper = None
+ _dml_strategy: _DMLStrategyArgument = "auto"
+ _synchronize_session: _SynchronizeSessionArgument = "auto"
+ _can_use_returning: bool = False
+ _is_delete_using: bool = False
+ _is_update_from: bool = False
+ _autoflush: bool = True
+ _subject_mapper: Optional[Mapper[Any]] = None
_resolved_values = EMPTY_DICT
- _resolved_keys_as_propnames = EMPTY_DICT
- _value_evaluators = EMPTY_DICT
- _matched_objects = None
+ _eval_condition = None
_matched_rows = None
_refresh_identity_token = None
@@ -295,19 +589,16 @@ class BulkUDCompileState(ORMDMLState):
execution_options,
) = BulkUDCompileState.default_update_options.from_execution_options(
"_sa_orm_update_options",
- {"synchronize_session", "is_delete_using", "is_update_from"},
+ {
+ "synchronize_session",
+ "is_delete_using",
+ "is_update_from",
+ "dml_strategy",
+ },
execution_options,
statement._execution_options,
)
- sync = update_options._synchronize_session
- if sync is not None:
- if sync not in ("evaluate", "fetch", False):
- raise sa_exc.ArgumentError(
- "Valid strategies for session synchronization "
- "are 'evaluate', 'fetch', False"
- )
-
bind_arguments["clause"] = statement
try:
plugin_subject = statement._propagate_attrs["plugin_subject"]
@@ -318,43 +609,86 @@ class BulkUDCompileState(ORMDMLState):
update_options += {"_subject_mapper": plugin_subject.mapper}
+ if not isinstance(params, list):
+ if update_options._dml_strategy == "auto":
+ update_options += {"_dml_strategy": "orm"}
+ elif update_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "bulk" ORM insert strategy without '
+ "passing separate parameters"
+ )
+ else:
+ if update_options._dml_strategy == "auto":
+ update_options += {"_dml_strategy": "bulk"}
+ elif update_options._dml_strategy == "orm":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "orm" ORM insert strategy with a '
+ "separate parameter list"
+ )
+
+ sync = update_options._synchronize_session
+ if sync is not None:
+ if sync not in ("auto", "evaluate", "fetch", False):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for session synchronization "
+ "are 'auto', 'evaluate', 'fetch', False"
+ )
+ if update_options._dml_strategy == "bulk" and sync == "fetch":
+ raise sa_exc.InvalidRequestError(
+ "The 'fetch' synchronization strategy is not available "
+ "for 'bulk' ORM updates (i.e. multiple parameter sets)"
+ )
+
if update_options._autoflush:
session._autoflush()
+ if update_options._dml_strategy == "orm":
+
+ if update_options._synchronize_session == "auto":
+ update_options = cls._do_pre_synchronize_auto(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "evaluate":
+ update_options = cls._do_pre_synchronize_evaluate(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "fetch":
+ update_options = cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._dml_strategy == "bulk":
+ if update_options._synchronize_session == "auto":
+ update_options += {"_synchronize_session": "evaluate"}
+
+ # indicators from the "pre exec" step that are then
+ # added to the DML statement, which will also be part of the cache
+ # key. The compile level create_for_statement() method will then
+ # consume these at compiler time.
statement = statement._annotate(
{
"synchronize_session": update_options._synchronize_session,
"is_delete_using": update_options._is_delete_using,
"is_update_from": update_options._is_update_from,
+ "dml_strategy": update_options._dml_strategy,
+ "can_use_returning": update_options._can_use_returning,
}
)
- # this stage of the execution is called before the do_orm_execute event
- # hook. meaning for an extension like horizontal sharding, this step
- # happens before the extension splits out into multiple backends and
- # runs only once. if we do pre_sync_fetch, we execute a SELECT
- # statement, which the horizontal sharding extension splits amongst the
- # shards and combines the results together.
-
- if update_options._synchronize_session == "evaluate":
- update_options = cls._do_pre_synchronize_evaluate(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- elif update_options._synchronize_session == "fetch":
- update_options = cls._do_pre_synchronize_fetch(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
-
return (
statement,
util.immutabledict(execution_options).union(
@@ -382,12 +716,30 @@ class BulkUDCompileState(ORMDMLState):
# individual ones we return here.
update_options = execution_options["_sa_orm_update_options"]
- if update_options._synchronize_session == "evaluate":
- cls._do_post_synchronize_evaluate(session, result, update_options)
- elif update_options._synchronize_session == "fetch":
- cls._do_post_synchronize_fetch(session, result, update_options)
+ if update_options._dml_strategy == "orm":
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_evaluate(
+ session, statement, result, update_options
+ )
+ elif update_options._synchronize_session == "fetch":
+ cls._do_post_synchronize_fetch(
+ session, statement, result, update_options
+ )
+ elif update_options._dml_strategy == "bulk":
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_bulk_evaluate(
+ session, params, result, update_options
+ )
+ return result
- return result
+ return cls._return_orm_returning(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
@classmethod
def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
@@ -473,11 +825,76 @@ class BulkUDCompileState(ORMDMLState):
primary_key_convert = [
lookup[bpk] for bpk in mapper.base_mapper.primary_key
]
-
return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
@classmethod
- def _do_pre_synchronize_evaluate(
+ def _get_matched_objects_on_criteria(cls, update_options, states):
+ mapper = update_options._subject_mapper
+ eval_condition = update_options._eval_condition
+
+ raw_data = [
+ (state.obj(), state, state.dict)
+ for state in states
+ if state.mapper.isa(mapper) and not state.expired
+ ]
+
+ identity_token = update_options._refresh_identity_token
+ if identity_token is not None:
+ raw_data = [
+ (obj, state, dict_)
+ for obj, state, dict_ in raw_data
+ if state.identity_token == identity_token
+ ]
+
+ result = []
+ for obj, state, dict_ in raw_data:
+ evaled_condition = eval_condition(obj)
+
+ # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
+ # evaluates as True for all comparisons
+ if (
+ evaled_condition is True
+ or evaled_condition is evaluator._EXPIRED_OBJECT
+ ):
+ result.append(
+ (
+ obj,
+ state,
+ dict_,
+ evaled_condition is evaluator._EXPIRED_OBJECT,
+ )
+ )
+ return result
+
+ @classmethod
+ def _eval_condition_from_statement(cls, update_options, statement):
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
+
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ crit = ()
+ if statement._where_criteria:
+ crit += statement._where_criteria
+
+ global_attributes = {}
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(global_attributes)
+
+ if global_attributes:
+ crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
+
+ if crit:
+ eval_condition = evaluator_compiler.process(*crit)
+ else:
+
+ def eval_condition(obj):
+ return True
+
+ return eval_condition
+
+ @classmethod
+ def _do_pre_synchronize_auto(
cls,
session,
statement,
@@ -486,33 +903,59 @@ class BulkUDCompileState(ORMDMLState):
bind_arguments,
update_options,
):
- mapper = update_options._subject_mapper
- target_cls = mapper.class_
+ """setup auto sync strategy
+
+
+ "auto" checks if we can use "evaluate" first, then falls back
+ to "fetch"
+
+ evaluate is vastly more efficient for the common case
+ where session is empty, only has a few objects, and the UPDATE
+ statement can potentially match thousands/millions of rows.
- value_evaluators = resolved_keys_as_propnames = EMPTY_DICT
+ OTOH more complex criteria that fails to work with "evaluate"
+ we would hope usually correlates with fewer net rows.
+
+ """
try:
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- crit = ()
- if statement._where_criteria:
- crit += statement._where_criteria
+ eval_condition = cls._eval_condition_from_statement(
+ update_options, statement
+ )
- global_attributes = {}
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(global_attributes)
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ return update_options + {
+ "_eval_condition": eval_condition,
+ "_synchronize_session": "evaluate",
+ }
- if global_attributes:
- crit += cls._adjust_for_extra_criteria(
- global_attributes, mapper
- )
+ update_options += {"_synchronize_session": "fetch"}
+ return cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
- if crit:
- eval_condition = evaluator_compiler.process(*crit)
- else:
+ @classmethod
+ def _do_pre_synchronize_evaluate(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
- def eval_condition(obj):
- return True
+ try:
+ eval_condition = cls._eval_condition_from_statement(
+ update_options, statement
+ )
except evaluator.UnevaluatableError as err:
raise sa_exc.InvalidRequestError(
@@ -521,52 +964,8 @@ class BulkUDCompileState(ORMDMLState):
"synchronize_session execution option." % err
) from err
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
-
- if effective_statement.__visit_name__ == "update":
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- value_evaluators = {}
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
-
- # TODO: detect when the where clause is a trivial primary key match.
- matched_objects = [
- state.obj()
- for state in session.identity_map.all_states()
- if state.mapper.isa(mapper)
- and not state.expired
- and eval_condition(state.obj())
- and (
- update_options._refresh_identity_token is None
- # TODO: coverage for the case where horizontal sharding
- # invokes an update() or delete() given an explicit identity
- # token up front
- or state.identity_token
- == update_options._refresh_identity_token
- )
- ]
return update_options + {
- "_matched_objects": matched_objects,
- "_value_evaluators": value_evaluators,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ "_eval_condition": eval_condition,
}
@classmethod
@@ -584,12 +983,6 @@ class BulkUDCompileState(ORMDMLState):
def _resolved_keys_as_propnames(cls, mapper, resolved_values):
values = []
for k, v in resolved_values:
- if isinstance(k, attributes.QueryableAttribute):
- values.append((k.key, v))
- continue
- elif hasattr(k, "__clause_element__"):
- k = k.__clause_element__()
-
if mapper and isinstance(k, expression.ColumnElement):
try:
attr = mapper._columntoproperty[k]
@@ -599,7 +992,8 @@ class BulkUDCompileState(ORMDMLState):
values.append((attr.key, v))
else:
raise sa_exc.InvalidRequestError(
- "Invalid expression type: %r" % k
+ "Attribute name not found, can't be "
+ "synchronized back to objects: %r" % k
)
return values
@@ -622,14 +1016,43 @@ class BulkUDCompileState(ORMDMLState):
)
select_stmt._where_criteria = statement._where_criteria
+ # conditionally run the SELECT statement for pre-fetch, testing the
+ # "bind" for if we can use RETURNING or not using the do_orm_execute
+ # event. If RETURNING is available, the do_orm_execute event
+ # will cancel the SELECT from being actually run.
+ #
+ # The way this is organized seems strange, why don't we just
+ # call can_use_returning() before invoking the statement and get
+ # answer?, why does this go through the whole execute phase using an
+ # event? Answer: because we are integrating with extensions such
+ # as the horizontal sharding extention that "multiplexes" an individual
+ # statement run through multiple engines, and it uses
+ # do_orm_execute() to do that.
+
+ can_use_returning = None
+
def skip_for_returning(orm_context: ORMExecuteState) -> Any:
bind = orm_context.session.get_bind(**orm_context.bind_arguments)
- if cls.can_use_returning(
+ nonlocal can_use_returning
+
+ per_bind_result = cls.can_use_returning(
bind.dialect,
mapper,
is_update_from=update_options._is_update_from,
is_delete_using=update_options._is_delete_using,
- ):
+ )
+
+ if can_use_returning is not None:
+ if can_use_returning != per_bind_result:
+ raise sa_exc.InvalidRequestError(
+ "For synchronize_session='fetch', can't mix multiple "
+ "backends where some support RETURNING and others "
+ "don't"
+ )
+ else:
+ can_use_returning = per_bind_result
+
+ if per_bind_result:
return _result.null_result()
else:
return None
@@ -643,52 +1066,22 @@ class BulkUDCompileState(ORMDMLState):
)
matched_rows = result.fetchall()
- value_evaluators = EMPTY_DICT
-
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
-
- if effective_statement.__visit_name__ == "update":
- target_cls = mapper.class_
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
-
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- value_evaluators = {}
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
-
- else:
- resolved_keys_as_propnames = EMPTY_DICT
-
return update_options + {
- "_value_evaluators": value_evaluators,
"_matched_rows": matched_rows,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ "_can_use_returning": can_use_returning,
}
@CompileState.plugin_for("orm", "insert")
-class ORMInsert(ORMDMLState, InsertDMLState):
+class BulkORMInsert(ORMDMLState, InsertDMLState):
+ class default_insert_options(Options):
+ _dml_strategy: _DMLStrategyArgument = "auto"
+ _render_nulls: bool = False
+ _return_defaults: bool = False
+ _subject_mapper: Optional[Mapper[Any]] = None
+
+ select_statement: Optional[FromStatement] = None
+
@classmethod
def orm_pre_session_exec(
cls,
@@ -699,6 +1092,16 @@ class ORMInsert(ORMDMLState, InsertDMLState):
bind_arguments,
is_reentrant_invoke,
):
+
+ (
+ insert_options,
+ execution_options,
+ ) = BulkORMInsert.default_insert_options.from_execution_options(
+ "_sa_orm_insert_options",
+ {"dml_strategy"},
+ execution_options,
+ statement._execution_options,
+ )
bind_arguments["clause"] = statement
try:
plugin_subject = statement._propagate_attrs["plugin_subject"]
@@ -707,22 +1110,209 @@ class ORMInsert(ORMDMLState, InsertDMLState):
else:
bind_arguments["mapper"] = plugin_subject.mapper
+ insert_options += {"_subject_mapper": plugin_subject.mapper}
+
+ if not params:
+ if insert_options._dml_strategy == "auto":
+ insert_options += {"_dml_strategy": "orm"}
+ elif insert_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "bulk" ORM insert strategy without '
+ "passing separate parameters"
+ )
+ else:
+ if insert_options._dml_strategy == "auto":
+ insert_options += {"_dml_strategy": "bulk"}
+ elif insert_options._dml_strategy == "orm":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "orm" ORM insert strategy with a '
+ "separate parameter list"
+ )
+
+ if insert_options._dml_strategy != "raw":
+ # for ORM object loading, like ORMContext, we have to disable
+ # result set adapt_to_context, because we will be generating a
+ # new statement with specific columns that's cached inside of
+ # an ORMFromStatementCompileState, which we will re-use for
+ # each result.
+ if not execution_options:
+ execution_options = context._orm_load_exec_options
+ else:
+ execution_options = execution_options.union(
+ context._orm_load_exec_options
+ )
+
+ statement = statement._annotate(
+ {"dml_strategy": insert_options._dml_strategy}
+ )
+
return (
statement,
- util.immutabledict(execution_options),
+ util.immutabledict(execution_options).union(
+ {"_sa_orm_insert_options": insert_options}
+ ),
)
@classmethod
- def orm_setup_cursor_result(
+ def orm_execute_statement(
cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- ):
- return result
+ session: Session,
+ statement: dml.Insert,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ insert_options = execution_options.get(
+ "_sa_orm_insert_options", cls.default_insert_options
+ )
+
+ if insert_options._dml_strategy not in (
+ "raw",
+ "bulk",
+ "orm",
+ "auto",
+ ):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM insert strategy "
+ "are 'raw', 'orm', 'bulk', 'auto"
+ )
+
+ result: _result.Result[Any]
+
+ if insert_options._dml_strategy == "raw":
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ return result
+
+ if insert_options._dml_strategy == "bulk":
+ mapper = insert_options._subject_mapper
+
+ if (
+ statement._post_values_clause is not None
+ and mapper._multiple_persistence_tables
+ ):
+ raise sa_exc.InvalidRequestError(
+ "bulk INSERT with a 'post values' clause "
+ "(typically upsert) not supported for multi-table "
+ f"mapper {mapper}"
+ )
+
+ assert mapper is not None
+ assert session._transaction is not None
+ result = _bulk_insert(
+ mapper,
+ cast(
+ "Iterable[Dict[str, Any]]",
+ [params] if isinstance(params, dict) else params,
+ ),
+ session._transaction,
+ isstates=False,
+ return_defaults=insert_options._return_defaults,
+ render_nulls=insert_options._render_nulls,
+ use_orm_insert_stmt=statement,
+ execution_options=execution_options,
+ )
+ elif insert_options._dml_strategy == "orm":
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ else:
+ raise AssertionError()
+
+ if not bool(statement._returning):
+ return result
+
+ return cls._return_orm_returning(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
+
+ self = cast(
+ BulkORMInsert,
+ super().create_for_statement(statement, compiler, **kw),
+ )
+
+ if compiler is not None:
+ toplevel = not compiler.stack
+ else:
+ toplevel = True
+ if not toplevel:
+ return self
+
+ mapper = statement._propagate_attrs["plugin_subject"]
+ dml_strategy = statement._annotations.get("dml_strategy", "raw")
+ if dml_strategy == "bulk":
+ self._setup_for_bulk_insert(compiler)
+ elif dml_strategy == "orm":
+ self._setup_for_orm_insert(compiler, mapper)
+
+ return self
+
+ @classmethod
+ def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
+ return {
+ col.key if col is not None else k: v
+ for col, k, v in (
+ (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
+ )
+ }
+
+ def _setup_for_orm_insert(self, compiler, mapper):
+ statement = orm_level_statement = cast(dml.Insert, self.statement)
+
+ statement = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ statement,
+ use_supplemental_cols=False,
+ )
+ self.statement = statement
+
+ def _setup_for_bulk_insert(self, compiler):
+ """establish an INSERT statement within the context of
+ bulk insert.
+
+ This method will be within the "conn.execute()" call that is invoked
+ by persistence._emit_insert_statement().
+
+ """
+ statement = orm_level_statement = cast(dml.Insert, self.statement)
+ an = statement._annotations
+
+ emit_insert_table, emit_insert_mapper = (
+ an["_emit_insert_table"],
+ an["_emit_insert_mapper"],
+ )
+
+ statement = statement._clone()
+
+ statement.table = emit_insert_table
+ if self._dict_parameters:
+ self._dict_parameters = {
+ col: val
+ for col, val in self._dict_parameters.items()
+ if col.table is emit_insert_table
+ }
+
+ statement = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ statement,
+ use_supplemental_cols=True,
+ dml_mapper=emit_insert_mapper,
+ )
+
+ self.statement = statement
@CompileState.plugin_for("orm", "update")
@@ -732,13 +1322,27 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
self = cls.__new__(cls)
+ dml_strategy = statement._annotations.get(
+ "dml_strategy", "unspecified"
+ )
+
+ if dml_strategy == "bulk":
+ self._setup_for_bulk_update(statement, compiler)
+ elif dml_strategy in ("orm", "unspecified"):
+ self._setup_for_orm_update(statement, compiler)
+
+ return self
+
+ def _setup_for_orm_update(self, statement, compiler, **kw):
+ orm_level_statement = statement
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
self.extra_criteria_entities = {}
- self._resolved_values = cls._get_resolved_values(mapper, statement)
+ self._resolved_values = self._get_resolved_values(mapper, statement)
extra_criteria_attributes = {}
@@ -749,8 +1353,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
if statement._values:
self._resolved_values = dict(self._resolved_values)
- new_stmt = sql.Update.__new__(sql.Update)
- new_stmt.__dict__.update(statement.__dict__)
+ new_stmt = statement._clone()
new_stmt.table = mapper.local_table
# note if the statement has _multi_values, these
@@ -762,7 +1365,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
elif statement._values:
new_stmt._values = self._resolved_values
- new_crit = cls._adjust_for_extra_criteria(
+ new_crit = self._adjust_for_extra_criteria(
extra_criteria_attributes, mapper
)
if new_crit:
@@ -776,21 +1379,150 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
- if compiler._annotations.get(
+ use_supplemental_cols = False
+
+ synchronize_session = compiler._annotations.get(
"synchronize_session", None
- ) == "fetch" and self.can_use_returning(
- compiler.dialect, mapper, is_multitable=self.is_multitable
- ):
- if new_stmt._returning:
- raise sa_exc.InvalidRequestError(
- "Can't use synchronize_session='fetch' "
- "with explicit returning()"
+ )
+ can_use_returning = compiler._annotations.get(
+ "can_use_returning", None
+ )
+ if can_use_returning is not False:
+ # even though pre_exec has determined basic
+ # can_use_returning for the dialect, if we are to use
+ # RETURNING we need to run can_use_returning() at this level
+ # unconditionally because is_delete_using was not known
+ # at the pre_exec level
+ can_use_returning = (
+ synchronize_session == "fetch"
+ and self.can_use_returning(
+ compiler.dialect, mapper, is_multitable=self.is_multitable
)
- self.statement = self.statement.returning(
- *mapper.local_table.primary_key
)
- return self
+ if synchronize_session == "fetch" and can_use_returning:
+ use_supplemental_cols = True
+
+ # NOTE: we might want to RETURNING the actual columns to be
+ # synchronized also. however this is complicated and difficult
+ # to align against the behavior of "evaluate". Additionally,
+ # in a large number (if not the majority) of cases, we have the
+ # "evaluate" answer, usually a fixed value, in memory already and
+ # there's no need to re-fetch the same value
+ # over and over again. so perhaps if it could be RETURNING just
+ # the elements that were based on a SQL expression and not
+ # a constant. For now it doesn't quite seem worth it
+ new_stmt = new_stmt.return_defaults(
+ *(list(mapper.local_table.primary_key))
+ )
+
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ use_supplemental_cols=use_supplemental_cols,
+ )
+
+ self.statement = new_stmt
+
+ def _setup_for_bulk_update(self, statement, compiler, **kw):
+ """establish an UPDATE statement within the context of
+ bulk insert.
+
+ This method will be within the "conn.execute()" call that is invoked
+ by persistence._emit_update_statement().
+
+ """
+ statement = cast(dml.Update, statement)
+ an = statement._annotations
+
+ emit_update_table, _ = (
+ an["_emit_update_table"],
+ an["_emit_update_mapper"],
+ )
+
+ statement = statement._clone()
+ statement.table = emit_update_table
+
+ UpdateDMLState.__init__(self, statement, compiler, **kw)
+
+ if self._ordered_values:
+ raise sa_exc.InvalidRequestError(
+ "bulk ORM UPDATE does not support ordered_values() for "
+ "custom UPDATE statements with bulk parameter sets. Use a "
+ "non-bulk UPDATE statement or use values()."
+ )
+
+ if self._dict_parameters:
+ self._dict_parameters = {
+ col: val
+ for col, val in self._dict_parameters.items()
+ if col.table is emit_update_table
+ }
+ self.statement = statement
+
+ @classmethod
+ def orm_execute_statement(
+ cls,
+ session: Session,
+ statement: dml.Update,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ update_options = execution_options.get(
+ "_sa_orm_update_options", cls.default_update_options
+ )
+
+ if update_options._dml_strategy not in ("orm", "auto", "bulk"):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM UPDATE strategy "
+ "are 'orm', 'auto', 'bulk'"
+ )
+
+ result: _result.Result[Any]
+
+ if update_options._dml_strategy == "bulk":
+ if statement._where_criteria:
+ raise sa_exc.InvalidRequestError(
+ "WHERE clause with bulk ORM UPDATE not "
+ "supported right now. Statement may be invoked at the "
+ "Core level using "
+ "session.connection().execute(stmt, parameters)"
+ )
+ mapper = update_options._subject_mapper
+ assert mapper is not None
+ assert session._transaction is not None
+ result = _bulk_update(
+ mapper,
+ cast(
+ "Iterable[Dict[str, Any]]",
+ [params] if isinstance(params, dict) else params,
+ ),
+ session._transaction,
+ isstates=False,
+ update_changed_only=False,
+ use_orm_update_stmt=statement,
+ )
+ return cls.orm_setup_cursor_result(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+ else:
+ return super().orm_execute_statement(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ conn,
+ )
@classmethod
def can_use_returning(
@@ -827,119 +1559,80 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
return True
@classmethod
- def _get_crud_kv_pairs(cls, statement, kv_iterator):
- plugin_subject = statement._propagate_attrs["plugin_subject"]
-
- core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
-
- if not plugin_subject or not plugin_subject.mapper:
- return core_get_crud_kv_pairs(statement, kv_iterator)
-
- mapper = plugin_subject.mapper
-
- values = []
-
- for k, v in kv_iterator:
- k = coercions.expect(roles.DMLColumnRole, k)
+ def _do_post_synchronize_bulk_evaluate(
+ cls, session, params, result, update_options
+ ):
+ if not params:
+ return
- if isinstance(k, str):
- desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
- if desc is NO_VALUE:
- values.append(
- (
- k,
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- ),
- )
- )
- else:
- values.extend(
- core_get_crud_kv_pairs(
- statement, desc._bulk_update_tuples(v)
- )
- )
- elif "entity_namespace" in k._annotations:
- k_anno = k._annotations
- attr = _entity_namespace_key(
- k_anno["entity_namespace"], k_anno["proxy_key"]
- )
- values.extend(
- core_get_crud_kv_pairs(
- statement, attr._bulk_update_tuples(v)
- )
- )
- else:
- values.append(
- (
- k,
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- ),
- )
- )
- return values
+ mapper = update_options._subject_mapper
+ pk_keys = [prop.key for prop in mapper._identity_key_props]
- @classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
+ identity_map = session.identity_map
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
- for obj in update_options._matched_objects:
-
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
+ for param in params:
+ identity_key = mapper.identity_key_from_primary_key(
+ (param[key] for key in pk_keys),
+ update_options._refresh_identity_token,
)
-
- # the evaluated states were gathered across all identity tokens.
- # however the post_sync events are called per identity token,
- # so filter.
- if (
- update_options._refresh_identity_token is not None
- and state.identity_token
- != update_options._refresh_identity_token
- ):
+ state = identity_map.fast_get_state(identity_key)
+ if not state:
continue
+ evaluated_keys = set(param).difference(pk_keys)
+
+ dict_ = state.dict
# only evaluate unmodified attributes
to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
+ dict_[key] = param[key]
state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
- to_expire = attrib.intersection(dict_).difference(to_evaluate)
+ # attributes that were formerly modified instead get expired.
+ # this only gets hit if the session had pending changes
+ # and autoflush were set to False.
+ to_expire = evaluated_keys.intersection(dict_).difference(
+ to_evaluate
+ )
if to_expire:
state._expire_attributes(dict_, to_expire)
- states.add(state)
- session._register_altered(states)
+ @classmethod
+ def _do_post_synchronize_evaluate(
+ cls, session, statement, result, update_options
+ ):
+
+ matched_objects = cls._get_matched_objects_on_criteria(
+ update_options,
+ session.identity_map.all_states(),
+ )
+
+ cls._apply_update_set_values_to_objects(
+ session,
+ update_options,
+ statement,
+ [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
+ )
@classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
+ def _do_post_synchronize_fetch(
+ cls, session, statement, result, update_options
+ ):
target_mapper = update_options._subject_mapper
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
-
- if result.returns_rows:
- rows = cls._interpret_returning_rows(target_mapper, result.all())
+ returned_defaults_rows = result.returned_defaults_rows
+ if returned_defaults_rows:
+ pk_rows = cls._interpret_returning_rows(
+ target_mapper, returned_defaults_rows
+ )
matched_rows = [
tuple(row) + (update_options._refresh_identity_token,)
- for row in rows
+ for row in pk_rows
]
else:
matched_rows = update_options._matched_rows
@@ -960,23 +1653,69 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
if identity_key in session.identity_map
]
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
+ if not objs:
+ return
- for obj in objs:
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
- )
+ cls._apply_update_set_values_to_objects(
+ session,
+ update_options,
+ statement,
+ [
+ (
+ obj,
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
+ for obj in objs
+ ],
+ )
+
+ @classmethod
+ def _apply_update_set_values_to_objects(
+ cls, session, update_options, statement, matched_objects
+ ):
+ """apply values to objects derived from an update statement, e.g.
+ UPDATE..SET <values>
+
+ """
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ resolved_values = cls._get_resolved_values(mapper, statement)
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ value_evaluators = {}
+ for key, value in resolved_keys_as_propnames:
+ try:
+ _evaluator = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ value_evaluators[key] = _evaluator
+
+ evaluated_keys = list(value_evaluators.keys())
+ attrib = set(k for k, v in resolved_keys_as_propnames)
+
+ states = set()
+ for obj, state, dict_ in matched_objects:
to_evaluate = state.unmodified.intersection(evaluated_keys)
+
for key in to_evaluate:
if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
+ # only run eval for attributes that are present.
+ dict_[key] = value_evaluators[key](obj)
+
state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
+ # attributes that were formerly modified instead get expired.
+ # this only gets hit if the session had pending changes
+ # and autoflush were set to False.
to_expire = attrib.intersection(dict_).difference(to_evaluate)
if to_expire:
state._expire_attributes(dict_, to_expire)
@@ -991,6 +1730,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
+ orm_level_statement = statement
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
@@ -1002,31 +1743,97 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
if opt._is_criteria_option:
opt.get_global_criteria(extra_criteria_attributes)
+ new_stmt = statement._clone()
+ new_stmt.table = mapper.local_table
+
new_crit = cls._adjust_for_extra_criteria(
extra_criteria_attributes, mapper
)
if new_crit:
- statement = statement.where(*new_crit)
+ new_stmt = new_stmt.where(*new_crit)
# do this first as we need to determine if there is
# DELETE..FROM
- DeleteDMLState.__init__(self, statement, compiler, **kw)
+ DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
+
+ use_supplemental_cols = False
- if compiler._annotations.get(
+ synchronize_session = compiler._annotations.get(
"synchronize_session", None
- ) == "fetch" and self.can_use_returning(
- compiler.dialect,
- mapper,
- is_multitable=self.is_multitable,
- is_delete_using=compiler._annotations.get(
- "is_delete_using", False
- ),
- ):
- self.statement = statement.returning(*statement.table.primary_key)
+ )
+ can_use_returning = compiler._annotations.get(
+ "can_use_returning", None
+ )
+ if can_use_returning is not False:
+ # even though pre_exec has determined basic
+ # can_use_returning for the dialect, if we are to use
+ # RETURNING we need to run can_use_returning() at this level
+ # unconditionally because is_delete_using was not known
+ # at the pre_exec level
+ can_use_returning = (
+ synchronize_session == "fetch"
+ and self.can_use_returning(
+ compiler.dialect,
+ mapper,
+ is_multitable=self.is_multitable,
+ is_delete_using=compiler._annotations.get(
+ "is_delete_using", False
+ ),
+ )
+ )
+
+ if can_use_returning:
+ use_supplemental_cols = True
+
+ new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
+
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ use_supplemental_cols=use_supplemental_cols,
+ )
+
+ self.statement = new_stmt
return self
@classmethod
+ def orm_execute_statement(
+ cls,
+ session: Session,
+ statement: dml.Delete,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ update_options = execution_options.get(
+ "_sa_orm_update_options", cls.default_update_options
+ )
+
+ if update_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ "Bulk ORM DELETE not supported right now. "
+ "Statement may be invoked at the "
+ "Core level using "
+ "session.connection().execute(stmt, parameters)"
+ )
+
+ if update_options._dml_strategy not in (
+ "orm",
+ "auto",
+ ):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM DELETE strategy are 'orm', 'auto'"
+ )
+
+ return super().orm_execute_statement(
+ session, statement, params, execution_options, bind_arguments, conn
+ )
+
+ @classmethod
def can_use_returning(
cls,
dialect: Dialect,
@@ -1068,25 +1875,41 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
return True
@classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
-
- session._remove_newly_deleted(
- [
- attributes.instance_state(obj)
- for obj in update_options._matched_objects
- ]
+ def _do_post_synchronize_evaluate(
+ cls, session, statement, result, update_options
+ ):
+ matched_objects = cls._get_matched_objects_on_criteria(
+ update_options,
+ session.identity_map.all_states(),
)
+ to_delete = []
+
+ for _, state, dict_, is_partially_expired in matched_objects:
+ if is_partially_expired:
+ state._expire(dict_, session.identity_map._modified)
+ else:
+ to_delete.append(state)
+
+ if to_delete:
+ session._remove_newly_deleted(to_delete)
+
@classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
+ def _do_post_synchronize_fetch(
+ cls, session, statement, result, update_options
+ ):
target_mapper = update_options._subject_mapper
- if result.returns_rows:
- rows = cls._interpret_returning_rows(target_mapper, result.all())
+ returned_defaults_rows = result.returned_defaults_rows
+
+ if returned_defaults_rows:
+ pk_rows = cls._interpret_returning_rows(
+ target_mapper, returned_defaults_rows
+ )
matched_rows = [
tuple(row) + (update_options._refresh_identity_token,)
- for row in rows
+ for row in pk_rows
]
else:
matched_rows = update_options._matched_rows