summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-08-30 18:13:36 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-08-30 19:45:04 -0400
commit575b6dded9a25fca693f0aa7f6d7c6e735490460 (patch)
tree4bc0c76ee49bdac200abe0ec73ade88564b727c0 /lib
parent406034d41a764f6fe24374d40c95e79d295f6e80 (diff)
downloadsqlalchemy-575b6dded9a25fca693f0aa7f6d7c6e735490460.tar.gz
Support extra / single inh criteria with ORM update/delete
The ORM bulk update and delete operations, historically available via the :meth:`_orm.Query.update` and :meth:`_orm.Query.delete` methods as well as via the :class:`_dml.Update` and :class:`_dml.Delete` constructs for :term:`2.0 style` execution, will now automatically accommodate for the additional WHERE criteria needed for a single-table inheritance discrminiator. Joined-table inheritance is still not directly supported. The new :func:`_orm.with_loader_criteria` construct is also supported for all mappings with bulk update/delete. Fixes: #5018 Fixes: #3903 Change-Id: Id90827cc7e2bc713d1255127f908c8e133de9295
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/interfaces.py23
-rw-r--r--lib/sqlalchemy/orm/persistence.py110
-rw-r--r--lib/sqlalchemy/orm/util.py9
-rw-r--r--lib/sqlalchemy/sql/dml.py23
4 files changed, 146 insertions, 19 deletions
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 068c85073..b1ff1a049 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -729,6 +729,8 @@ class ORMOption(ExecutableOption):
_is_compile_state = False
+ _is_criteria_option = False
+
class LoaderOption(ORMOption):
"""Describe a loader modification to an ORM statement at compilation time.
@@ -743,6 +745,27 @@ class LoaderOption(ORMOption):
"""Apply a modification to a given :class:`.CompileState`."""
+class CriteriaOption(ORMOption):
+ """Describe a WHERE criteria modification to an ORM statement at
+ compilation time.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _is_compile_state = True
+ _is_criteria_option = True
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ def get_global_criteria(self, attributes):
+ """update additional entity criteria options in the given
+ attributes dictionary.
+
+ """
+
+
class UserDefinedOption(ORMOption):
"""Base class for a user-defined option that can be consumed from the
:meth:`.SessionEvents.do_orm_execute` event hook.
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 49b29a6bc..d05381c1d 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -1857,6 +1857,43 @@ class BulkUDCompileState(CompileState):
return result
@classmethod
+ def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
+ """Apply extra criteria filtering.
+
+ For all distinct single-table-inheritance mappers represented in the
+ table being updated or deleted, produce additional WHERE criteria such
+ that only the appropriate subtypes are selected from the total results.
+
+ Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+ collected from the statement.
+
+ """
+
+ return_crit = ()
+
+ adapter = ext_info._adapter if ext_info.is_aliased_class else None
+
+ if (
+ "additional_entity_criteria",
+ ext_info.mapper,
+ ) in global_attributes:
+ return_crit += tuple(
+ ae._resolve_where_criteria(ext_info)
+ for ae in global_attributes[
+ ("additional_entity_criteria", ext_info.mapper)
+ ]
+ if ae.include_aliases or ae.entity is ext_info
+ )
+
+ if ext_info.mapper._single_table_criterion is not None:
+ return_crit += (ext_info.mapper._single_table_criterion,)
+
+ if adapter:
+ return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
+
+ return return_crit
+
+ @classmethod
def _do_pre_synchronize_evaluate(
cls,
session,
@@ -1873,10 +1910,22 @@ class BulkUDCompileState(CompileState):
try:
evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ crit = ()
if statement._where_criteria:
- eval_condition = evaluator_compiler.process(
- *statement._where_criteria
+ crit += statement._where_criteria
+
+ global_attributes = {}
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(global_attributes)
+
+ if global_attributes:
+ crit += cls._adjust_for_extra_criteria(
+ global_attributes, mapper
)
+
+ if crit:
+ eval_condition = evaluator_compiler.process(*crit)
else:
def eval_condition(obj):
@@ -1920,16 +1969,17 @@ class BulkUDCompileState(CompileState):
# TODO: detect when the where clause is a trivial primary key match.
matched_objects = [
- obj
- for (cls, pk, identity_token,), obj in session.identity_map.items()
- if issubclass(cls, target_cls)
- and eval_condition(obj)
+ state.obj()
+ for state in session.identity_map.all_states()
+ if state.mapper.isa(mapper)
+ and eval_condition(state.obj())
and (
update_options._refresh_identity_token is None
# TODO: coverage for the case where horiziontal sharding
# invokes an update() or delete() given an explicit identity
# token up front
- or identity_token == update_options._refresh_identity_token
+ or state.identity_token
+ == update_options._refresh_identity_token
)
]
return update_options + {
@@ -2003,8 +2053,10 @@ class BulkUDCompileState(CompileState):
):
mapper = update_options._subject_mapper
- select_stmt = select(
- *(mapper.primary_key + (mapper.select_identity_token,))
+ select_stmt = (
+ select(*(mapper.primary_key + (mapper.select_identity_token,)))
+ .select_from(mapper)
+ .options(*statement._with_options)
)
select_stmt._where_criteria = statement._where_criteria
@@ -2075,12 +2127,20 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
self = cls.__new__(cls)
- self.mapper = mapper = statement.table._annotations.get(
- "parentmapper", None
- )
+ ext_info = statement.table._annotations["parententity"]
+
+ self.mapper = mapper = ext_info.mapper
+
+ self.extra_criteria_entities = {}
self._resolved_values = cls._get_resolved_values(mapper, statement)
+ extra_criteria_attributes = {}
+
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(extra_criteria_attributes)
+
if not statement._preserve_parameter_order and statement._values:
self._resolved_values = dict(self._resolved_values)
@@ -2097,6 +2157,12 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
elif statement._values:
new_stmt._values = self._resolved_values
+ new_crit = cls._adjust_for_extra_criteria(
+ extra_criteria_attributes, mapper
+ )
+ if new_crit:
+ new_stmt = new_stmt.where(*new_crit)
+
# if we are against a lambda statement we might not be the
# topmost object that received per-execute annotations
top_level_stmt = compiler.statement
@@ -2211,11 +2277,25 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
- self.mapper = mapper = statement.table._annotations.get(
- "parentmapper", None
- )
+ ext_info = statement.table._annotations["parententity"]
+ self.mapper = mapper = ext_info.mapper
top_level_stmt = compiler.statement
+
+ self.extra_criteria_entities = {}
+
+ extra_criteria_attributes = {}
+
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(extra_criteria_attributes)
+
+ new_crit = cls._adjust_for_extra_criteria(
+ extra_criteria_attributes, mapper
+ )
+ if new_crit:
+ statement = statement.where(*new_crit)
+
if (
mapper
and top_level_stmt._annotations.get("synchronize_session", None)
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 82fad0815..271a441f0 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -23,7 +23,7 @@ from .base import object_state # noqa
from .base import state_attribute_str # noqa
from .base import state_class_str # noqa
from .base import state_str # noqa
-from .interfaces import LoaderOption
+from .interfaces import CriteriaOption
from .interfaces import MapperProperty # noqa
from .interfaces import ORMColumnsClauseRole
from .interfaces import ORMEntityColumnsClauseRole
@@ -856,7 +856,7 @@ class AliasedInsp(
return "aliased(%s)" % (self._target.__name__,)
-class LoaderCriteriaOption(LoaderOption):
+class LoaderCriteriaOption(CriteriaOption):
"""Add additional WHERE criteria to the load for all occurrences of
a particular entity.
@@ -1026,8 +1026,11 @@ class LoaderCriteriaOption(LoaderOption):
# if options to limit the criteria to immediate query only,
# use compile_state.attributes instead
+ self.get_global_criteria(compile_state.global_attributes)
+
+ def get_global_criteria(self, attributes):
for mp in self._all_mappers():
- load_criteria = compile_state.global_attributes.setdefault(
+ load_criteria = attributes.setdefault(
("additional_entity_criteria", mp), []
)
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index a9bccaeff..b7151ac7b 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -12,6 +12,7 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and
from sqlalchemy.types import NullType
from . import coercions
from . import roles
+from .base import _entity_namespace_key
from .base import _from_objects
from .base import _generative
from .base import ColumnCollection
@@ -983,10 +984,30 @@ class DMLWhereBase(object):
)
def filter(self, *criteria):
- """A synonym for the :meth:`_dml.DMLWhereBase.where` method."""
+ """A synonym for the :meth:`_dml.DMLWhereBase.where` method.
+
+ .. versionadded:: 1.4
+
+ """
return self.where(*criteria)
+ def _filter_by_zero(self):
+ return self.table
+
+ def filter_by(self, **kwargs):
+ r"""apply the given filtering criterion as a WHERE clause
+ to this select.
+
+ """
+ from_entity = self._filter_by_zero()
+
+ clauses = [
+ _entity_namespace_key(from_entity, key) == value
+ for key, value in kwargs.items()
+ ]
+ return self.filter(*clauses)
+
@property
def whereclause(self):
"""Return the completed WHERE clause for this :class:`.DMLWhereBase`