summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/dml.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/dml.py')
-rw-r--r--lib/sqlalchemy/sql/dml.py267
1 files changed, 178 insertions, 89 deletions
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index f5fb6b2f3..0c9056aee 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -15,18 +15,29 @@ import collections.abc as collections_abc
import operator
import typing
from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterable
from typing import List
from typing import MutableMapping
+from typing import NoReturn
from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
+from typing import Union
from . import coercions
from . import roles
from . import util as sql_util
+from ._typing import is_column_element
+from ._typing import is_named_from_clause
from .base import _entity_namespace_key
from .base import _exclusive_against
from .base import _from_objects
from .base import _generative
+from .base import _select_iterables
from .base import ColumnCollection
from .base import CompileState
from .base import DialectKWArgs
@@ -34,7 +45,9 @@ from .base import Executable
from .base import HasCompileState
from .elements import BooleanClauseList
from .elements import ClauseElement
+from .elements import ColumnElement
from .elements import Null
+from .selectable import FromClause
from .selectable import HasCTE
from .selectable import HasPrefixes
from .selectable import ReturnsRows
@@ -45,16 +58,25 @@ from .. import exc
from .. import util
from ..util.typing import TypeGuard
-
if TYPE_CHECKING:
- def isupdate(dml) -> TypeGuard[UpdateDMLState]:
+ from ._typing import _ColumnsClauseArgument
+ from ._typing import _DMLColumnArgument
+ from ._typing import _FromClauseArgument
+ from ._typing import _HasClauseElement
+ from ._typing import _SelectIterable
+ from .base import ReadOnlyColumnCollection
+ from .compiler import SQLCompiler
+ from .elements import ColumnClause
+ from .selectable import Select
+
+ def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]:
...
- def isdelete(dml) -> TypeGuard[DeleteDMLState]:
+ def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]:
...
- def isinsert(dml) -> TypeGuard[InsertDMLState]:
+ def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]:
...
else:
@@ -63,27 +85,43 @@ else:
isinsert = operator.attrgetter("isinsert")
+_DMLColumnElement = Union[str, "ColumnClause[Any]"]
+
+
class DMLState(CompileState):
_no_parameters = True
- _dict_parameters: Optional[MutableMapping[str, Any]] = None
- _multi_parameters: Optional[List[MutableMapping[str, Any]]] = None
- _ordered_values = None
- _parameter_ordering = None
+ _dict_parameters: Optional[MutableMapping[_DMLColumnElement, Any]] = None
+ _multi_parameters: Optional[
+ List[MutableMapping[_DMLColumnElement, Any]]
+ ] = None
+ _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None
+ _parameter_ordering: Optional[List[_DMLColumnElement]] = None
_has_multi_parameters = False
isupdate = False
isdelete = False
isinsert = False
- def __init__(self, statement, compiler, **kw):
+ statement: UpdateBase
+
+ def __init__(
+ self, statement: UpdateBase, compiler: SQLCompiler, **kw: Any
+ ):
raise NotImplementedError()
@classmethod
- def get_entity_description(cls, statement):
- return {"name": statement.table.name, "table": statement.table}
+ def get_entity_description(cls, statement: UpdateBase) -> Dict[str, Any]:
+ return {
+ "name": statement.table.name
+ if is_named_from_clause(statement.table)
+ else None,
+ "table": statement.table,
+ }
@classmethod
- def get_returning_column_descriptions(cls, statement):
+ def get_returning_column_descriptions(
+ cls, statement: UpdateBase
+ ) -> List[Dict[str, Any]]:
return [
{
"name": c.key,
@@ -94,11 +132,21 @@ class DMLState(CompileState):
]
@property
- def dml_table(self):
+ def dml_table(self) -> roles.DMLTableRole:
return self.statement.table
+ if TYPE_CHECKING:
+
+ @classmethod
+ def get_plugin_class(cls, statement: Executable) -> Type[DMLState]:
+ ...
+
@classmethod
- def _get_crud_kv_pairs(cls, statement, kv_iterator):
+ def _get_crud_kv_pairs(
+ cls,
+ statement: UpdateBase,
+ kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]],
+ ) -> List[Tuple[_DMLColumnElement, Any]]:
return [
(
coercions.expect(roles.DMLColumnRole, k),
@@ -112,8 +160,8 @@ class DMLState(CompileState):
for k, v in kv_iterator
]
- def _make_extra_froms(self, statement):
- froms = []
+ def _make_extra_froms(self, statement: DMLWhereBase) -> List[FromClause]:
+ froms: List[FromClause] = []
all_tables = list(sql_util.tables_from_leftmost(statement.table))
seen = {all_tables[0]}
@@ -127,7 +175,7 @@ class DMLState(CompileState):
froms.extend(all_tables[1:])
return froms
- def _process_multi_values(self, statement):
+ def _process_multi_values(self, statement: ValuesBase) -> None:
if not statement._supports_multi_parameters:
raise exc.InvalidRequestError(
"%s construct does not support "
@@ -135,7 +183,7 @@ class DMLState(CompileState):
)
for parameters in statement._multi_values:
- multi_parameters = [
+ multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [
{
c.key: value
for c, value in zip(statement.table.c, parameter_set)
@@ -153,9 +201,10 @@ class DMLState(CompileState):
elif not self._has_multi_parameters:
self._cant_mix_formats_error()
else:
+ assert self._multi_parameters
self._multi_parameters.extend(multi_parameters)
- def _process_values(self, statement):
+ def _process_values(self, statement: ValuesBase) -> None:
if self._no_parameters:
self._has_multi_parameters = False
self._dict_parameters = statement._values
@@ -163,11 +212,12 @@ class DMLState(CompileState):
elif self._has_multi_parameters:
self._cant_mix_formats_error()
- def _process_ordered_values(self, statement):
+ def _process_ordered_values(self, statement: ValuesBase) -> None:
parameters = statement._ordered_values
if self._no_parameters:
self._no_parameters = False
+ assert parameters is not None
self._dict_parameters = dict(parameters)
self._ordered_values = parameters
self._parameter_ordering = [key for key, value in parameters]
@@ -179,7 +229,8 @@ class DMLState(CompileState):
"with any other values() call"
)
- def _process_select_values(self, statement):
+ def _process_select_values(self, statement: ValuesBase) -> None:
+ assert statement._select_names is not None
parameters = {
coercions.expect(roles.DMLColumnRole, name, as_key=True): Null()
for name in statement._select_names
@@ -193,7 +244,7 @@ class DMLState(CompileState):
# does not allow this construction to occur
assert False, "This statement already has parameters"
- def _cant_mix_formats_error(self):
+ def _cant_mix_formats_error(self) -> NoReturn:
raise exc.InvalidRequestError(
"Can't mix single and multiple VALUES "
"formats in one INSERT statement; one style appends to a "
@@ -208,7 +259,7 @@ class InsertDMLState(DMLState):
include_table_with_column_exprs = False
- def __init__(self, statement, compiler, **kw):
+ def __init__(self, statement: Insert, compiler: SQLCompiler, **kw: Any):
self.statement = statement
self.isinsert = True
@@ -226,10 +277,9 @@ class UpdateDMLState(DMLState):
include_table_with_column_exprs = False
- def __init__(self, statement, compiler, **kw):
+ def __init__(self, statement: Update, compiler: SQLCompiler, **kw: Any):
self.statement = statement
self.isupdate = True
- self._preserve_parameter_order = statement._preserve_parameter_order
if statement._ordered_values is not None:
self._process_ordered_values(statement)
elif statement._values is not None:
@@ -238,7 +288,7 @@ class UpdateDMLState(DMLState):
self._process_multi_values(statement)
self._extra_froms = ef = self._make_extra_froms(statement)
self.is_multitable = mt = ef and self._dict_parameters
- self.include_table_with_column_exprs = (
+ self.include_table_with_column_exprs = bool(
mt and compiler.render_table_with_column_in_update_from
)
@@ -247,7 +297,7 @@ class UpdateDMLState(DMLState):
class DeleteDMLState(DMLState):
isdelete = True
- def __init__(self, statement, compiler, **kw):
+ def __init__(self, statement: Delete, compiler: SQLCompiler, **kw: Any):
self.statement = statement
self.isdelete = True
@@ -271,23 +321,31 @@ class UpdateBase(
__visit_name__ = "update_base"
- _hints = util.immutabledict()
+ _hints: util.immutabledict[
+ Tuple[roles.DMLTableRole, str], str
+ ] = util.EMPTY_DICT
named_with_column = False
- table: TableClause
+ table: roles.DMLTableRole
_return_defaults = False
- _return_defaults_columns = None
- _returning = ()
+ _return_defaults_columns: Optional[
+ Tuple[roles.ColumnsClauseRole, ...]
+ ] = None
+ _returning: Tuple[roles.ColumnsClauseRole, ...] = ()
is_dml = True
- def _generate_fromclause_column_proxies(self, fromclause):
+ def _generate_fromclause_column_proxies(
+ self, fromclause: FromClause
+ ) -> None:
fromclause._columns._populate_separate_keys(
- col._make_proxy(fromclause) for col in self._returning
+ col._make_proxy(fromclause)
+ for col in self._all_selected_columns
+ if is_column_element(col)
)
- def params(self, *arg, **kw):
+ def params(self, *arg: Any, **kw: Any) -> NoReturn:
"""Set the parameters for the statement.
This method raises ``NotImplementedError`` on the base class,
@@ -302,7 +360,9 @@ class UpdateBase(
)
@_generative
- def with_dialect_options(self: SelfUpdateBase, **opt) -> SelfUpdateBase:
+ def with_dialect_options(
+ self: SelfUpdateBase, **opt: Any
+ ) -> SelfUpdateBase:
"""Add dialect options to this INSERT/UPDATE/DELETE object.
e.g.::
@@ -318,7 +378,9 @@ class UpdateBase(
return self
@_generative
- def returning(self: SelfUpdateBase, *cols) -> SelfUpdateBase:
+ def returning(
+ self: SelfUpdateBase, *cols: _ColumnsClauseArgument
+ ) -> SelfUpdateBase:
r"""Add a :term:`RETURNING` or equivalent clause to this statement.
e.g.:
@@ -397,26 +459,32 @@ class UpdateBase(
)
return self
- @property
- def _all_selected_columns(self):
- return self._returning
+ @util.non_memoized_property
+ def _all_selected_columns(self) -> _SelectIterable:
+ return [c for c in _select_iterables(self._returning)]
@property
- def exported_columns(self):
+ def exported_columns(
+ self,
+ ) -> ReadOnlyColumnCollection[Optional[str], ColumnElement[Any]]:
"""Return the RETURNING columns as a column collection for this
statement.
.. versionadded:: 1.4
"""
- # TODO: no coverage here
return ColumnCollection(
- (c.key, c) for c in self._all_selected_columns
- ).as_immutable()
+ (c.key, c)
+ for c in self._all_selected_columns
+ if is_column_element(c)
+ ).as_readonly()
@_generative
def with_hint(
- self: SelfUpdateBase, text, selectable=None, dialect_name="*"
+ self: SelfUpdateBase,
+ text: str,
+ selectable: Optional[roles.DMLTableRole] = None,
+ dialect_name: str = "*",
) -> SelfUpdateBase:
"""Add a table hint for a single table to this
INSERT/UPDATE/DELETE statement.
@@ -454,7 +522,7 @@ class UpdateBase(
return self
@property
- def entity_description(self):
+ def entity_description(self) -> Dict[str, Any]:
"""Return a :term:`plugin-enabled` description of the table and/or entity
which this DML construct is operating against.
@@ -490,7 +558,7 @@ class UpdateBase(
return meth(self)
@property
- def returning_column_descriptions(self):
+ def returning_column_descriptions(self) -> List[Dict[str, Any]]:
"""Return a :term:`plugin-enabled` description of the columns
which this DML construct is RETURNING against, in other words
the expressions established as part of :meth:`.UpdateBase.returning`.
@@ -547,18 +615,30 @@ class ValuesBase(UpdateBase):
__visit_name__ = "values_base"
_supports_multi_parameters = False
- _preserve_parameter_order = False
- select = None
- _post_values_clause = None
- _values = None
- _multi_values = ()
- _ordered_values = None
- _select_names = None
+ select: Optional[Select] = None
+ """SELECT statement for INSERT .. FROM SELECT"""
+
+ _post_values_clause: Optional[ClauseElement] = None
+ """used by extensions to Insert etc. to add additional syntacitcal
+ constructs, e.g. ON CONFLICT etc."""
+
+ _values: Optional[util.immutabledict[_DMLColumnElement, Any]] = None
+ _multi_values: Tuple[
+ Union[
+ Sequence[Dict[_DMLColumnElement, Any]],
+ Sequence[Sequence[Any]],
+ ],
+ ...,
+ ] = ()
+
+ _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None
+
+ _select_names: Optional[List[str]] = None
_inline: bool = False
- _returning = ()
+ _returning: Tuple[roles.ColumnsClauseRole, ...] = ()
- def __init__(self, table):
+ def __init__(self, table: _FromClauseArgument):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)
@@ -573,7 +653,14 @@ class ValuesBase(UpdateBase):
"values present",
},
)
- def values(self: SelfValuesBase, *args, **kwargs) -> SelfValuesBase:
+ def values(
+ self: SelfValuesBase,
+ *args: Union[
+ Dict[_DMLColumnArgument, Any],
+ Sequence[Any],
+ ],
+ **kwargs: Any,
+ ) -> SelfValuesBase:
r"""Specify a fixed VALUES clause for an INSERT statement, or the SET
clause for an UPDATE.
@@ -704,9 +791,7 @@ class ValuesBase(UpdateBase):
"dictionaries/tuples is accepted positionally."
)
- elif not self._preserve_parameter_order and isinstance(
- arg, collections_abc.Sequence
- ):
+ elif isinstance(arg, collections_abc.Sequence):
if arg and isinstance(arg[0], (list, dict, tuple)):
self._multi_values += (arg,)
@@ -714,18 +799,11 @@ class ValuesBase(UpdateBase):
# tuple values
arg = {c.key: value for c, value in zip(self.table.c, arg)}
- elif self._preserve_parameter_order and not isinstance(
- arg, collections_abc.Sequence
- ):
- raise ValueError(
- "When preserve_parameter_order is True, "
- "values() only accepts a list of 2-tuples"
- )
else:
# kwarg path. this is the most common path for non-multi-params
# so this is fairly quick.
- arg = kwargs
+ arg = cast("Dict[_DMLColumnArgument, Any]", kwargs)
if args:
raise exc.ArgumentError(
"Only a single dictionary/tuple or list of "
@@ -739,15 +817,11 @@ class ValuesBase(UpdateBase):
# and ensures they get the "crud"-style name when rendered.
kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
-
- if self._preserve_parameter_order:
- self._ordered_values = kv_generator(self, arg)
+ coerced_arg = {k: v for k, v in kv_generator(self, arg.items())}
+ if self._values:
+ self._values = self._values.union(coerced_arg)
else:
- arg = {k: v for k, v in kv_generator(self, arg.items())}
- if self._values:
- self._values = self._values.union(arg)
- else:
- self._values = util.immutabledict(arg)
+ self._values = util.immutabledict(coerced_arg)
return self
@_generative
@@ -758,7 +832,9 @@ class ValuesBase(UpdateBase):
},
defaults={"_returning": _returning},
)
- def return_defaults(self: SelfValuesBase, *cols) -> SelfValuesBase:
+ def return_defaults(
+ self: SelfValuesBase, *cols: _DMLColumnArgument
+ ) -> SelfValuesBase:
"""Make use of a :term:`RETURNING` clause for the purpose
of fetching server-side expressions and defaults.
@@ -843,7 +919,9 @@ class ValuesBase(UpdateBase):
"""
self._return_defaults = True
- self._return_defaults_columns = cols
+ self._return_defaults_columns = tuple(
+ coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+ )
return self
@@ -867,6 +945,8 @@ class Insert(ValuesBase):
is_insert = True
+ table: TableClause
+
_traverse_internals = (
[
("table", InternalTraversal.dp_clauseelement),
@@ -890,7 +970,7 @@ class Insert(ValuesBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table):
+ def __init__(self, table: roles.FromClauseRole):
super(Insert, self).__init__(table)
@_generative
@@ -916,7 +996,10 @@ class Insert(ValuesBase):
@_generative
def from_select(
- self: SelfInsert, names, select, include_defaults=True
+ self: SelfInsert,
+ names: List[str],
+ select: Select,
+ include_defaults: bool = True,
) -> SelfInsert:
"""Return a new :class:`_expression.Insert` construct which represents
an ``INSERT...FROM SELECT`` statement.
@@ -983,10 +1066,13 @@ SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase")
class DMLWhereBase:
- _where_criteria = ()
+ table: roles.DMLTableRole
+ _where_criteria: Tuple[ColumnElement[Any], ...] = ()
@_generative
- def where(self: SelfDMLWhereBase, *whereclause) -> SelfDMLWhereBase:
+ def where(
+ self: SelfDMLWhereBase, *whereclause: roles.ExpressionElementRole[Any]
+ ) -> SelfDMLWhereBase:
"""Return a new construct with the given expression(s) added to
its WHERE clause, joined to the existing clause via AND, if any.
@@ -1022,7 +1108,9 @@ class DMLWhereBase:
self._where_criteria += (where_criteria,)
return self
- def filter(self: SelfDMLWhereBase, *criteria) -> SelfDMLWhereBase:
+ def filter(
+ self: SelfDMLWhereBase, *criteria: roles.ExpressionElementRole[Any]
+ ) -> SelfDMLWhereBase:
"""A synonym for the :meth:`_dml.DMLWhereBase.where` method.
.. versionadded:: 1.4
@@ -1031,10 +1119,10 @@ class DMLWhereBase:
return self.where(*criteria)
- def _filter_by_zero(self):
+ def _filter_by_zero(self) -> roles.DMLTableRole:
return self.table
- def filter_by(self: SelfDMLWhereBase, **kwargs) -> SelfDMLWhereBase:
+ def filter_by(self: SelfDMLWhereBase, **kwargs: Any) -> SelfDMLWhereBase:
r"""apply the given filtering criterion as a WHERE clause
to this select.
@@ -1048,7 +1136,7 @@ class DMLWhereBase:
return self.filter(*clauses)
@property
- def whereclause(self):
+ def whereclause(self) -> Optional[ColumnElement[Any]]:
"""Return the completed WHERE clause for this :class:`.DMLWhereBase`
statement.
@@ -1079,7 +1167,6 @@ class Update(DMLWhereBase, ValuesBase):
__visit_name__ = "update"
is_update = True
- _preserve_parameter_order = False
_traverse_internals = (
[
@@ -1102,11 +1189,13 @@ class Update(DMLWhereBase, ValuesBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table):
+ def __init__(self, table: roles.FromClauseRole):
super(Update, self).__init__(table)
@_generative
- def ordered_values(self: SelfUpdate, *args) -> SelfUpdate:
+ def ordered_values(
+ self: SelfUpdate, *args: Tuple[_DMLColumnArgument, Any]
+ ) -> SelfUpdate:
"""Specify the VALUES clause of this UPDATE statement with an explicit
parameter ordering that will be maintained in the SET clause of the
resulting UPDATE statement.
@@ -1190,7 +1279,7 @@ class Delete(DMLWhereBase, UpdateBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table):
+ def __init__(self, table: roles.FromClauseRole):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)