summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-06-03 17:38:35 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-06-06 13:31:54 -0400
commit3ab2364e78641c4f0e4b6456afc2cbed39b0d0e6 (patch)
treef3dc26609070c1a357a366592c791a3ec0655483 /lib/sqlalchemy/sql
parent14bc09203a8b5b2bc001f764ad7cce6a184975cc (diff)
downloadsqlalchemy-3ab2364e78641c4f0e4b6456afc2cbed39b0d0e6.tar.gz
Convert bulk update/delete to new execution model
This reorganizes the BulkUD model in sqlalchemy.orm.persistence to be based on the CompileState concept and to allow plain update() / delete() to be passed to session.execute() where the ORM synchronize session logic will take place. Also gets "synchronize_session='fetch'" working with horizontal sharding. Adding a few more result.scalar_one() types of methods as scalar_one() seems like what is normally desired. Fixes: #5160 Change-Id: I8001ebdad089da34119eb459709731ba6c0ba975
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/base.py10
-rw-r--r--lib/sqlalchemy/sql/coercions.py10
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
-rw-r--r--lib/sqlalchemy/sql/dml.py32
-rw-r--r--lib/sqlalchemy/sql/roles.py5
-rw-r--r--lib/sqlalchemy/sql/selectable.py6
-rw-r--r--lib/sqlalchemy/sql/traversals.py58
7 files changed, 101 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index f14319089..5dd3b519a 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -446,10 +446,14 @@ class CompileState(object):
plugin_name = statement._propagate_attrs.get(
"compile_state_plugin", "default"
)
- else:
- plugin_name = "default"
+ klass = cls.plugins.get(
+ (plugin_name, statement.__visit_name__), None
+ )
+ if klass is None:
+ klass = cls.plugins[("default", statement.__visit_name__)]
- klass = cls.plugins[(plugin_name, statement.__visit_name__)]
+ else:
+ klass = cls.plugins[("default", statement.__visit_name__)]
if klass is cls:
return cls(statement, compiler, **kw)
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index db43e42a6..4c6a0317a 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -755,6 +755,16 @@ class AnonymizedFromClauseImpl(StrictFromClauseImpl):
return element.alias(name=name, flat=flat)
+class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, element, **kw):
+ if "dml_table" in element._annotations:
+ return element._annotations["dml_table"]
+ else:
+ return element
+
+
class DMLSelectImpl(_NoTextCoercion, RoleImpl):
__slots__ = ()
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index f4160b552..2519438d1 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -3215,6 +3215,8 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
if toplevel:
self.isupdate = True
+ if not self.compile_state:
+ self.compile_state = compile_state
extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)
@@ -3342,6 +3344,8 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
if toplevel:
self.isdelete = True
+ if not self.compile_state:
+ self.compile_state = compile_state
extra_froms = compile_state._extra_froms
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 467a764d6..a82641d77 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -19,6 +19,7 @@ from .base import CompileState
from .base import DialectKWArgs
from .base import Executable
from .base import HasCompileState
+from .elements import BooleanClauseList
from .elements import ClauseElement
from .elements import Null
from .selectable import HasCTE
@@ -150,7 +151,6 @@ class UpdateDMLState(DMLState):
def __init__(self, statement, compiler, **kw):
self.statement = statement
-
self.isupdate = True
self._preserve_parameter_order = statement._preserve_parameter_order
if statement._ordered_values is not None:
@@ -447,7 +447,9 @@ class ValuesBase(UpdateBase):
_returning = ()
def __init__(self, table, values, prefixes):
- self.table = coercions.expect(roles.FromClauseRole, table)
+ self.table = coercions.expect(
+ roles.DMLTableRole, table, apply_propagate_attrs=self
+ )
if values is not None:
self.values.non_generative(self, values)
if prefixes:
@@ -949,6 +951,28 @@ class DMLWhereBase(object):
coercions.expect(roles.WhereHavingRole, whereclause),
)
+ def filter(self, *criteria):
+ """A synonym for the :meth:`_dml.DMLWhereBase.where` method."""
+
+ return self.where(*criteria)
+
+ @property
+ def whereclause(self):
+ """Return the completed WHERE clause for this :class:`.DMLWhereBase`
+ statement.
+
+ This assembles the current collection of WHERE criteria
+ into a single :class:`_expression.BooleanClauseList` construct.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ return BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
+
class Update(DMLWhereBase, ValuesBase):
"""Represent an Update construct.
@@ -1266,7 +1290,9 @@ class Delete(DMLWhereBase, UpdateBase):
"""
self._bind = bind
- self.table = coercions.expect(roles.FromClauseRole, table)
+ self.table = coercions.expect(
+ roles.DMLTableRole, table, apply_propagate_attrs=self
+ )
self._returning = returning
if prefixes:
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 5a55fe5f2..3d94ec9ff 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -184,10 +184,15 @@ class CompoundElementRole(SQLRole):
)
+# TODO: are we using this?
class DMLRole(StatementRole):
pass
+class DMLTableRole(FromClauseRole):
+ _role_name = "subject table for an INSERT, UPDATE or DELETE"
+
+
class DMLColumnRole(SQLRole):
_role_name = "SET/VALUES column expression or string key"
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index d6845e05f..a95fc561a 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -789,7 +789,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
self._reset_column_collection()
-class Join(FromClause):
+class Join(roles.DMLTableRole, FromClause):
"""represent a ``JOIN`` construct between two
:class:`_expression.FromClause`
elements.
@@ -1406,7 +1406,7 @@ class AliasedReturnsRows(NoInit, FromClause):
return self.element.bind
-class Alias(AliasedReturnsRows):
+class Alias(roles.DMLTableRole, AliasedReturnsRows):
"""Represents an table or selectable alias (AS).
Represents an alias, as typically applied to any table or
@@ -1987,7 +1987,7 @@ class FromGrouping(GroupedElement, FromClause):
self.element = state["element"]
-class TableClause(Immutable, FromClause):
+class TableClause(roles.DMLTableRole, Immutable, FromClause):
"""Represents a minimal "table" construct.
This is a lightweight table object that has only a name, a
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 388097e45..68281f33d 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -10,6 +10,7 @@ from .. import util
from ..inspection import inspect
from ..util import collections_abc
from ..util import HasMemoized
+from ..util import py37
SKIP_TRAVERSE = util.symbol("skip_traverse")
COMPARE_FAILED = False
@@ -562,23 +563,38 @@ class _CacheKey(ExtendedInternalTraversal):
)
def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+ if py37:
+ # in py37 we can assume two dictionaries created in the same
+ # insert ordering will retain that sorting
+ return (
+ attrname,
+ tuple(
+ (
+ k._gen_cache_key(anon_map, bindparams)
+ if hasattr(k, "__clause_element__")
+ else k,
+ obj[k]._gen_cache_key(anon_map, bindparams),
+ )
+ for k in obj
+ ),
+ )
+ else:
+ expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
+ if expr_values:
+ # expr values can't be sorted deterministically right now,
+ # so no cache
+ anon_map[NO_CACHE] = True
+ return ()
- expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
- if expr_values:
- # expr values can't be sorted deterministically right now,
- # so no cache
- anon_map[NO_CACHE] = True
- return ()
-
- str_values = expr_values.symmetric_difference(obj)
+ str_values = expr_values.symmetric_difference(obj)
- return (
- attrname,
- tuple(
- (k, obj[k]._gen_cache_key(anon_map, bindparams))
- for k in sorted(str_values)
- ),
- )
+ return (
+ attrname,
+ tuple(
+ (k, obj[k]._gen_cache_key(anon_map, bindparams))
+ for k in sorted(str_values)
+ ),
+ )
def visit_dml_multi_values(
self, attrname, obj, parent, anon_map, bindparams
@@ -1130,6 +1146,18 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
for lv, rv in zip(left, right):
if not self._compare_dml_values_or_ce(lv, rv, **kw):
return COMPARE_FAILED
+ elif isinstance(right, collections_abc.Sequence):
+ return COMPARE_FAILED
+ elif py37:
+ # dictionaries guaranteed to support insert ordering in
+ # py37 so that we can compare the keys in order. without
+ # this, we can't compare SQL expression keys because we don't
+ # know which key is which
+ for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
+ if not self._compare_dml_values_or_ce(lk, rk, **kw):
+ return COMPARE_FAILED
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
else:
for lk in left:
lv = left[lk]