summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-02-23 13:37:18 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2020-03-06 11:01:51 -0500
commit851fb8f5a661c66ee76308181118369c8c4df9e0 (patch)
treeb6c786e78e090752f5c0922d1f09d277ab94e365 /lib/sqlalchemy
parentd72bda5ed23a46bcbf31d40684200dcb79012a33 (diff)
downloadsqlalchemy-851fb8f5a661c66ee76308181118369c8c4df9e0.tar.gz
Decouple compiler state from DML objects; make cacheable
Targeting select / insert / update / delete, the goal is to minimize overhead of construction and generative methods so that only the raw arguments passed are handled. An interim stage that converts the raw state into more compiler-ready state is added, which is analogous to the ORM QueryContext which will also be rolled in to be a similar concept, as is currently being prototyped in I19e05b3424b07114cce6c439b05198ac47f7ac10. the ORM update/delete BulkUD concept is also going to be rolled onto this idea. So while the compiler-ready state object, here called DMLState, looks a little thin, it's the base of a bigger pattern that will allow for ORM functionality to embed itself directly into the compiler, execution context, and result set objects. This change targets the DML objects, primarily focused on the values() method which is the most complex process. The work done by values() is minimized as much as possible while still being able to create a cache key. Additional computation is then offloaded to a new object ValuesState that is handled by the compiler. Architecturally, a big change here is that insert.values() and update.values() will generate BindParameter objects for the values now, which are then carefully received by crud.py so that they generate the expected names. This is so that the values() portion of these constructs is cacheable. for the "multi-values" version of Insert, this is all skipped and the plan right now is that a multi-values insert is not worth caching (can always be revisited). Using the coercions system in values() also gets us nicer validation for free, we can remove the NotAClauseElement thing from schema, and we also now require scalar_subquery() is called for an insert/update that uses a SELECT as a column value, 1.x deprecation path is added. The traversal system is then applied to the DML objects including tests so that they have traversal, cloning, and cache key support. cloning is not a use case for DML however having it present allows better validation of the structure within the tests. Special per-dialect DML is explicitly not cacheable at the moment, more as a proof of concept that third party DML constructs can exist as gracefully not-cacheable rather than producing an incomplete cache key. A few selected performance improvements have been added as well, simplifying the immutabledict.union() method and adding a new SQLCompiler function that can generate delimeter-separated clauses like WHERE and ORDER BY without having to build a ClauseList object at all. The use of ClauseList will be removed from Select in an upcoming commit. Overall, ClaustList is unnecessary for internal use and only adds overhead to statement construction and will likely be removed as much as possible except for explcit use of conjunctions like and_() and or_(). Change-Id: I408e0b8be91fddd77cf279da97f55020871f75a9
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py24
-rw-r--r--lib/sqlalchemy/engine/default.py6
-rw-r--r--lib/sqlalchemy/sql/base.py47
-rw-r--r--lib/sqlalchemy/sql/coercions.py13
-rw-r--r--lib/sqlalchemy/sql/compiler.py87
-rw-r--r--lib/sqlalchemy/sql/crud.py181
-rw-r--r--lib/sqlalchemy/sql/dml.py537
-rw-r--r--lib/sqlalchemy/sql/elements.py14
-rw-r--r--lib/sqlalchemy/sql/schema.py25
-rw-r--r--lib/sqlalchemy/sql/selectable.py2
-rw-r--r--lib/sqlalchemy/sql/traversals.py218
-rw-r--r--lib/sqlalchemy/sql/visitors.py21
-rw-r--r--lib/sqlalchemy/util/_collections.py15
13 files changed, 820 insertions, 370 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index a3855cc2c..955a0f23b 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -1450,31 +1450,17 @@ class MSExecutionContext(default.DefaultExecutionContext):
insert_has_sequence = seq_column is not None
if insert_has_sequence:
+ compile_state = self.compiled.compile_state
self._enable_identity_insert = (
seq_column.key in self.compiled_parameters[0]
) or (
- self.compiled.statement.parameters
+ compile_state._dict_parameters
and (
- (
- self.compiled.statement._has_multi_parameters
- and (
- seq_column.key
- in self.compiled.statement.parameters[0]
- or seq_column
- in self.compiled.statement.parameters[0]
- )
- )
- or (
- not self.compiled.statement._has_multi_parameters
- and (
- seq_column.key
- in self.compiled.statement.parameters
- or seq_column
- in self.compiled.statement.parameters
- )
- )
+ seq_column.key in compile_state._dict_parameters
+ or seq_column in compile_state._dict_parameters
)
)
+
else:
self._enable_identity_insert = False
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 8775a8813..b151b6e48 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -1459,10 +1459,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
"get_current_parameters() can only be invoked in the "
"context of a Python side column default function"
)
+
+ compile_state = self.compiled.compile_state
if (
isolate_multiinsert_groups
and self.isinsert
- and self.compiled.statement._has_multi_parameters
+ and compile_state._has_multi_parameters
):
if column._is_multiparam_column:
index = column.index + 1
@@ -1470,7 +1472,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
else:
d = {column.key: parameters[column.key]}
index = 0
- keys = self.compiled.statement.parameters[0].keys()
+ keys = compile_state._dict_parameters.keys()
d.update(
(key, parameters["%s_m%d" % (key, index)]) for key in keys
)
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 2d336360f..89839ea28 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -16,6 +16,7 @@ import re
from .traversals import HasCacheKey # noqa
from .visitors import ClauseVisitor
+from .visitors import InternalTraversal
from .. import exc
from .. import util
@@ -221,6 +222,10 @@ class DialectKWArgs(object):
"""
+ _dialect_kwargs_traverse_internals = [
+ ("dialect_options", InternalTraversal.dp_dialect_options)
+ ]
+
@classmethod
def argument_for(cls, dialect_name, argument_name, default):
"""Add a new kind of dialect-specific keyword argument for this class.
@@ -386,6 +391,39 @@ class DialectKWArgs(object):
construct_arg_dictionary[arg_name] = kwargs[k]
+class CompileState(object):
+ """Produces additional object state necessary for a statement to be
+ compiled.
+
+ the :class:`.CompileState` class is at the base of classes that assemble
+ state for a particular statement object that is then used by the
+ compiler. This process is essentially an extension of the process that
+ the SQLCompiler.visit_XYZ() method takes, however there is an emphasis
+ on converting raw user intent into more organized structures rather than
+ producing string output. The top-level :class:`.CompileState` for the
+ statement being executed is also accessible when the execution context
+ works with invoking the statement and collecting results.
+
+ The production of :class:`.CompileState` is specific to the compiler, such
+ as within the :meth:`.SQLCompiler.visit_insert`,
+ :meth:`.SQLCompiler.visit_select` etc. methods. These methods are also
+ responsible for associating the :class:`.CompileState` with the
+ :class:`.SQLCompiler` itself, if the statement is the "toplevel" statement,
+ i.e. the outermost SQL statement that's actually being executed.
+ There can be other :class:`.CompileState` objects that are not the
+ toplevel, such as when a SELECT subquery or CTE-nested
+ INSERT/UPDATE/DELETE is generated.
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ("statement",)
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+
+
class Generative(object):
"""Provide a method-chaining pattern in conjunction with the
@_generative decorator."""
@@ -396,6 +434,12 @@ class Generative(object):
return s
+class HasCompileState(Generative):
+ """A class that has a :class:`.CompileState` associated with it."""
+
+ _compile_state_cls = CompileState
+
+
class Executable(Generative):
"""Mark a ClauseElement as supporting execution.
@@ -627,6 +671,9 @@ class ColumnCollection(object):
def keys(self):
return [k for (k, col) in self._collection]
+ def __bool__(self):
+ return bool(self._collection)
+
def __len__(self):
return len(self._collection)
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index fc841bb4b..679d9c6e9 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -55,7 +55,10 @@ def expect(role, element, **kw):
# elaborate logic up front if possible
impl = _impl_lookup[role]
- if not isinstance(element, (elements.ClauseElement, schema.SchemaItem)):
+ if not isinstance(
+ element,
+ (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue),
+ ):
resolved = impl._resolve_for_clause_element(element, **kw)
else:
resolved = element
@@ -194,7 +197,9 @@ class _ColumnCoercions(object):
def _implicit_coercions(
self, original_element, resolved, argname=None, **kw
):
- if resolved._is_select_statement:
+ if not resolved.is_clause_element:
+ self._raise_for_expected(original_element, argname, resolved)
+ elif resolved._is_select_statement:
self._warn_for_scalar_subquery_coercion()
return resolved.scalar_subquery()
elif resolved._is_from_clause and isinstance(
@@ -290,14 +295,14 @@ class ExpressionElementImpl(
_ColumnCoercions, RoleImpl, roles.ExpressionElementRole
):
def _literal_coercion(
- self, element, name=None, type_=None, argname=None, **kw
+ self, element, name=None, type_=None, argname=None, is_crud=False, **kw
):
if element is None:
return elements.Null()
else:
try:
return elements.BindParameter(
- name, element, type_, unique=True
+ name, element, type_, unique=True, _is_crud=is_crud
)
except exc.ArgumentError as err:
self._raise_for_expected(element, err=err)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 424282951..3ebcf24b0 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -653,6 +653,20 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
+ compile_state = None
+ """Optional :class:`.CompileState` object that maintains additional
+ state used by the compiler.
+
+ Major executable objects such as :class:`.Insert`, :class:`.Update`,
+ :class:`.Delete`, :class:`.Select` will generate this state when compiled
+ in order to calculate additional information about the object. For the
+ top level object that is to be executed, the state can be stored here where
+ it can also have applicability towards result set processing.
+
+ .. versionadded:: 1.4
+
+ """
+
def __init__(
self,
dialect,
@@ -1292,6 +1306,13 @@ class SQLCompiler(Compiled):
else:
return "0"
+ def _generate_delimited_list(self, elements, separator, **kw):
+ return separator.join(
+ s
+ for s in (c._compiler_dispatch(self, **kw) for c in elements)
+ if s
+ )
+
def visit_clauselist(self, clauselist, **kw):
sep = clauselist.operator
if sep is None:
@@ -1299,13 +1320,7 @@ class SQLCompiler(Compiled):
else:
sep = OPERATORS[clauselist.operator]
- text = sep.join(
- s
- for s in (
- c._compiler_dispatch(self, **kw) for c in clauselist.clauses
- )
- if s
- )
+ text = self._generate_delimited_list(clauselist.clauses, sep, **kw)
if clauselist._tuple_values and self.dialect.tuple_in_values:
text = "VALUES " + text
return text
@@ -2810,8 +2825,18 @@ class SQLCompiler(Compiled):
return dialect_hints, table_text
def visit_insert(self, insert_stmt, **kw):
+
+ compile_state = insert_stmt._compile_state_cls(
+ insert_stmt, self, isinsert=True, **kw
+ )
+ insert_stmt = compile_state.statement
+
toplevel = not self.stack
+ if toplevel:
+ self.isinsert = True
+ self.compile_state = compile_state
+
self.stack.append(
{
"correlate_froms": set(),
@@ -2820,8 +2845,8 @@ class SQLCompiler(Compiled):
}
)
- crud_params = crud._setup_crud_params(
- self, insert_stmt, crud.ISINSERT, **kw
+ crud_params = crud._get_crud_params(
+ self, insert_stmt, compile_state, **kw
)
if (
@@ -2835,7 +2860,7 @@ class SQLCompiler(Compiled):
"inserts." % self.dialect.name
)
- if insert_stmt._has_multi_parameters:
+ if compile_state._has_multi_parameters:
if not self.dialect.supports_multivalues_insert:
raise exc.CompileError(
"The '%s' dialect with current database "
@@ -2888,7 +2913,7 @@ class SQLCompiler(Compiled):
text += " %s" % select_text
elif not crud_params and supports_default_values:
text += " DEFAULT VALUES"
- elif insert_stmt._has_multi_parameters:
+ elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)" % (", ".join(c[1] for c in crud_param_set))
@@ -2947,9 +2972,16 @@ class SQLCompiler(Compiled):
)
def visit_update(self, update_stmt, **kw):
+ compile_state = update_stmt._compile_state_cls(
+ update_stmt, self, isupdate=True, **kw
+ )
+ update_stmt = compile_state.statement
+
toplevel = not self.stack
+ if toplevel:
+ self.isupdate = True
- extra_froms = update_stmt._extra_froms
+ extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)
if is_multitable:
@@ -2981,8 +3013,8 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(
update_stmt, update_stmt.table, render_extra_froms, **kw
)
- crud_params = crud._setup_crud_params(
- self, update_stmt, crud.ISUPDATE, **kw
+ crud_params = crud._get_crud_params(
+ self, update_stmt, compile_state, **kw
)
if update_stmt._hints:
@@ -3022,8 +3054,10 @@ class SQLCompiler(Compiled):
if extra_from_text:
text += " " + extra_from_text
- if update_stmt._whereclause is not None:
- t = self.process(update_stmt._whereclause, **kw)
+ if update_stmt._where_criteria:
+ t = self._generate_delimited_list(
+ update_stmt._where_criteria, OPERATORS[operators.and_], **kw
+ )
if t:
text += " WHERE " + t
@@ -3045,10 +3079,6 @@ class SQLCompiler(Compiled):
return text
- @util.memoized_property
- def _key_getters_for_crud_column(self):
- return crud._key_getters_for_crud_column(self, self.statement)
-
def delete_extra_from_clause(
self, update_stmt, from_table, extra_froms, from_hints, **kw
):
@@ -3069,11 +3099,16 @@ class SQLCompiler(Compiled):
return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
def visit_delete(self, delete_stmt, **kw):
- toplevel = not self.stack
+ compile_state = delete_stmt._compile_state_cls(
+ delete_stmt, self, isdelete=True, **kw
+ )
+ delete_stmt = compile_state.statement
- crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw)
+ toplevel = not self.stack
+ if toplevel:
+ self.isdelete = True
- extra_froms = delete_stmt._extra_froms
+ extra_froms = compile_state._extra_froms
correlate_froms = {delete_stmt.table}.union(extra_froms)
self.stack.append(
@@ -3122,8 +3157,10 @@ class SQLCompiler(Compiled):
if extra_from_text:
text += " " + extra_from_text
- if delete_stmt._whereclause is not None:
- t = delete_stmt._whereclause._compiler_dispatch(self, **kw)
+ if delete_stmt._where_criteria:
+ t = self._generate_delimited_list(
+ delete_stmt._where_criteria, OPERATORS[operators.and_], **kw
+ )
if t:
text += " WHERE " + t
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index e474952ce..2827a5817 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -16,6 +16,7 @@ from . import coercions
from . import dml
from . import elements
from . import roles
+from .elements import ClauseElement
from .. import exc
from .. import util
@@ -33,45 +34,8 @@ values present.
""",
)
-ISINSERT = util.symbol("ISINSERT")
-ISUPDATE = util.symbol("ISUPDATE")
-ISDELETE = util.symbol("ISDELETE")
-
-def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
- restore_isinsert = compiler.isinsert
- restore_isupdate = compiler.isupdate
- restore_isdelete = compiler.isdelete
-
- should_restore = (
- (restore_isinsert or restore_isupdate or restore_isdelete)
- or len(compiler.stack) > 1
- or "visiting_cte" in kw
- )
-
- if local_stmt_type is ISINSERT:
- compiler.isupdate = False
- compiler.isinsert = True
- elif local_stmt_type is ISUPDATE:
- compiler.isupdate = True
- compiler.isinsert = False
- elif local_stmt_type is ISDELETE:
- if not should_restore:
- compiler.isdelete = True
- else:
- assert False, "ISINSERT, ISUPDATE, or ISDELETE expected"
-
- try:
- if local_stmt_type in (ISINSERT, ISUPDATE):
- return _get_crud_params(compiler, stmt, **kw)
- finally:
- if should_restore:
- compiler.isinsert = restore_isinsert
- compiler.isupdate = restore_isupdate
- compiler.isdelete = restore_isdelete
-
-
-def _get_crud_params(compiler, stmt, **kw):
+def _get_crud_params(compiler, stmt, compile_state, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -87,27 +51,29 @@ def _get_crud_params(compiler, stmt, **kw):
compiler.update_prefetch = []
compiler.returning = []
+ # getters - these are normally just column.key,
+ # but in the case of mysql multi-table update, the rules for
+ # .key must conditionally take tablename into account
+ (
+ _column_as_key,
+ _getattr_col_key,
+ _col_bind_name,
+ ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
+
+ compiler._key_getters_for_crud_column = getters
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
- if compiler.column_keys is None and stmt.parameters is None:
+ if compiler.column_keys is None and compile_state._no_parameters:
return [
(c, _create_bind_param(compiler, c, None, required=True))
for c in stmt.table.columns
]
- if stmt._has_multi_parameters:
- stmt_parameters = stmt.parameters[0]
+ if compile_state._has_multi_parameters:
+ stmt_parameters = compile_state._multi_parameters[0]
else:
- stmt_parameters = stmt.parameters
-
- # getters - these are normally just column.key,
- # but in the case of mysql multi-table update, the rules for
- # .key must conditionally take tablename into account
- (
- _column_as_key,
- _getattr_col_key,
- _col_bind_name,
- ) = _key_getters_for_crud_column(compiler, stmt)
+ stmt_parameters = compile_state._dict_parameters
# if we have statement parameters - set defaults in the
# compiled params
@@ -132,10 +98,15 @@ def _get_crud_params(compiler, stmt, **kw):
# special logic that only occurs for multi-table UPDATE
# statements
- if compiler.isupdate and stmt._extra_froms and stmt_parameters:
+ if (
+ compile_state.isupdate
+ and compile_state._extra_froms
+ and stmt_parameters
+ ):
_get_multitable_params(
compiler,
stmt,
+ compile_state,
stmt_parameters,
check_columns,
_col_bind_name,
@@ -144,10 +115,11 @@ def _get_crud_params(compiler, stmt, **kw):
kw,
)
- if compiler.isinsert and stmt.select_names:
+ if compile_state.isinsert and stmt._select_names:
_scan_insert_from_select_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -160,6 +132,7 @@ def _get_crud_params(compiler, stmt, **kw):
_scan_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -181,8 +154,10 @@ def _get_crud_params(compiler, stmt, **kw):
% (", ".join("%s" % (c,) for c in check))
)
- if stmt._has_multi_parameters:
- values = _extend_values_for_multiparams(compiler, stmt, values, kw)
+ if compile_state._has_multi_parameters:
+ values = _extend_values_for_multiparams(
+ compiler, stmt, compile_state, values, kw
+ )
return values
@@ -201,15 +176,46 @@ def _create_bind_param(
return bindparam
-def _key_getters_for_crud_column(compiler, stmt):
- if compiler.isupdate and stmt._extra_froms:
+def _handle_values_anonymous_param(compiler, col, value, name, **kw):
+ # the insert() and update() constructs as of 1.4 will now produce anonymous
+ # bindparam() objects in the values() collections up front when given plain
+ # literal values. This is so that cache key behaviors, which need to
+ # produce bound parameters in deterministic order without invoking any
+ # compilation here, can be applied to these constructs when they include
+ # values() (but not yet multi-values, which are not included in caching
+ # right now).
+ #
+ # in order to produce the desired "crud" style name for these parameters,
+ # which will also be targetable in engine/default.py through the usual
+ # conventions, apply our desired name to these unique parameters by
+ # populating the compiler truncated names cache with the desired name,
+ # rather than having
+ # compiler.visit_bindparam()->compiler._truncated_identifier make up a
+ # name. Saves on call counts also.
+ if value.unique and isinstance(value.key, elements._truncated_label):
+ compiler.truncated_names[("bindparam", value.key)] = name
+
+ if value.type._isnull:
+ # either unique parameter, or other bound parameters that were
+ # passed in directly
+ # clone using base ClauseElement to retain unique key
+ value = ClauseElement._clone(value)
+
+ # set type to that of the column unconditionally
+ value.type = col.type
+
+ return value._compiler_dispatch(compiler, **kw)
+
+
+def _key_getters_for_crud_column(compiler, stmt, compile_state):
+ if compile_state.isupdate and compile_state._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
# the "main" table of the statement remains unqualified,
# allowing the most compatibility with a non-multi-table
# statement.
- _et = set(stmt._extra_froms)
+ _et = set(compile_state._extra_froms)
c_key_role = functools.partial(
coercions.expect_as_key, roles.DMLColumnRole
@@ -246,6 +252,7 @@ def _key_getters_for_crud_column(compiler, stmt):
def _scan_insert_from_select_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -260,9 +267,9 @@ def _scan_insert_from_select_cols(
implicit_returning,
implicit_return_defaults,
postfetch_lastrowid,
- ) = _get_returning_modifiers(compiler, stmt)
+ ) = _get_returning_modifiers(compiler, stmt, compile_state)
- cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names]
+ cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
compiler._insert_from_select = stmt.select
@@ -294,6 +301,7 @@ def _scan_insert_from_select_cols(
def _scan_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -308,11 +316,11 @@ def _scan_cols(
implicit_returning,
implicit_return_defaults,
postfetch_lastrowid,
- ) = _get_returning_modifiers(compiler, stmt)
+ ) = _get_returning_modifiers(compiler, stmt, compile_state)
- if stmt._parameter_ordering:
+ if compile_state._parameter_ordering:
parameter_ordering = [
- _column_as_key(key) for key in stmt._parameter_ordering
+ _column_as_key(key) for key in compile_state._parameter_ordering
]
ordered_keys = set(parameter_ordering)
cols = [stmt.table.c[key] for key in parameter_ordering] + [
@@ -329,6 +337,7 @@ def _scan_cols(
_append_param_parameter(
compiler,
stmt,
+ compile_state,
c,
col_key,
parameters,
@@ -339,7 +348,7 @@ def _scan_cols(
kw,
)
- elif compiler.isinsert:
+ elif compile_state.isinsert:
if (
c.primary_key
and need_pks
@@ -377,7 +386,7 @@ def _scan_cols(
):
_warn_pk_with_no_anticipated_value(c)
- elif compiler.isupdate:
+ elif compile_state.isupdate:
_append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw
)
@@ -386,6 +395,7 @@ def _scan_cols(
def _append_param_parameter(
compiler,
stmt,
+ compile_state,
c,
col_key,
parameters,
@@ -395,7 +405,9 @@ def _append_param_parameter(
values,
kw,
):
+
value = parameters.pop(col_key)
+
if coercions._is_literal(value):
value = _create_bind_param(
compiler,
@@ -403,15 +415,21 @@ def _append_param_parameter(
value,
required=value is REQUIRED,
name=_col_bind_name(c)
- if not stmt._has_multi_parameters
+ if not compile_state._has_multi_parameters
+ else "%s_m0" % _col_bind_name(c),
+ **kw
+ )
+ elif value._is_bind_parameter:
+ value = _handle_values_anonymous_param(
+ compiler,
+ c,
+ value,
+ name=_col_bind_name(c)
+ if not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw
)
else:
- if isinstance(value, elements.BindParameter) and value.type._isnull:
- value = value._clone()
- value.type = c.type
-
if c.primary_key and implicit_returning:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
@@ -644,6 +662,7 @@ def _append_param_update(
def _get_multitable_params(
compiler,
stmt,
+ compile_state,
stmt_parameters,
check_columns,
_col_bind_name,
@@ -656,7 +675,7 @@ def _get_multitable_params(
for c, param in stmt_parameters.items()
)
affected_tables = set()
- for t in stmt._extra_froms:
+ for t in compile_state._extra_froms:
for c in t.c:
if c in normalized_params:
affected_tables.add(t)
@@ -669,6 +688,11 @@ def _get_multitable_params(
value,
required=value is REQUIRED,
name=_col_bind_name(c),
+ **kw # TODO: no test coverage for literal binds here
+ )
+ elif value._is_bind_parameter:
+ value = _handle_values_anonymous_param(
+ compiler, c, value, name=_col_bind_name(c), **kw
)
else:
compiler.postfetch.append(c)
@@ -704,11 +728,11 @@ def _get_multitable_params(
compiler.postfetch.append(c)
-def _extend_values_for_multiparams(compiler, stmt, values, kw):
+def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
values_0 = values
values = [values]
- for i, row in enumerate(stmt.parameters[1:]):
+ for i, row in enumerate(compile_state._multi_parameters[1:]):
extension = []
for (col, param) in values_0:
if col in row or col.key in row:
@@ -757,12 +781,13 @@ def _get_stmt_parameters_params(
values.append((k, v))
-def _get_returning_modifiers(compiler, stmt):
+def _get_returning_modifiers(compiler, stmt, compile_state):
+
need_pks = (
- compiler.isinsert
+ compile_state.isinsert
and not compiler.inline
and not stmt._returning
- and not stmt._has_multi_parameters
+ and not compile_state._has_multi_parameters
)
implicit_returning = (
@@ -771,9 +796,9 @@ def _get_returning_modifiers(compiler, stmt):
and stmt.table.implicit_returning
)
- if compiler.isinsert:
+ if compile_state.isinsert:
implicit_return_defaults = implicit_returning and stmt._return_defaults
- elif compiler.isupdate:
+ elif compile_state.isupdate:
implicit_return_defaults = (
compiler.dialect.implicit_returning
and stmt.table.implicit_returning
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 097c513b4..171a2cc2c 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -8,25 +8,162 @@
Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`.
"""
-
+from sqlalchemy.types import NullType
from . import coercions
from . import roles
from .base import _from_objects
from .base import _generative
+from .base import CompileState
from .base import DialectKWArgs
from .base import Executable
-from .elements import and_
+from .base import HasCompileState
from .elements import ClauseElement
from .elements import Null
from .selectable import HasCTE
from .selectable import HasPrefixes
+from .visitors import InternalTraversal
from .. import exc
from .. import util
+from ..util import collections_abc
+
+
+class DMLState(CompileState):
+ _no_parameters = True
+ _dict_parameters = None
+ _multi_parameters = None
+ _parameter_ordering = None
+ _has_multi_parameters = False
+ isupdate = False
+ isdelete = False
+ isinsert = False
+
+ def __init__(
+ self,
+ statement,
+ compiler,
+ isinsert=False,
+ isupdate=False,
+ isdelete=False,
+ **kw
+ ):
+ self.statement = statement
+
+ if isupdate:
+ self.isupdate = True
+ self._preserve_parameter_order = (
+ statement._preserve_parameter_order
+ )
+ if statement._ordered_values is not None:
+ self._process_ordered_values(statement)
+ elif statement._values is not None:
+ self._process_values(statement)
+ elif statement._multi_values:
+ self._process_multi_values(statement)
+ self._extra_froms = self._make_extra_froms(statement)
+ elif isinsert:
+ self.isinsert = True
+ if statement._select_names:
+ self._process_select_values(statement)
+ if statement._values is not None:
+ self._process_values(statement)
+ if statement._multi_values:
+ self._process_multi_values(statement)
+ elif isdelete:
+ self.isdelete = True
+ self._extra_froms = self._make_extra_froms(statement)
+ else:
+ assert False, "one of isinsert, isupdate, or isdelete must be set"
+
+ def _make_extra_froms(self, statement):
+ froms = []
+ seen = {statement.table}
+
+ for crit in statement._where_criteria:
+ for item in _from_objects(crit):
+ if not seen.intersection(item._cloned_set):
+ froms.append(item)
+ seen.update(item._cloned_set)
+
+ return froms
+
+ def _process_multi_values(self, statement):
+ if not statement._supports_multi_parameters:
+ raise exc.InvalidRequestError(
+ "%s construct does not support "
+ "multiple parameter sets." % statement.__visit_name__.upper()
+ )
+
+ for parameters in statement._multi_values:
+ multi_parameters = [
+ {
+ c.key: value
+ for c, value in zip(statement.table.c, parameter_set)
+ }
+ if isinstance(parameter_set, collections_abc.Sequence)
+ else parameter_set
+ for parameter_set in parameters
+ ]
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._has_multi_parameters = True
+ self._multi_parameters = multi_parameters
+ self._dict_parameters = self._multi_parameters[0]
+ elif not self._has_multi_parameters:
+ self._cant_mix_formats_error()
+ else:
+ self._multi_parameters.extend(multi_parameters)
+
+ def _process_values(self, statement):
+ if self._no_parameters:
+ self._has_multi_parameters = False
+ self._dict_parameters = statement._values
+ self._no_parameters = False
+ elif self._has_multi_parameters:
+ self._cant_mix_formats_error()
+
+ def _process_ordered_values(self, statement):
+ parameters = statement._ordered_values
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._dict_parameters = dict(parameters)
+ self._parameter_ordering = [key for key, value in parameters]
+ elif self._has_multi_parameters:
+ self._cant_mix_formats_error()
+ else:
+ raise exc.InvalidRequestError(
+ "Can only invoke ordered_values() once, and not mixed "
+ "with any other values() call"
+ )
+
+ def _process_select_values(self, statement):
+ parameters = {
+ coercions.expect(roles.DMLColumnRole, name, as_key=True): Null()
+ for name in statement._select_names
+ }
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._dict_parameters = parameters
+ else:
+ # this condition normally not reachable as the Insert
+ # does not allow this construction to occur
+ assert False, "This statement already has parameters"
+
+ def _cant_mix_formats_error(self):
+ raise exc.InvalidRequestError(
+ "Can't mix single and multiple VALUES "
+ "formats in one INSERT statement; one style appends to a "
+ "list while the other replaces values, so the intent is "
+ "ambiguous."
+ )
class UpdateBase(
roles.DMLRole,
HasCTE,
+ HasCompileState,
DialectKWArgs,
HasPrefixes,
Executable,
@@ -42,10 +179,10 @@ class UpdateBase(
{"autocommit": True}
)
_hints = util.immutabledict()
- _parameter_ordering = None
- _prefixes = ()
named_with_column = False
+ _compile_state_cls = DMLState
+
@classmethod
def _constructor_20_deprecations(cls, fn_name, clsname, names):
@@ -112,43 +249,6 @@ class UpdateBase(
col._make_proxy(fromclause) for col in self._returning
)
- def _process_colparams(self, parameters, preserve_parameter_order=False):
- def process_single(p):
- if isinstance(p, (list, tuple)):
- return dict((c.key, pval) for c, pval in zip(self.table.c, p))
- else:
- return p
-
- if (
- preserve_parameter_order or self._preserve_parameter_order
- ) and parameters is not None:
- if not isinstance(parameters, list) or (
- parameters and not isinstance(parameters[0], tuple)
- ):
- raise ValueError(
- "When preserve_parameter_order is True, "
- "values() only accepts a list of 2-tuples"
- )
- self._parameter_ordering = [key for key, value in parameters]
-
- return dict(parameters), False
-
- if (
- isinstance(parameters, (list, tuple))
- and parameters
- and isinstance(parameters[0], (list, tuple, dict))
- ):
-
- if not self._supports_multi_parameters:
- raise exc.InvalidRequestError(
- "This construct does not support "
- "multiple parameter sets."
- )
-
- return [process_single(p) for p in parameters], True
- else:
- return process_single(parameters), False
-
def params(self, *arg, **kw):
"""Set the parameters for the statement.
@@ -163,6 +263,29 @@ class UpdateBase(
" stmt.values(**parameters)."
)
+ @_generative
+ def with_dialect_options(self, **opt):
+ """Add dialect options to this INSERT/UPDATE/DELETE object.
+
+ e.g.::
+
+ upd = table.update().dialect_options(mysql_limit=10)
+
+ .. versionadded: 1.4 - this method supersedes the dialect options
+ associated with the constructor.
+
+
+ """
+ self._validate_dialect_kwargs(opt)
+
+ def _validate_dialect_kwargs_deprecated(self, dialect_kw):
+ util.warn_deprecated_20(
+ "Passing dialect keyword arguments directly to the "
+ "constructor is deprecated and will be removed in SQLAlchemy "
+ "2.0. Please use the ``with_dialect_options()`` method."
+ )
+ self._validate_dialect_kwargs(dialect_kw)
+
def bind(self):
"""Return a 'bind' linked to this :class:`.UpdateBase`
or a :class:`.Table` associated with it.
@@ -266,9 +389,6 @@ class UpdateBase(
self._hints = self._hints.union({(selectable, dialect_name): text})
- def _copy_internals(self, **kw):
- raise NotImplementedError()
-
class ValuesBase(UpdateBase):
"""Supplies support for :meth:`.ValuesBase.values` to
@@ -277,16 +397,21 @@ class ValuesBase(UpdateBase):
__visit_name__ = "values_base"
_supports_multi_parameters = False
- _has_multi_parameters = False
_preserve_parameter_order = False
select = None
_post_values_clause = None
+ _values = None
+ _multi_values = ()
+ _ordered_values = None
+ _select_names = None
+
+ _returning = ()
+
def __init__(self, table, values, prefixes):
self.table = coercions.expect(roles.FromClauseRole, table)
- self.parameters, self._has_multi_parameters = self._process_colparams(
- values
- )
+ if values is not None:
+ self.values.non_generative(self, values)
if prefixes:
self._setup_prefixes(prefixes)
@@ -416,59 +541,96 @@ class ValuesBase(UpdateBase):
:func:`~.expression.update` - produce an ``UPDATE`` statement
"""
- if self.select is not None:
+ if self._select_names:
raise exc.InvalidRequestError(
"This construct already inserts from a SELECT"
)
- if self._has_multi_parameters and kwargs:
- raise exc.InvalidRequestError(
- "This construct already has multiple parameter sets."
+ elif self._ordered_values:
+ raise exc.ArgumentError(
+ "This statement already has ordered values present"
)
if args:
- if len(args) > 1:
+ # positional case. this is currently expensive. we don't
+ # yet have positional-only args so we have to check the length.
+ # then we need to check multiparams vs. single dictionary.
+ # since the parameter format is needed in order to determine
+ # a cache key, we need to determine this up front.
+ arg = args[0]
+
+ if kwargs:
+ raise exc.ArgumentError(
+ "Can't pass positional and kwargs to values() "
+ "simultaneously"
+ )
+ elif len(args) > 1:
raise exc.ArgumentError(
"Only a single dictionary/tuple or list of "
"dictionaries/tuples is accepted positionally."
)
- v = args[0]
- else:
- v = {}
- if self.parameters is None:
- (
- self.parameters,
- self._has_multi_parameters,
- ) = self._process_colparams(v)
- else:
- if self._has_multi_parameters:
- self.parameters = list(self.parameters)
- p, self._has_multi_parameters = self._process_colparams(v)
- if not self._has_multi_parameters:
- raise exc.ArgumentError(
- "Can't mix single-values and multiple values "
- "formats in one statement"
- )
+ elif not self._preserve_parameter_order and isinstance(
+ arg, collections_abc.Sequence
+ ):
- self.parameters.extend(p)
- else:
- self.parameters = self.parameters.copy()
- p, self._has_multi_parameters = self._process_colparams(v)
- if self._has_multi_parameters:
- raise exc.ArgumentError(
- "Can't mix single-values and multiple values "
- "formats in one statement"
- )
- self.parameters.update(p)
+ if arg and isinstance(arg[0], (list, dict, tuple)):
+ self._multi_values += (arg,)
+ return
- if kwargs:
- if self._has_multi_parameters:
+ # tuple values
+ arg = {c.key: value for c, value in zip(self.table.c, arg)}
+ elif self._preserve_parameter_order and not isinstance(
+ arg, collections_abc.Sequence
+ ):
+ raise ValueError(
+ "When preserve_parameter_order is True, "
+ "values() only accepts a list of 2-tuples"
+ )
+
+ else:
+ # kwarg path. this is the most common path for non-multi-params
+ # so this is fairly quick.
+ arg = kwargs
+ if args:
raise exc.ArgumentError(
- "Can't pass kwargs and multiple parameter sets "
- "simultaneously"
+ "Only a single dictionary/tuple or list of "
+ "dictionaries/tuples is accepted positionally."
)
+
+ # for top level values(), convert literals to anonymous bound
+ # parameters at statement construction time, so that these values can
+ # participate in the cache key process like any other ClauseElement.
+ # crud.py now intercepts bound parameters with unique=True from here
+ # and ensures they get the "crud"-style name when rendered.
+
+ if self._preserve_parameter_order:
+ arg = [
+ (
+ k,
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=NullType(),
+ is_crud=True,
+ ),
+ )
+ for k, v in arg
+ ]
+ self._ordered_values = arg
+ else:
+ arg = {
+ k: coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=NullType(),
+ is_crud=True,
+ )
+ for k, v in arg.items()
+ }
+ if self._values:
+ self._values = self._values.union(arg)
else:
- self.parameters.update(kwargs)
+ self._values = util.immutabledict(arg)
@_generative
def return_defaults(self, *cols):
@@ -555,6 +717,25 @@ class Insert(ValuesBase):
_supports_multi_parameters = True
+ select = None
+ include_insert_from_select_defaults = False
+
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_inline", InternalTraversal.dp_boolean),
+ ("_select_names", InternalTraversal.dp_string_list),
+ ("_values", InternalTraversal.dp_dml_values),
+ ("_multi_values", InternalTraversal.dp_dml_multi_values),
+ ("select", InternalTraversal.dp_clauseelement),
+ ("_post_values_clause", InternalTraversal.dp_clauseelement),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ )
+
@ValuesBase._constructor_20_deprecations(
"insert",
"Insert",
@@ -626,18 +807,13 @@ class Insert(ValuesBase):
"""
super(Insert, self).__init__(table, values, prefixes)
self._bind = bind
- self.select = self.select_names = None
- self.include_insert_from_select_defaults = False
self._inline = inline
- self._returning = returning
- self._validate_dialect_kwargs(dialect_kw)
- self._return_defaults = return_defaults
+ if returning:
+ self._returning = returning
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
- def get_children(self, **kwargs):
- if self.select is not None:
- return (self.select,)
- else:
- return ()
+ self._return_defaults = return_defaults
@_generative
def inline(self):
@@ -702,25 +878,34 @@ class Insert(ValuesBase):
:attr:`.ResultProxy.inserted_primary_key` accessor does not apply.
"""
- if self.parameters:
+
+ if self._values:
raise exc.InvalidRequestError(
"This construct already inserts value expressions"
)
- self.parameters, self._has_multi_parameters = self._process_colparams(
- {
- coercions.expect(roles.DMLColumnRole, n, as_key=True): Null()
- for n in names
- }
- )
-
- self.select_names = names
+ self._select_names = names
self._inline = True
self.include_insert_from_select_defaults = include_defaults
self.select = coercions.expect(roles.DMLSelectRole, select)
-class Update(ValuesBase):
+class DMLWhereBase(object):
+ _where_criteria = ()
+
+ @_generative
+ def where(self, whereclause):
+ """return a new construct with the given expression added to
+ its WHERE clause, joined to the existing clause via AND, if any.
+
+ """
+
+ self._where_criteria += (
+ coercions.expect(roles.WhereHavingRole, whereclause),
+ )
+
+
+class Update(DMLWhereBase, ValuesBase):
"""Represent an Update construct.
The :class:`.Update` object is created using the :func:`update()`
@@ -730,6 +915,20 @@ class Update(ValuesBase):
__visit_name__ = "update"
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_list),
+ ("_inline", InternalTraversal.dp_boolean),
+ ("_ordered_values", InternalTraversal.dp_dml_ordered_values),
+ ("_values", InternalTraversal.dp_dml_values),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ )
+
@ValuesBase._constructor_20_deprecations(
"update",
"Update",
@@ -874,21 +1073,14 @@ class Update(ValuesBase):
self._bind = bind
self._returning = returning
if whereclause is not None:
- self._whereclause = coercions.expect(
- roles.WhereHavingRole, whereclause
+ self._where_criteria += (
+ coercions.expect(roles.WhereHavingRole, whereclause),
)
- else:
- self._whereclause = None
self._inline = inline
- self._validate_dialect_kwargs(dialect_kw)
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
self._return_defaults = return_defaults
- def get_children(self, **kwargs):
- if self._whereclause is not None:
- return (self._whereclause,)
- else:
- return ()
-
@_generative
def ordered_values(self, *args):
"""Specify the VALUES clause of this UPDATE statement with an explicit
@@ -912,22 +1104,27 @@ class Update(ValuesBase):
parameter, which will be removed in SQLAlchemy 2.0.
"""
- if self.select is not None:
- raise exc.InvalidRequestError(
- "This construct already inserts from a SELECT"
- )
-
- if self.parameters is None:
- (
- self.parameters,
- self._has_multi_parameters,
- ) = self._process_colparams(
- list(args), preserve_parameter_order=True
- )
- else:
+ if self._values:
raise exc.ArgumentError(
"This statement already has values present"
)
+ elif self._ordered_values:
+ raise exc.ArgumentError(
+ "This statement already has ordered values present"
+ )
+ arg = [
+ (
+ k,
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=NullType(),
+ is_crud=True,
+ ),
+ )
+ for k, v in args
+ ]
+ self._ordered_values = arg
@_generative
def inline(self):
@@ -945,37 +1142,8 @@ class Update(ValuesBase):
"""
self._inline = True
- @_generative
- def where(self, whereclause):
- """return a new update() construct with the given expression added to
- its WHERE clause, joined to the existing clause via AND, if any.
-
- """
- if self._whereclause is not None:
- self._whereclause = and_(
- self._whereclause,
- coercions.expect(roles.WhereHavingRole, whereclause),
- )
- else:
- self._whereclause = coercions.expect(
- roles.WhereHavingRole, whereclause
- )
-
- @property
- def _extra_froms(self):
- froms = []
- seen = {self.table}
-
- if self._whereclause is not None:
- for item in _from_objects(self._whereclause):
- if not seen.intersection(item._cloned_set):
- froms.append(item)
- seen.update(item._cloned_set)
-
- return froms
-
-class Delete(UpdateBase):
+class Delete(DMLWhereBase, UpdateBase):
"""Represent a DELETE construct.
The :class:`.Delete` object is created using the :func:`delete()`
@@ -985,6 +1153,17 @@ class Delete(UpdateBase):
__visit_name__ = "delete"
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_list),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ )
+
@ValuesBase._constructor_20_deprecations(
"delete",
"Delete",
@@ -1041,43 +1220,9 @@ class Delete(UpdateBase):
self._setup_prefixes(prefixes)
if whereclause is not None:
- self._whereclause = coercions.expect(
- roles.WhereHavingRole, whereclause
- )
- else:
- self._whereclause = None
-
- self._validate_dialect_kwargs(dialect_kw)
-
- def get_children(self, **kwargs):
- if self._whereclause is not None:
- return (self._whereclause,)
- else:
- return ()
-
- @_generative
- def where(self, whereclause):
- """Add the given WHERE clause to a newly returned delete construct."""
-
- if self._whereclause is not None:
- self._whereclause = and_(
- self._whereclause,
+ self._where_criteria += (
coercions.expect(roles.WhereHavingRole, whereclause),
)
- else:
- self._whereclause = coercions.expect(
- roles.WhereHavingRole, whereclause
- )
-
- @property
- def _extra_froms(self):
- froms = []
- seen = {self.table}
- if self._whereclause is not None:
- for item in _from_objects(self._whereclause):
- if not seen.intersection(item._cloned_set):
- froms.append(item)
- seen.update(item._cloned_set)
-
- return froms
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index d0babb1be..47739a37d 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -200,6 +200,7 @@ class ClauseElement(
_is_from_container = False
_is_select_container = False
_is_select_statement = False
+ _is_bind_parameter = False
_order_by_label_element = None
@@ -1010,6 +1011,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
_is_crud = False
_expanding_in_types = ()
+ _is_bind_parameter = True
def __init__(
self,
@@ -1025,6 +1027,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
literal_execute=False,
_compared_to_operator=None,
_compared_to_type=None,
+ _is_crud=False,
):
r"""Produce a "bound expression".
@@ -1303,6 +1306,8 @@ class BindParameter(roles.InElementRole, ColumnElement):
self.required = required
self.expanding = expanding
self.literal_execute = literal_execute
+ if _is_crud:
+ self._is_crud = True
if type_ is None:
if _compared_to_type is not None:
self.type = _compared_to_type.coerce_compared_value(
@@ -4264,21 +4269,12 @@ class ColumnClause(
else:
return other.proxy_set.intersection(self.proxy_set)
- def _get_table(self):
- return self.__dict__["table"]
-
- def _set_table(self, table):
- self._memoized_property.expire_instance(self)
- self.__dict__["table"] = table
-
def get_children(self, column_tables=False, **kw):
if column_tables and self.table is not None:
return [self.table]
else:
return []
- table = property(_get_table, _set_table)
-
@_memoized_property
def _from_objects(self):
t = self.table
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 5445a1bce..4c627c4cc 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -1413,6 +1413,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"Column must be constructed with a non-blank name or "
"assign a non-blank .name before adding to a Table."
)
+
+ Column._memoized_property.expire_instance(self)
+
if self.key is None:
self.key = self.name
@@ -2080,24 +2083,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
self._set_target_column(_column)
-class _NotAColumnExpr(object):
- # the coercions system is not used in crud.py for the values passed in
- # the insert().values() and update().values() methods, so the usual
- # pathways to rejecting a coercion in the unlikely case of adding defaut
- # generator objects to insert() or update() constructs aren't available;
- # create a quick coercion rejection here that is specific to what crud.py
- # calls on value objects.
- def _not_a_column_expr(self):
- raise exc.InvalidRequestError(
- "This %s cannot be used directly "
- "as a column expression." % self.__class__.__name__
- )
-
- self_group = lambda self: self._not_a_column_expr() # noqa
- _from_objects = property(lambda self: self._not_a_column_expr())
-
-
-class DefaultGenerator(_NotAColumnExpr, SchemaItem):
+class DefaultGenerator(SchemaItem):
"""Base class for column *default* values."""
__visit_name__ = "default_generator"
@@ -2505,7 +2491,7 @@ class Sequence(roles.StatementRole, DefaultGenerator):
@inspection._self_inspects
-class FetchedValue(_NotAColumnExpr, SchemaEventTarget):
+class FetchedValue(SchemaEventTarget):
"""A marker for a transparent database-side default.
Use :class:`.FetchedValue` when the database is configured
@@ -2528,6 +2514,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget):
is_server_default = True
reflected = False
has_argument = False
+ is_clause_element = False
def __init__(self, for_update=False):
self.for_update = for_update
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index b972c13be..965ac6e7f 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -3145,8 +3145,6 @@ class Select(
__visit_name__ = "select"
- _prefixes = ()
- _suffixes = ()
_hints = util.immutabledict()
_statement_hints = ()
_distinct = False
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 03ff7c439..c29a04ee0 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -200,6 +200,9 @@ class _CacheKey(ExtendedInternalTraversal):
attrname, inspect(obj), parent, anon_map, bindparams
)
+ def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+ return tuple(obj)
+
def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
return (
attrname,
@@ -336,6 +339,25 @@ class _CacheKey(ExtendedInternalTraversal):
def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
+ def visit_dialect_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ dialect_name,
+ tuple(
+ [
+ (key, obj[dialect_name][key])
+ for key in sorted(obj[dialect_name])
+ ]
+ ),
+ )
+ for dialect_name in sorted(obj)
+ ),
+ )
+
def visit_string_clauseelement_dict(
self, attrname, obj, parent, anon_map, bindparams
):
@@ -366,9 +388,13 @@ class _CacheKey(ExtendedInternalTraversal):
def visit_fromclause_canonical_column_collection(
self, attrname, obj, parent, anon_map, bindparams
):
+ # inlining into the internals of ColumnCollection
return (
attrname,
- tuple(col._gen_cache_key(anon_map, bindparams) for col in obj),
+ tuple(
+ col._gen_cache_key(anon_map, bindparams)
+ for k, col in obj._collection
+ ),
)
def visit_unknown_structure(
@@ -377,6 +403,48 @@ class _CacheKey(ExtendedInternalTraversal):
anon_map[NO_CACHE] = True
return ()
+ def visit_dml_ordered_values(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key._gen_cache_key(anon_map, bindparams)
+ if hasattr(key, "__clause_element__")
+ else key,
+ value._gen_cache_key(anon_map, bindparams),
+ )
+ for key, value in obj
+ ),
+ )
+
+ def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+
+ expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
+ if expr_values:
+ # expr values can't be sorted deterministically right now,
+ # so no cache
+ anon_map[NO_CACHE] = True
+ return ()
+
+ str_values = expr_values.symmetric_difference(obj)
+
+ return (
+ attrname,
+ tuple(
+ (k, obj[k]._gen_cache_key(anon_map, bindparams))
+ for k in sorted(str_values)
+ ),
+ )
+
+ def visit_dml_multi_values(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # multivalues are simply not cacheable right now
+ anon_map[NO_CACHE] = True
+ return ()
+
_cache_key_traversal_visitor = _CacheKey()
@@ -404,6 +472,70 @@ class _CopyInternals(InternalTraversal):
(key, clone(value, **kw)) for key, value in element.items()
)
+ def visit_dml_ordered_values(self, parent, element, clone=_clone, **kw):
+ # sequence of 2-tuples
+ return [
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key,
+ clone(value, **kw),
+ )
+ for key, value in element
+ ]
+
+ def visit_dml_values(self, parent, element, clone=_clone, **kw):
+ # sequence of dictionaries
+ return [
+ {
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key
+ ): clone(value, **kw)
+ for key, value in sub_element.items()
+ }
+ for sub_element in element
+ ]
+
+ def visit_dml_multi_values(self, parent, element, clone=_clone, **kw):
+ # sequence of sequences, each sequence contains a list/dict/tuple
+
+ def copy(elem):
+ if isinstance(elem, (list, tuple)):
+ return [
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key,
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value,
+ )
+ for key, value in elem
+ ]
+ elif isinstance(elem, dict):
+ return {
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key
+ ): (
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value
+ )
+ for key, value in elem
+ }
+ else:
+ # TODO: use abc classes
+ assert False
+
+ return [
+ [copy(sub_element) for sub_element in sequence]
+ for sequence in element
+ ]
+
_copy_internals = _CopyInternals()
@@ -442,6 +574,25 @@ class _GetChildren(InternalTraversal):
def visit_clauseelement_unordered_set(self, element, **kw):
return tuple(element)
+ def visit_dml_ordered_values(self, element, **kw):
+ for k, v in element:
+ if hasattr(k, "__clause_element__"):
+ yield k
+ yield v
+
+ def visit_dml_values(self, element, **kw):
+ expr_values = {k for k in element if hasattr(k, "__clause_element__")}
+ str_values = expr_values.symmetric_difference(element)
+
+ for k in sorted(str_values):
+ yield element[k]
+ for k in expr_values:
+ yield k
+ yield element[k]
+
+ def visit_dml_multi_values(self, element, **kw):
+ return ()
+
_get_children = _GetChildren()
@@ -644,6 +795,9 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
def visit_string(self, left_parent, left, right_parent, right, **kw):
return left == right
+ def visit_string_list(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
def visit_anon_name(self, left_parent, left, right_parent, right, **kw):
return _resolve_name_for_compare(
left_parent, left, self.anon_map[0], **kw
@@ -663,6 +817,11 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
def visit_plain_dict(self, left_parent, left, right_parent, right, **kw):
return left == right
+ def visit_dialect_options(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
def visit_plain_obj(self, left_parent, left, right_parent, right, **kw):
return left == right
@@ -713,6 +872,55 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
):
raise NotImplementedError()
+ def visit_dml_ordered_values(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ # sequence of tuple pairs
+
+ for (lk, lv), (rk, rv) in util.zip_longest(
+ left, right, fillvalue=(None, None)
+ ):
+ lkce = hasattr(lk, "__clause_element__")
+ rkce = hasattr(rk, "__clause_element__")
+ if lkce != rkce:
+ return COMPARE_FAILED
+ elif lkce and not self.compare_inner(lk, rk, **kw):
+ return COMPARE_FAILED
+ elif not lkce and lk != rk:
+ return COMPARE_FAILED
+ elif not self.compare_inner(lv, rv, **kw):
+ return COMPARE_FAILED
+
+ def visit_dml_values(self, left_parent, left, right_parent, right, **kw):
+ if left is None or right is None or len(left) != len(right):
+ return COMPARE_FAILED
+
+ for lk in left:
+ lv = left[lk]
+
+ if lk not in right:
+ return COMPARE_FAILED
+ rv = right[lk]
+
+ if not self.compare_inner(lv, rv, **kw):
+ return COMPARE_FAILED
+
+ def visit_dml_multi_values(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for lseq, rseq in util.zip_longest(left, right, fillvalue=None):
+ if lseq is None or rseq is None:
+ return COMPARE_FAILED
+
+ for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None):
+ if (
+ self.visit_dml_values(
+ left_parent, ld, right_parent, rd, **kw
+ )
+ is COMPARE_FAILED
+ ):
+ return COMPARE_FAILED
+
def compare_clauselist(self, left, right, **kw):
if left.operator is right.operator:
if operators.is_associative(left.operator):
@@ -731,11 +939,11 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
if left.operator == right.operator:
if operators.is_commutative(left.operator):
if (
- compare(left.left, right.left, **kw)
- and compare(left.right, right.right, **kw)
+ self.compare_inner(left.left, right.left, **kw)
+ and self.compare_inner(left.right, right.right, **kw)
) or (
- compare(left.left, right.right, **kw)
- and compare(left.right, right.left, **kw)
+ self.compare_inner(left.left, right.right, **kw)
+ and self.compare_inner(left.right, right.left, **kw)
):
return ["operator", "negate", "left", "right"]
else:
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index fda48c657..a049d9bb0 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -269,6 +269,9 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_string_list = symbol("SL")
+ """Visit a list of strings."""
+
dp_anon_name = symbol("AN")
"""Visit a potentially "anonymized" string value.
@@ -313,6 +316,9 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_dialect_options = symbol("DO")
+ """visit a dialect options structure."""
+
dp_string_clauseelement_dict = symbol("CD")
"""Visit a dictionary of string keys to :class:`.ClauseElement`
objects.
@@ -365,6 +371,21 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_dml_ordered_values = symbol("DML_OV")
+ """visit the values() ordered tuple list of an :class:`.Update` object."""
+
+ dp_dml_values = symbol("DML_V")
+ """visit the values() dictionary of a :class:`.ValuesBase
+ (e.g. Insert or Update) object.
+
+ """
+
+ dp_dml_multi_values = symbol("DML_MV")
+ """visit the values() multi-valued list of dictionaries of an
+ :class:`.Insert` object.
+
+ """
+
class ExtendedInternalTraversal(InternalTraversal):
"""defines additional symbols that are useful in caching applications.
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index 2770cc239..b21eb44cf 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -47,17 +47,10 @@ class immutabledict(ImmutableContainer, dict):
return immutabledict, (dict(self),)
def union(self, d):
- if not d:
- return self
- elif not self:
- if isinstance(d, immutabledict):
- return d
- else:
- return immutabledict(d)
- else:
- d2 = immutabledict(self)
- dict.update(d2, d)
- return d2
+ new = dict.__new__(self.__class__)
+ dict.__init__(new, self)
+ dict.update(new, d)
+ return new
def __repr__(self):
return "immutabledict(%s)" % dict.__repr__(self)