summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-06-03 17:38:35 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-06-06 13:31:54 -0400
commit3ab2364e78641c4f0e4b6456afc2cbed39b0d0e6 (patch)
treef3dc26609070c1a357a366592c791a3ec0655483 /lib/sqlalchemy
parent14bc09203a8b5b2bc001f764ad7cce6a184975cc (diff)
downloadsqlalchemy-3ab2364e78641c4f0e4b6456afc2cbed39b0d0e6.tar.gz
Convert bulk update/delete to new execution model
This reorganizes the BulkUD model in sqlalchemy.orm.persistence to be based on the CompileState concept and to allow plain update() / delete() to be passed to session.execute() where the ORM synchronize session logic will take place. Also gets "synchronize_session='fetch'" working with horizontal sharding. Adding a few more result.scalar_one() types of methods as scalar_one() seems like what is normally desired. Fixes: #5160 Change-Id: I8001ebdad089da34119eb459709731ba6c0ba975
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/engine/cursor.py9
-rw-r--r--lib/sqlalchemy/engine/result.py85
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py95
-rw-r--r--lib/sqlalchemy/ext/hybrid.py5
-rw-r--r--lib/sqlalchemy/orm/context.py4
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py3
-rw-r--r--lib/sqlalchemy/orm/events.py14
-rw-r--r--lib/sqlalchemy/orm/mapper.py30
-rw-r--r--lib/sqlalchemy/orm/persistence.py498
-rw-r--r--lib/sqlalchemy/orm/query.py115
-rw-r--r--lib/sqlalchemy/orm/session.py114
-rw-r--r--lib/sqlalchemy/sql/base.py10
-rw-r--r--lib/sqlalchemy/sql/coercions.py10
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
-rw-r--r--lib/sqlalchemy/sql/dml.py32
-rw-r--r--lib/sqlalchemy/sql/roles.py5
-rw-r--r--lib/sqlalchemy/sql/selectable.py6
-rw-r--r--lib/sqlalchemy/sql/traversals.py58
-rw-r--r--lib/sqlalchemy/testing/assertions.py4
-rw-r--r--lib/sqlalchemy/util/__init__.py1
-rw-r--r--lib/sqlalchemy/util/compat.py1
21 files changed, 667 insertions, 436 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index 1d832e4af..d03d79df7 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -1630,6 +1630,15 @@ class CursorResult(BaseCursorResult, Result):
def _raw_row_iterator(self):
return self._fetchiter_impl()
+ def merge(self, *others):
+ merged_result = super(CursorResult, self).merge(*others)
+ setup_rowcounts = not self._metadata.returns_rows
+ if setup_rowcounts:
+ merged_result.rowcount = sum(
+ result.rowcount for result in (self,) + others
+ )
+ return merged_result
+
def close(self):
"""Close this :class:`_engine.CursorResult`.
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index 600229037..b29bc22d4 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -951,7 +951,7 @@ class Result(InPlaceGenerative):
"""
return self._allrows()
- def _only_one_row(self, raise_for_second_row, raise_for_none):
+ def _only_one_row(self, raise_for_second_row, raise_for_none, scalar):
onerow = self._fetchone_impl
row = onerow(hard_close=True)
@@ -1010,27 +1010,43 @@ class Result(InPlaceGenerative):
# if we checked for second row then that would have
# closed us :)
self._soft_close(hard=True)
- post_creational_filter = self._post_creational_filter
- if post_creational_filter:
- row = post_creational_filter(row)
- return row
+ if not scalar:
+ post_creational_filter = self._post_creational_filter
+ if post_creational_filter:
+ row = post_creational_filter(row)
+
+ if scalar and row:
+ return row[0]
+ else:
+ return row
def first(self):
"""Fetch the first row or None if no row is present.
Closes the result set and discards remaining rows.
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the :meth:`.Result.scalar` method,
+ or combine :meth:`.Result.scalars` and :meth:`.Result.first`.
+
.. comment: A warning is emitted if additional rows remain.
:return: a :class:`.Row` object if no filters are applied, or None
if no rows remain.
When filters are applied, such as :meth:`_engine.Result.mappings`
- or :meth:`._engine.Result.scalar`, different kinds of objects
+ or :meth:`._engine.Result.scalars`, different kinds of objects
may be returned.
+ .. seealso::
+
+ :meth:`_result.Result.scalar`
+
+ :meth:`_result.Result.one`
+
"""
- return self._only_one_row(False, False)
+ return self._only_one_row(False, False, False)
def one_or_none(self):
"""Return at most one result or raise an exception.
@@ -1055,15 +1071,50 @@ class Result(InPlaceGenerative):
:meth:`_result.Result.one`
"""
- return self._only_one_row(True, False)
+ return self._only_one_row(True, False, False)
+
+ def scalar_one(self):
+ """Return exactly one scalar result or raise an exception.
+
+ This is equvalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one`.
+
+ .. seealso::
+
+ :meth:`.Result.one`
+
+ :meth:`.Result.scalars`
+
+ """
+ return self._only_one_row(True, True, True)
+
+ def scalar_one_or_none(self):
+ """Return exactly one or no scalar result.
+
+ This is equvalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one_or_none`.
+
+ .. seealso::
+
+ :meth:`.Result.one_or_none`
+
+ :meth:`.Result.scalars`
+
+ """
+ return self._only_one_row(True, False, True)
def one(self):
- """Return exactly one result or raise an exception.
+ """Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
rows, or :class:`.MultipleResultsFound` if multiple rows
would be returned.
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the :meth:`.Result.scalar_one` method,
+ or combine :meth:`.Result.scalars` and :meth:`.Result.one`.
+
.. versionadded:: 1.4
:return: The first :class:`.Row`.
@@ -1079,24 +1130,26 @@ class Result(InPlaceGenerative):
:meth:`_result.Result.one_or_none`
+ :meth:`_result.Result.scalar_one`
+
"""
- return self._only_one_row(True, True)
+ return self._only_one_row(True, True, False)
def scalar(self):
"""Fetch the first column of the first row, and close the result set.
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
After calling this method, the object is fully closed,
e.g. the :meth:`_engine.CursorResult.close`
method will have been called.
- :return: a Python scalar value , or None if no rows remain
+ :return: a Python scalar value , or None if no rows remain.
"""
- row = self.first()
- if row is not None:
- return row[0]
- else:
- return None
+ return self._only_one_row(False, False, True)
class FrozenResult(object):
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index c3ac71c10..0983807cb 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -50,58 +50,6 @@ class ShardedQuery(Query):
"""
return self.execution_options(_sa_shard_id=shard_id)
- def _execute_crud(self, stmt, mapper):
- def exec_for_shard(shard_id):
- conn = self.session.connection(
- mapper=mapper,
- shard_id=shard_id,
- clause=stmt,
- close_with_result=True,
- )
- result = conn._execute_20(
- stmt, self.load_options._params, self._execution_options
- )
- return result
-
- if self._shard_id is not None:
- return exec_for_shard(self._shard_id)
- else:
- rowcount = 0
- results = []
- # TODO: this will have to be the new object
- for shard_id in self.execute_chooser(self):
- result = exec_for_shard(shard_id)
- rowcount += result.rowcount
- results.append(result)
-
- return ShardedResult(results, rowcount)
-
-
-class ShardedResult(object):
- """A value object that represents multiple :class:`_engine.CursorResult`
- objects.
-
- This is used by the :meth:`.ShardedQuery._execute_crud` hook to return
- an object that takes the place of the single :class:`_engine.CursorResult`.
-
- Attribute include ``result_proxies``, which is a sequence of the
- actual :class:`_engine.CursorResult` objects,
- as well as ``aggregate_rowcount``
- or ``rowcount``, which is the sum of all the individual rowcount values.
-
- .. versionadded:: 1.3
- """
-
- __slots__ = ("result_proxies", "aggregate_rowcount")
-
- def __init__(self, result_proxies, aggregate_rowcount):
- self.result_proxies = result_proxies
- self.aggregate_rowcount = aggregate_rowcount
-
- @property
- def rowcount(self):
- return self.aggregate_rowcount
-
class ShardedSession(Session):
def __init__(
@@ -259,37 +207,40 @@ class ShardedSession(Session):
def execute_and_instances(orm_context):
- if orm_context.bind_arguments.get("_horizontal_shard", False):
- return None
-
params = orm_context.parameters
- load_options = orm_context.load_options
+ if orm_context.is_select:
+ load_options = active_options = orm_context.load_options
+ update_options = None
+ if params is None:
+ params = active_options._params
+
+ else:
+ load_options = None
+ update_options = active_options = orm_context.update_delete_options
+
session = orm_context.session
# orm_query = orm_context.orm_query
- if params is None:
- params = load_options._params
-
- def iter_for_shard(shard_id, load_options):
+ def iter_for_shard(shard_id, load_options, update_options):
execution_options = dict(orm_context.local_execution_options)
bind_arguments = dict(orm_context.bind_arguments)
- bind_arguments["_horizontal_shard"] = True
bind_arguments["shard_id"] = shard_id
- load_options += {"_refresh_identity_token": shard_id}
- execution_options["_sa_orm_load_options"] = load_options
+ if orm_context.is_select:
+ load_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_load_options"] = load_options
+ else:
+ update_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_update_options"] = update_options
- return session.execute(
- orm_context.statement,
- orm_context.parameters,
- execution_options,
- bind_arguments,
+ return orm_context.invoke_statement(
+ bind_arguments=bind_arguments, execution_options=execution_options
)
- if load_options._refresh_identity_token is not None:
- shard_id = load_options._refresh_identity_token
+ if active_options._refresh_identity_token is not None:
+ shard_id = active_options._refresh_identity_token
elif "_sa_shard_id" in orm_context.merged_execution_options:
shard_id = orm_context.merged_execution_options["_sa_shard_id"]
elif "shard_id" in orm_context.bind_arguments:
@@ -298,11 +249,11 @@ def execute_and_instances(orm_context):
shard_id = None
if shard_id is not None:
- return iter_for_shard(shard_id, load_options)
+ return iter_for_shard(shard_id, load_options, update_options)
else:
partial = []
for shard_id in session.execute_chooser(orm_context):
- result_ = iter_for_shard(shard_id, load_options)
+ result_ = iter_for_shard(shard_id, load_options, update_options)
partial.append(result_)
return partial[0].merge(*partial[1:])
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
index 9f73b5d31..efd8d7d6b 100644
--- a/lib/sqlalchemy/ext/hybrid.py
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -777,7 +777,7 @@ things it can be used for.
from .. import util
from ..orm import attributes
from ..orm import interfaces
-
+from ..sql import elements
HYBRID_METHOD = util.symbol("HYBRID_METHOD")
"""Symbol indicating an :class:`InspectionAttr` that's
@@ -1144,6 +1144,9 @@ class ExprComparator(Comparator):
return self.hybrid.info
def _bulk_update_tuples(self, value):
+ if isinstance(value, elements.BindParameter):
+ value = value.value
+
if isinstance(self.expression, attributes.QueryableAttribute):
return self.expression._bulk_update_tuples(value)
elif self.hybrid.update_expr is not None:
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index bd4074ea1..a16db66f6 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -189,7 +189,7 @@ class ORMCompileState(CompileState):
@classmethod
def orm_pre_session_exec(
- cls, session, statement, execution_options, bind_arguments
+ cls, session, statement, params, execution_options, bind_arguments
):
load_options = execution_options.get(
"_sa_orm_load_options", QueryContext.default_load_options
@@ -216,6 +216,8 @@ class ORMCompileState(CompileState):
if load_options._autoflush:
session._autoflush()
+ return execution_options
+
@classmethod
def orm_setup_cursor_result(
cls, session, statement, execution_options, bind_arguments, result
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 6be4f0dff..027f2521b 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -420,6 +420,9 @@ class CompositeProperty(DescriptorProperty):
return CompositeProperty.CompositeBundle(self.prop, clauses)
def _bulk_update_tuples(self, value):
+ if isinstance(value, sql.elements.BindParameter):
+ value = value.value
+
if value is None:
values = [None for key in self.prop._attribute_keys]
elif isinstance(value, self.prop.composite_class):
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
index be7aa272e..217aa76c7 100644
--- a/lib/sqlalchemy/orm/events.py
+++ b/lib/sqlalchemy/orm/events.py
@@ -1764,7 +1764,7 @@ class SessionEvents(event.Events):
lambda update_context: (
update_context.session,
update_context.query,
- update_context.context,
+ None,
update_context.result,
),
)
@@ -1782,12 +1782,13 @@ class SessionEvents(event.Events):
was called upon.
* ``values`` The "values" dictionary that was passed to
:meth:`_query.Query.update`.
- * ``context`` The :class:`.QueryContext` object, corresponding
- to the invocation of an ORM query.
* ``result`` the :class:`_engine.CursorResult`
returned as a result of the
bulk UPDATE operation.
+ .. versionchanged:: 1.4 the update_context no longer has a
+ ``QueryContext`` object associated with it.
+
.. seealso::
:meth:`.QueryEvents.before_compile_update`
@@ -1802,7 +1803,7 @@ class SessionEvents(event.Events):
lambda delete_context: (
delete_context.session,
delete_context.query,
- delete_context.context,
+ None,
delete_context.result,
),
)
@@ -1818,12 +1819,13 @@ class SessionEvents(event.Events):
* ``query`` -the :class:`_query.Query`
object that this update operation
was called upon.
- * ``context`` The :class:`.QueryContext` object, corresponding
- to the invocation of an ORM query.
* ``result`` the :class:`_engine.CursorResult`
returned as a result of the
bulk DELETE operation.
+ .. versionchanged:: 1.4 the update_context no longer has a
+ ``QueryContext`` object associated with it.
+
.. seealso::
:meth:`.QueryEvents.before_compile_delete`
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 4166e6d2a..c4cb89c03 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -2235,14 +2235,28 @@ class Mapper(
@HasMemoized.memoized_instancemethod
def __clause_element__(self):
- return self.selectable._annotate(
- {
- "entity_namespace": self,
- "parententity": self,
- "parentmapper": self,
- "compile_state_plugin": "orm",
- }
- )._set_propagate_attrs(
+
+ annotations = {
+ "entity_namespace": self,
+ "parententity": self,
+ "parentmapper": self,
+ "compile_state_plugin": "orm",
+ }
+ if self.persist_selectable is not self.local_table:
+ # joined table inheritance, with polymorphic selectable,
+ # etc.
+ annotations["dml_table"] = self.local_table._annotate(
+ {
+ "entity_namespace": self,
+ "parententity": self,
+ "parentmapper": self,
+ "compile_state_plugin": "orm",
+ }
+ )._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+
+ return self.selectable._annotate(annotations)._set_propagate_attrs(
{"compile_state_plugin": "orm", "plugin_subject": self}
)
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 163ebf22a..19d43d354 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -28,11 +28,15 @@ from .. import exc as sa_exc
from .. import future
from .. import sql
from .. import util
+from ..future import select as future_select
from ..sql import coercions
from ..sql import expression
from ..sql import operators
from ..sql import roles
-from ..sql.base import _from_objects
+from ..sql.base import CompileState
+from ..sql.base import Options
+from ..sql.dml import DeleteDMLState
+from ..sql.dml import UpdateDMLState
from ..sql.elements import BooleanClauseList
@@ -1650,243 +1654,193 @@ def _sort_states(mapper, states):
)
-class BulkUD(object):
- """Handle bulk update and deletes via a :class:`_query.Query`."""
+_EMPTY_DICT = util.immutabledict()
- def __init__(self, query):
- self.query = query.enable_eagerloads(False)
- self._validate_query_state()
- def _validate_query_state(self):
- for attr, methname, notset, op in (
- ("_limit_clause", "limit()", None, operator.is_),
- ("_offset_clause", "offset()", None, operator.is_),
- ("_order_by_clauses", "order_by()", (), operator.eq),
- ("_group_by_clauses", "group_by()", (), operator.eq),
- ("_distinct", "distinct()", False, operator.is_),
- (
- "_from_obj",
- "join(), outerjoin(), select_from(), or from_self()",
- (),
- operator.eq,
- ),
- (
- "_legacy_setup_joins",
- "join(), outerjoin(), select_from(), or from_self()",
- (),
- operator.eq,
- ),
- ):
- if not op(getattr(self.query, attr), notset):
- raise sa_exc.InvalidRequestError(
- "Can't call Query.update() or Query.delete() "
- "when %s has been called" % (methname,)
- )
-
- @property
- def session(self):
- return self.query.session
+class BulkUDCompileState(CompileState):
+ class default_update_options(Options):
+ _synchronize_session = "evaluate"
+ _autoflush = True
+ _subject_mapper = None
+ _resolved_values = _EMPTY_DICT
+ _resolved_keys_as_propnames = _EMPTY_DICT
+ _value_evaluators = _EMPTY_DICT
+ _matched_objects = None
+ _matched_rows = None
+ _refresh_identity_token = None
@classmethod
- def _factory(cls, lookup, synchronize_session, *arg):
- try:
- klass = lookup[synchronize_session]
- except KeyError as err:
- util.raise_(
- sa_exc.ArgumentError(
- "Valid strategies for session synchronization "
- "are %s" % (", ".join(sorted(repr(x) for x in lookup)))
- ),
- replace_context=err,
+ def orm_pre_session_exec(
+ cls, session, statement, params, execution_options, bind_arguments
+ ):
+ sync = execution_options.get("synchronize_session", None)
+ if sync is None:
+ sync = statement._execution_options.get(
+ "synchronize_session", None
)
- else:
- return klass(*arg)
-
- def exec_(self):
- self._do_before_compile()
- self._do_pre()
- self._do_pre_synchronize()
- self._do_exec()
- self._do_post_synchronize()
- self._do_post()
-
- def _execute_stmt(self, stmt):
- self.result = self.query._execute_crud(stmt, self.mapper)
- self.rowcount = self.result.rowcount
-
- def _do_before_compile(self):
- raise NotImplementedError()
- @util.preload_module("sqlalchemy.orm.context")
- def _do_pre(self):
- query_context = util.preloaded.orm_context
- query = self.query
-
- self.compile_state = (
- self.context
- ) = compile_state = query._compile_state()
-
- self.mapper = compile_state._entity_zero()
-
- if isinstance(
- compile_state._entities[0], query_context._RawColumnEntity,
- ):
- # check for special case of query(table)
- tables = set()
- for ent in compile_state._entities:
- if not isinstance(ent, query_context._RawColumnEntity,):
- tables.clear()
- break
- else:
- tables.update(_from_objects(ent.column))
+ update_options = execution_options.get(
+ "_sa_orm_update_options",
+ BulkUDCompileState.default_update_options,
+ )
- if len(tables) != 1:
- raise sa_exc.InvalidRequestError(
- "This operation requires only one Table or "
- "entity be specified as the target."
+ if sync is not None:
+ if sync not in ("evaluate", "fetch", False):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for session synchronization "
+ "are 'evaluate', 'fetch', False"
)
- else:
- self.primary_table = tables.pop()
+ update_options += {"_synchronize_session": sync}
+ bind_arguments["clause"] = statement
+ try:
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+ except KeyError:
+ assert False, "statement had 'orm' plugin but no plugin_subject"
else:
- self.primary_table = compile_state._only_entity_zero(
- "This operation requires only one Table or "
- "entity be specified as the target."
- ).mapper.local_table
+ bind_arguments["mapper"] = plugin_subject.mapper
- session = query.session
+ update_options += {"_subject_mapper": plugin_subject.mapper}
- if query.load_options._autoflush:
+ if update_options._autoflush:
session._autoflush()
- def _do_pre_synchronize(self):
- pass
+ if update_options._synchronize_session == "evaluate":
+ update_options = cls._do_pre_synchronize_evaluate(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "fetch":
+ update_options = cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
- def _do_post_synchronize(self):
- pass
+ return util.immutabledict(execution_options).union(
+ dict(_sa_orm_update_options=update_options)
+ )
+ @classmethod
+ def orm_setup_cursor_result(
+ cls, session, statement, execution_options, bind_arguments, result
+ ):
+ update_options = execution_options["_sa_orm_update_options"]
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_evaluate(session, update_options)
+ elif update_options._synchronize_session == "fetch":
+ cls._do_post_synchronize_fetch(session, update_options)
-class BulkEvaluate(BulkUD):
- """BulkUD which does the 'evaluate' method of session state resolution."""
+ return result
- def _additional_evaluators(self, evaluator_compiler):
- pass
+ @classmethod
+ def _do_pre_synchronize_evaluate(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
- def _do_pre_synchronize(self):
- query = self.query
- target_cls = self.compile_state._mapper_zero().class_
+ value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
try:
evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- if query._where_criteria:
+ if statement._where_criteria:
eval_condition = evaluator_compiler.process(
- *query._where_criteria
+ *statement._where_criteria
)
else:
def eval_condition(obj):
return True
- self._additional_evaluators(evaluator_compiler)
+ # TODO: something more robust for this conditional
+ if statement.__visit_name__ == "update":
+ resolved_values = cls._get_resolved_values(mapper, statement)
+ value_evaluators = {}
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ for key, value in resolved_keys_as_propnames:
+ value_evaluators[key] = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
except evaluator.UnevaluatableError as err:
util.raise_(
sa_exc.InvalidRequestError(
'Could not evaluate current criteria in Python: "%s". '
"Specify 'fetch' or False for the "
- "synchronize_session parameter." % err
+ "synchronize_session execution option." % err
),
from_=err,
)
# TODO: detect when the where clause is a trivial primary key match
- self.matched_objects = [
+ matched_objects = [
obj
- for (
- cls,
- pk,
- identity_token,
- ), obj in query.session.identity_map.items()
- if issubclass(cls, target_cls) and eval_condition(obj)
+ for (cls, pk, identity_token,), obj in session.identity_map.items()
+ if issubclass(cls, target_cls)
+ and eval_condition(obj)
+ and identity_token == update_options._refresh_identity_token
]
-
-
-class BulkFetch(BulkUD):
- """BulkUD which does the 'fetch' method of session state resolution."""
-
- def _do_pre_synchronize(self):
- query = self.query
- session = query.session
- select_stmt = self.compile_state.statement.with_only_columns(
- self.primary_table.primary_key
- )
- self.matched_rows = session.execute(
- select_stmt, mapper=self.mapper, params=query.load_options._params
- ).fetchall()
-
-
-class BulkUpdate(BulkUD):
- """BulkUD which handles UPDATEs."""
-
- def __init__(self, query, values, update_kwargs):
- super(BulkUpdate, self).__init__(query)
- self.values = values
- self.update_kwargs = update_kwargs
+ return update_options + {
+ "_matched_objects": matched_objects,
+ "_value_evaluators": value_evaluators,
+ "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ }
@classmethod
- def factory(cls, query, synchronize_session, values, update_kwargs):
- return BulkUD._factory(
- {
- "evaluate": BulkUpdateEvaluate,
- "fetch": BulkUpdateFetch,
- False: BulkUpdate,
- },
- synchronize_session,
- query,
- values,
- update_kwargs,
- )
-
- def _do_before_compile(self):
- if self.query.dispatch.before_compile_update:
- for fn in self.query.dispatch.before_compile_update:
- new_query = fn(self.query, self)
- if new_query is not None:
- self.query = new_query
+ def _get_resolved_values(cls, mapper, statement):
+ if statement._multi_values:
+ return []
+ elif statement._ordered_values:
+ iterator = statement._ordered_values
+ elif statement._values:
+ iterator = statement._values.items()
+ else:
+ return []
- @property
- def _resolved_values(self):
values = []
- for k, v in (
- self.values.items()
- if hasattr(self.values, "items")
- else self.values
- ):
- if self.mapper:
- if isinstance(k, util.string_types):
- desc = sql.util._entity_namespace_key(self.mapper, k)
- values.extend(desc._bulk_update_tuples(v))
- elif isinstance(k, attributes.QueryableAttribute):
- values.extend(k._bulk_update_tuples(v))
+ if iterator:
+ for k, v in iterator:
+ if mapper:
+ if isinstance(k, util.string_types):
+ desc = sql.util._entity_namespace_key(mapper, k)
+ values.extend(desc._bulk_update_tuples(v))
+ elif isinstance(k, attributes.QueryableAttribute):
+ values.extend(k._bulk_update_tuples(v))
+ else:
+ values.append((k, v))
else:
values.append((k, v))
- else:
- values.append((k, v))
return values
- @property
- def _resolved_values_keys_as_propnames(self):
+ @classmethod
+ def _resolved_keys_as_propnames(cls, mapper, resolved_values):
values = []
- for k, v in self._resolved_values:
+ for k, v in resolved_values:
if isinstance(k, attributes.QueryableAttribute):
values.append((k.key, v))
continue
elif hasattr(k, "__clause_element__"):
k = k.__clause_element__()
- if self.mapper and isinstance(k, expression.ColumnElement):
+ if mapper and isinstance(k, expression.ColumnElement):
try:
- attr = self.mapper._columntoproperty[k]
+ attr = mapper._columntoproperty[k]
except orm_exc.UnmappedColumnError:
pass
else:
@@ -1897,87 +1851,99 @@ class BulkUpdate(BulkUD):
)
return values
- def _do_exec(self):
- values = self._resolved_values
+ @classmethod
+ def _do_pre_synchronize_fetch(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
+ mapper = update_options._subject_mapper
- if not self.update_kwargs.get("preserve_parameter_order", False):
- values = dict(values)
+ if mapper:
+ primary_table = mapper.local_table
+ else:
+ primary_table = statement._raw_columns[0]
- update_stmt = sql.update(
- self.primary_table, **self.update_kwargs
- ).values(values)
+ # note this creates a Select() *without* the ORM plugin.
+ # we don't want that here.
+ select_stmt = future_select(*primary_table.primary_key)
+ select_stmt._where_criteria = statement._where_criteria
- update_stmt._where_criteria = self.compile_state._where_criteria
+ matched_rows = session.execute(
+ select_stmt, params, execution_options, bind_arguments
+ ).fetchall()
- self._execute_stmt(update_stmt)
+ if statement.__visit_name__ == "update":
+ resolved_values = cls._get_resolved_values(mapper, statement)
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ else:
+ resolved_keys_as_propnames = _EMPTY_DICT
- def _do_post(self):
- session = self.query.session
- session.dispatch.after_bulk_update(self)
+ return update_options + {
+ "_matched_rows": matched_rows,
+ "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ }
-class BulkDelete(BulkUD):
- """BulkUD which handles DELETEs."""
+@CompileState.plugin_for("orm", "update")
+class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
- def __init__(self, query):
- super(BulkDelete, self).__init__(query)
+ self = cls.__new__(cls)
- @classmethod
- def factory(cls, query, synchronize_session):
- return BulkUD._factory(
- {
- "evaluate": BulkDeleteEvaluate,
- "fetch": BulkDeleteFetch,
- False: BulkDelete,
- },
- synchronize_session,
- query,
+ self.mapper = mapper = statement.table._annotations.get(
+ "parentmapper", None
)
- def _do_before_compile(self):
- if self.query.dispatch.before_compile_delete:
- for fn in self.query.dispatch.before_compile_delete:
- new_query = fn(self.query, self)
- if new_query is not None:
- self.query = new_query
+ self._resolved_values = cls._get_resolved_values(mapper, statement)
- def _do_exec(self):
- delete_stmt = sql.delete(self.primary_table,)
- delete_stmt._where_criteria = self.compile_state._where_criteria
+ if not statement._preserve_parameter_order and statement._values:
+ self._resolved_values = dict(self._resolved_values)
- self._execute_stmt(delete_stmt)
+ new_stmt = sql.Update.__new__(sql.Update)
+ new_stmt.__dict__.update(statement.__dict__)
+ new_stmt.table = mapper.local_table
- def _do_post(self):
- session = self.query.session
- session.dispatch.after_bulk_delete(self)
+ # note if the statement has _multi_values, these
+ # are passed through to the new statement, which will then raise
+ # InvalidRequestError because UPDATE doesn't support multi_values
+ # right now.
+ if statement._ordered_values:
+ new_stmt._ordered_values = self._resolved_values
+ elif statement._values:
+ new_stmt._values = self._resolved_values
+ UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
-class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
- """BulkUD which handles UPDATEs using the "evaluate"
- method of session resolution."""
+ return self
- def _additional_evaluators(self, evaluator_compiler):
- self.value_evaluators = {}
- values = self._resolved_values_keys_as_propnames
- for key, value in values:
- self.value_evaluators[key] = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
+ @classmethod
+ def _do_post_synchronize_evaluate(cls, session, update_options):
- def _do_post_synchronize(self):
- session = self.query.session
states = set()
- evaluated_keys = list(self.value_evaluators.keys())
- for obj in self.matched_objects:
+ evaluated_keys = list(update_options._value_evaluators.keys())
+ for obj in update_options._matched_objects:
+
state, dict_ = (
attributes.instance_state(obj),
attributes.instance_dict(obj),
)
+ assert (
+ state.identity_token == update_options._refresh_identity_token
+ )
+
# only evaluate unmodified attributes
to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
- dict_[key] = self.value_evaluators[key](obj)
+ dict_[key] = update_options._value_evaluators[key](obj)
state.manager.dispatch.refresh(state, None, to_evaluate)
@@ -1991,39 +1957,25 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
states.add(state)
session._register_altered(states)
-
-class BulkDeleteEvaluate(BulkEvaluate, BulkDelete):
- """BulkUD which handles DELETEs using the "evaluate"
- method of session resolution."""
-
- def _do_post_synchronize(self):
- self.query.session._remove_newly_deleted(
- [attributes.instance_state(obj) for obj in self.matched_objects]
- )
-
-
-class BulkUpdateFetch(BulkFetch, BulkUpdate):
- """BulkUD which handles UPDATEs using the "fetch"
- method of session resolution."""
-
- def _do_post_synchronize(self):
- session = self.query.session
- target_mapper = self.compile_state._mapper_zero()
+ @classmethod
+ def _do_post_synchronize_fetch(cls, session, update_options):
+ target_mapper = update_options._subject_mapper
states = set(
[
attributes.instance_state(session.identity_map[identity_key])
for identity_key in [
target_mapper.identity_key_from_primary_key(
- list(primary_key)
+ list(primary_key),
+ identity_token=update_options._refresh_identity_token,
)
- for primary_key in self.matched_rows
+ for primary_key in update_options._matched_rows
]
if identity_key in session.identity_map
]
)
- values = self._resolved_values_keys_as_propnames
+ values = update_options._resolved_keys_as_propnames
attrib = set(k for k, v in values)
for state in states:
to_expire = attrib.intersection(state.dict)
@@ -2032,18 +1984,38 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate):
session._register_altered(states)
-class BulkDeleteFetch(BulkFetch, BulkDelete):
- """BulkUD which handles DELETEs using the "fetch"
- method of session resolution."""
+@CompileState.plugin_for("orm", "delete")
+class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
+ self = cls.__new__(cls)
+
+ self.mapper = statement.table._annotations.get("parentmapper", None)
+
+ DeleteDMLState.__init__(self, statement, compiler, **kw)
+
+ return self
+
+ @classmethod
+ def _do_post_synchronize_evaluate(cls, session, update_options):
+
+ session._remove_newly_deleted(
+ [
+ attributes.instance_state(obj)
+ for obj in update_options._matched_objects
+ ]
+ )
+
+ @classmethod
+ def _do_post_synchronize_fetch(cls, session, update_options):
+ target_mapper = update_options._subject_mapper
- def _do_post_synchronize(self):
- session = self.query.session
- target_mapper = self.compile_state._mapper_zero()
- for primary_key in self.matched_rows:
+ for primary_key in update_options._matched_rows:
# TODO: inline this and call remove_newly_deleted
# once
identity_key = target_mapper.identity_key_from_primary_key(
- list(primary_key)
+ list(primary_key),
+ identity_token=update_options._refresh_identity_token,
)
if identity_key in session.identity_map:
session._remove_newly_deleted(
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 5137f9b1d..284ea9d72 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -19,12 +19,12 @@ database to return iterable result sets.
"""
import itertools
+import operator
from . import attributes
from . import exc as orm_exc
from . import interfaces
from . import loading
-from . import persistence
from .base import _assertions
from .context import _column_descriptions
from .context import _legacy_determine_last_joined_entity
@@ -2825,15 +2825,6 @@ class Query(
return result
- def _execute_crud(self, stmt, mapper):
- conn = self.session.connection(
- mapper=mapper, clause=stmt, close_with_result=True
- )
-
- return conn._execute_20(
- stmt, self.load_options._params, self._execution_options
- )
-
def __str__(self):
statement = self._statement_20()
@@ -3178,9 +3169,27 @@ class Query(
"""
- delete_op = persistence.BulkDelete.factory(self, synchronize_session)
- delete_op.exec_()
- return delete_op.rowcount
+ bulk_del = BulkDelete(self,)
+ if self.dispatch.before_compile_delete:
+ for fn in self.dispatch.before_compile_delete:
+ new_query = fn(bulk_del.query, bulk_del)
+ if new_query is not None:
+ bulk_del.query = new_query
+
+ self = bulk_del.query
+
+ delete_ = sql.delete(*self._raw_columns)
+ delete_._where_criteria = self._where_criteria
+ result = self.session.execute(
+ delete_,
+ self.load_options._params,
+ execution_options={"synchronize_session": synchronize_session},
+ )
+ bulk_del.result = result
+ self.session.dispatch.after_bulk_delete(bulk_del)
+ result.close()
+
+ return result.rowcount
def update(self, values, synchronize_session="evaluate", update_args=None):
r"""Perform a bulk update query.
@@ -3313,11 +3322,27 @@ class Query(
"""
update_args = update_args or {}
- update_op = persistence.BulkUpdate.factory(
- self, synchronize_session, values, update_args
+
+ bulk_ud = BulkUpdate(self, values, update_args)
+
+ if self.dispatch.before_compile_update:
+ for fn in self.dispatch.before_compile_update:
+ new_query = fn(bulk_ud.query, bulk_ud)
+ if new_query is not None:
+ bulk_ud.query = new_query
+ self = bulk_ud.query
+
+ upd = sql.update(*self._raw_columns, **update_args).values(values)
+ upd._where_criteria = self._where_criteria
+ result = self.session.execute(
+ upd,
+ self.load_options._params,
+ execution_options={"synchronize_session": synchronize_session},
)
- update_op.exec_()
- return update_op.rowcount
+ bulk_ud.result = result
+ self.session.dispatch.after_bulk_update(bulk_ud)
+ result.close()
+ return result.rowcount
def _compile_state(self, for_statement=False, **kw):
"""Create an out-of-compiler ORMCompileState object.
@@ -3427,3 +3452,59 @@ class AliasOption(interfaces.LoaderOption):
def process_compile_state(self, compile_state):
pass
+
+
+class BulkUD(object):
+ """State used for the orm.Query version of update() / delete().
+
+ This object is now specific to Query only.
+
+ """
+
+ def __init__(self, query):
+ self.query = query.enable_eagerloads(False)
+ self._validate_query_state()
+ self.mapper = self.query._entity_from_pre_ent_zero()
+
+ def _validate_query_state(self):
+ for attr, methname, notset, op in (
+ ("_limit_clause", "limit()", None, operator.is_),
+ ("_offset_clause", "offset()", None, operator.is_),
+ ("_order_by_clauses", "order_by()", (), operator.eq),
+ ("_group_by_clauses", "group_by()", (), operator.eq),
+ ("_distinct", "distinct()", False, operator.is_),
+ (
+ "_from_obj",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
+ (
+ "_legacy_setup_joins",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
+ ):
+ if not op(getattr(self.query, attr), notset):
+ raise sa_exc.InvalidRequestError(
+ "Can't call Query.update() or Query.delete() "
+ "when %s has been called" % (methname,)
+ )
+
+ @property
+ def session(self):
+ return self.query.session
+
+
+class BulkUpdate(BulkUD):
+ """BulkUD which handles UPDATEs."""
+
+ def __init__(self, query, values, update_kwargs):
+ super(BulkUpdate, self).__init__(query)
+ self.values = values
+ self.update_kwargs = update_kwargs
+
+
+class BulkDelete(BulkUD):
+ """BulkUD which handles DELETEs."""
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index ee42419a2..5ad8bcf2f 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -33,7 +33,9 @@ from .. import future
from .. import util
from ..inspection import inspect
from ..sql import coercions
+from ..sql import dml
from ..sql import roles
+from ..sql import selectable
from ..sql import visitors
from ..sql.base import CompileState
@@ -113,16 +115,24 @@ class ORMExecuteState(util.MemoizedSlots):
"_execution_options",
"_merged_execution_options",
"bind_arguments",
+ "_compile_state_cls",
)
def __init__(
- self, session, statement, parameters, execution_options, bind_arguments
+ self,
+ session,
+ statement,
+ parameters,
+ execution_options,
+ bind_arguments,
+ compile_state_cls,
):
self.session = session
self.statement = statement
self.parameters = parameters
self._execution_options = execution_options
self.bind_arguments = bind_arguments
+ self._compile_state_cls = compile_state_cls
def invoke_statement(
self,
@@ -194,6 +204,38 @@ class ORMExecuteState(util.MemoizedSlots):
)
@property
+ def is_orm_statement(self):
+ """return True if the operation is an ORM statement.
+
+ This indictes that the select(), update(), or delete() being
+ invoked contains ORM entities as subjects. For a statement
+ that does not have ORM entities and instead refers only to
+ :class:`.Table` metadata, it is invoked as a Core SQL statement
+ and no ORM-level automation takes place.
+
+ """
+ return self._compile_state_cls is not None
+
+ @property
+ def is_select(self):
+ """return True if this is a SELECT operation."""
+ return isinstance(self.statement, selectable.Select)
+
+ @property
+ def is_update(self):
+ """return True if this is an UPDATE operation."""
+ return isinstance(self.statement, dml.Update)
+
+ @property
+ def is_delete(self):
+ """return True if this is a DELETE operation."""
+ return isinstance(self.statement, dml.Delete)
+
+ @property
+ def _is_crud(self):
+ return isinstance(self.statement, (dml.Update, dml.Delete))
+
+ @property
def execution_options(self):
"""Placeholder for execution options.
@@ -270,11 +312,31 @@ class ORMExecuteState(util.MemoizedSlots):
def load_options(self):
"""Return the load_options that will be used for this execution."""
+ if not self.is_select:
+ raise sa_exc.InvalidRequestError(
+ "This ORM execution is not against a SELECT statement "
+ "so there are no load options."
+ )
return self._execution_options.get(
"_sa_orm_load_options", context.QueryContext.default_load_options
)
@property
+ def update_delete_options(self):
+ """Return the update_delete_options that will be used for this
+ execution."""
+
+ if not self._is_crud:
+ raise sa_exc.InvalidRequestError(
+ "This ORM execution is not against an UPDATE or DELETE "
+ "statement so there are no update options."
+ )
+ return self._execution_options.get(
+ "_sa_orm_update_options",
+ persistence.BulkUDCompileState.default_update_options,
+ )
+
+ @property
def user_defined_options(self):
"""The sequence of :class:`.UserDefinedOptions` that have been
associated with the statement being invoked.
@@ -1455,35 +1517,37 @@ class Session(_SessionClassMethods):
compile_state_cls = CompileState._get_plugin_class_for_plugin(
statement, "orm"
)
+ else:
+ compile_state_cls = None
- compile_state_cls.orm_pre_session_exec(
- self, statement, execution_options, bind_arguments
+ if compile_state_cls is not None:
+ execution_options = compile_state_cls.orm_pre_session_exec(
+ self, statement, params, execution_options, bind_arguments
)
-
- if self.dispatch.do_orm_execute:
- skip_events = bind_arguments.pop("_sa_skip_events", False)
-
- if not skip_events:
- orm_exec_state = ORMExecuteState(
- self,
- statement,
- params,
- execution_options,
- bind_arguments,
- )
- for fn in self.dispatch.do_orm_execute:
- result = fn(orm_exec_state)
- if result:
- return result
-
else:
- compile_state_cls = None
bind_arguments.setdefault("clause", statement)
if statement._is_future:
execution_options = util.immutabledict().merge_with(
execution_options, {"future_result": True}
)
+ if self.dispatch.do_orm_execute:
+ # run this event whether or not we are in ORM mode
+ skip_events = bind_arguments.get("_sa_skip_events", False)
+ if not skip_events:
+ orm_exec_state = ORMExecuteState(
+ self,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ compile_state_cls,
+ )
+ for fn in self.dispatch.do_orm_execute:
+ result = fn(orm_exec_state)
+ if result:
+ return result
+
bind = self.get_bind(**bind_arguments)
conn = self._connection_for_bind(bind, close_with_result=True)
@@ -1601,8 +1665,8 @@ class Session(_SessionClassMethods):
self.__binds[insp] = bind
elif insp.is_mapper:
self.__binds[insp.class_] = bind
- for selectable in insp._all_tables:
- self.__binds[selectable] = bind
+ for _selectable in insp._all_tables:
+ self.__binds[_selectable] = bind
else:
raise sa_exc.ArgumentError(
"Not an acceptable bind target: %s" % key
@@ -1664,7 +1728,9 @@ class Session(_SessionClassMethods):
"""
self._add_bind(table, bind)
- def get_bind(self, mapper=None, clause=None, bind=None):
+ def get_bind(
+ self, mapper=None, clause=None, bind=None, _sa_skip_events=None
+ ):
"""Return a "bind" to which this :class:`.Session` is bound.
The "bind" is usually an instance of :class:`_engine.Engine`,
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index f14319089..5dd3b519a 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -446,10 +446,14 @@ class CompileState(object):
plugin_name = statement._propagate_attrs.get(
"compile_state_plugin", "default"
)
- else:
- plugin_name = "default"
+ klass = cls.plugins.get(
+ (plugin_name, statement.__visit_name__), None
+ )
+ if klass is None:
+ klass = cls.plugins[("default", statement.__visit_name__)]
- klass = cls.plugins[(plugin_name, statement.__visit_name__)]
+ else:
+ klass = cls.plugins[("default", statement.__visit_name__)]
if klass is cls:
return cls(statement, compiler, **kw)
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index db43e42a6..4c6a0317a 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -755,6 +755,16 @@ class AnonymizedFromClauseImpl(StrictFromClauseImpl):
return element.alias(name=name, flat=flat)
+class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, element, **kw):
+ if "dml_table" in element._annotations:
+ return element._annotations["dml_table"]
+ else:
+ return element
+
+
class DMLSelectImpl(_NoTextCoercion, RoleImpl):
__slots__ = ()
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index f4160b552..2519438d1 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -3215,6 +3215,8 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
if toplevel:
self.isupdate = True
+ if not self.compile_state:
+ self.compile_state = compile_state
extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)
@@ -3342,6 +3344,8 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
if toplevel:
self.isdelete = True
+ if not self.compile_state:
+ self.compile_state = compile_state
extra_froms = compile_state._extra_froms
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 467a764d6..a82641d77 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -19,6 +19,7 @@ from .base import CompileState
from .base import DialectKWArgs
from .base import Executable
from .base import HasCompileState
+from .elements import BooleanClauseList
from .elements import ClauseElement
from .elements import Null
from .selectable import HasCTE
@@ -150,7 +151,6 @@ class UpdateDMLState(DMLState):
def __init__(self, statement, compiler, **kw):
self.statement = statement
-
self.isupdate = True
self._preserve_parameter_order = statement._preserve_parameter_order
if statement._ordered_values is not None:
@@ -447,7 +447,9 @@ class ValuesBase(UpdateBase):
_returning = ()
def __init__(self, table, values, prefixes):
- self.table = coercions.expect(roles.FromClauseRole, table)
+ self.table = coercions.expect(
+ roles.DMLTableRole, table, apply_propagate_attrs=self
+ )
if values is not None:
self.values.non_generative(self, values)
if prefixes:
@@ -949,6 +951,28 @@ class DMLWhereBase(object):
coercions.expect(roles.WhereHavingRole, whereclause),
)
+ def filter(self, *criteria):
+ """A synonym for the :meth:`_dml.DMLWhereBase.where` method."""
+
+ return self.where(*criteria)
+
+ @property
+ def whereclause(self):
+ """Return the completed WHERE clause for this :class:`.DMLWhereBase`
+ statement.
+
+ This assembles the current collection of WHERE criteria
+ into a single :class:`_expression.BooleanClauseList` construct.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ return BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
+
class Update(DMLWhereBase, ValuesBase):
"""Represent an Update construct.
@@ -1266,7 +1290,9 @@ class Delete(DMLWhereBase, UpdateBase):
"""
self._bind = bind
- self.table = coercions.expect(roles.FromClauseRole, table)
+ self.table = coercions.expect(
+ roles.DMLTableRole, table, apply_propagate_attrs=self
+ )
self._returning = returning
if prefixes:
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 5a55fe5f2..3d94ec9ff 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -184,10 +184,15 @@ class CompoundElementRole(SQLRole):
)
+# TODO: are we using this?
class DMLRole(StatementRole):
pass
+class DMLTableRole(FromClauseRole):
+ _role_name = "subject table for an INSERT, UPDATE or DELETE"
+
+
class DMLColumnRole(SQLRole):
_role_name = "SET/VALUES column expression or string key"
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index d6845e05f..a95fc561a 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -789,7 +789,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
self._reset_column_collection()
-class Join(FromClause):
+class Join(roles.DMLTableRole, FromClause):
"""represent a ``JOIN`` construct between two
:class:`_expression.FromClause`
elements.
@@ -1406,7 +1406,7 @@ class AliasedReturnsRows(NoInit, FromClause):
return self.element.bind
-class Alias(AliasedReturnsRows):
+class Alias(roles.DMLTableRole, AliasedReturnsRows):
"""Represents an table or selectable alias (AS).
Represents an alias, as typically applied to any table or
@@ -1987,7 +1987,7 @@ class FromGrouping(GroupedElement, FromClause):
self.element = state["element"]
-class TableClause(Immutable, FromClause):
+class TableClause(roles.DMLTableRole, Immutable, FromClause):
"""Represents a minimal "table" construct.
This is a lightweight table object that has only a name, a
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 388097e45..68281f33d 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -10,6 +10,7 @@ from .. import util
from ..inspection import inspect
from ..util import collections_abc
from ..util import HasMemoized
+from ..util import py37
SKIP_TRAVERSE = util.symbol("skip_traverse")
COMPARE_FAILED = False
@@ -562,23 +563,38 @@ class _CacheKey(ExtendedInternalTraversal):
)
def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+ if py37:
+ # in py37 we can assume two dictionaries created in the same
+ # insert ordering will retain that sorting
+ return (
+ attrname,
+ tuple(
+ (
+ k._gen_cache_key(anon_map, bindparams)
+ if hasattr(k, "__clause_element__")
+ else k,
+ obj[k]._gen_cache_key(anon_map, bindparams),
+ )
+ for k in obj
+ ),
+ )
+ else:
+ expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
+ if expr_values:
+ # expr values can't be sorted deterministically right now,
+ # so no cache
+ anon_map[NO_CACHE] = True
+ return ()
- expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
- if expr_values:
- # expr values can't be sorted deterministically right now,
- # so no cache
- anon_map[NO_CACHE] = True
- return ()
-
- str_values = expr_values.symmetric_difference(obj)
+ str_values = expr_values.symmetric_difference(obj)
- return (
- attrname,
- tuple(
- (k, obj[k]._gen_cache_key(anon_map, bindparams))
- for k in sorted(str_values)
- ),
- )
+ return (
+ attrname,
+ tuple(
+ (k, obj[k]._gen_cache_key(anon_map, bindparams))
+ for k in sorted(str_values)
+ ),
+ )
def visit_dml_multi_values(
self, attrname, obj, parent, anon_map, bindparams
@@ -1130,6 +1146,18 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
for lv, rv in zip(left, right):
if not self._compare_dml_values_or_ce(lv, rv, **kw):
return COMPARE_FAILED
+ elif isinstance(right, collections_abc.Sequence):
+ return COMPARE_FAILED
+ elif py37:
+ # dictionaries guaranteed to support insert ordering in
+ # py37 so that we can compare the keys in order. without
+ # this, we can't compare SQL expression keys because we don't
+ # know which key is which
+ for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
+ if not self._compare_dml_values_or_ce(lk, rk, **kw):
+ return COMPARE_FAILED
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
else:
for lk in left:
lv = left[lk]
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index 0ea9f067e..54da06a3d 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -403,10 +403,6 @@ class AssertsCompiledSQL(object):
LABEL_STYLE_TABLENAME_PLUS_COL
)
clause = compile_state.statement
- elif isinstance(clause, orm.persistence.BulkUD):
- with mock.patch.object(clause, "_execute_stmt") as stmt_mock:
- clause.exec_()
- clause = stmt_mock.mock_calls[0][1][0]
if compile_kwargs:
kw["compile_kwargs"] = compile_kwargs
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 55a6cdcf9..273570357 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -65,6 +65,7 @@ from .compat import pickle # noqa
from .compat import print_ # noqa
from .compat import py2k # noqa
from .compat import py36 # noqa
+from .compat import py37 # noqa
from .compat import py3k # noqa
from .compat import quote_plus # noqa
from .compat import raise_ # noqa
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index 247dbc13c..5c46395f9 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -15,6 +15,7 @@ import platform
import sys
+py37 = sys.version_info >= (3, 7)
py36 = sys.version_info >= (3, 6)
py3k = sys.version_info >= (3, 0)
py2k = sys.version_info < (3, 0)