diff options
Diffstat (limited to 'lib/sqlalchemy/sql/dml.py')
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 267 |
1 files changed, 178 insertions, 89 deletions
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index f5fb6b2f3..0c9056aee 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -15,18 +15,29 @@ import collections.abc as collections_abc import operator import typing from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable from typing import List from typing import MutableMapping +from typing import NoReturn from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from typing import TYPE_CHECKING +from typing import Union from . import coercions from . import roles from . import util as sql_util +from ._typing import is_column_element +from ._typing import is_named_from_clause from .base import _entity_namespace_key from .base import _exclusive_against from .base import _from_objects from .base import _generative +from .base import _select_iterables from .base import ColumnCollection from .base import CompileState from .base import DialectKWArgs @@ -34,7 +45,9 @@ from .base import Executable from .base import HasCompileState from .elements import BooleanClauseList from .elements import ClauseElement +from .elements import ColumnElement from .elements import Null +from .selectable import FromClause from .selectable import HasCTE from .selectable import HasPrefixes from .selectable import ReturnsRows @@ -45,16 +58,25 @@ from .. import exc from .. import util from ..util.typing import TypeGuard - if TYPE_CHECKING: - def isupdate(dml) -> TypeGuard[UpdateDMLState]: + from ._typing import _ColumnsClauseArgument + from ._typing import _DMLColumnArgument + from ._typing import _FromClauseArgument + from ._typing import _HasClauseElement + from ._typing import _SelectIterable + from .base import ReadOnlyColumnCollection + from .compiler import SQLCompiler + from .elements import ColumnClause + from .selectable import Select + + def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: ... - def isdelete(dml) -> TypeGuard[DeleteDMLState]: + def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]: ... - def isinsert(dml) -> TypeGuard[InsertDMLState]: + def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]: ... else: @@ -63,27 +85,43 @@ else: isinsert = operator.attrgetter("isinsert") +_DMLColumnElement = Union[str, "ColumnClause[Any]"] + + class DMLState(CompileState): _no_parameters = True - _dict_parameters: Optional[MutableMapping[str, Any]] = None - _multi_parameters: Optional[List[MutableMapping[str, Any]]] = None - _ordered_values = None - _parameter_ordering = None + _dict_parameters: Optional[MutableMapping[_DMLColumnElement, Any]] = None + _multi_parameters: Optional[ + List[MutableMapping[_DMLColumnElement, Any]] + ] = None + _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None + _parameter_ordering: Optional[List[_DMLColumnElement]] = None _has_multi_parameters = False isupdate = False isdelete = False isinsert = False - def __init__(self, statement, compiler, **kw): + statement: UpdateBase + + def __init__( + self, statement: UpdateBase, compiler: SQLCompiler, **kw: Any + ): raise NotImplementedError() @classmethod - def get_entity_description(cls, statement): - return {"name": statement.table.name, "table": statement.table} + def get_entity_description(cls, statement: UpdateBase) -> Dict[str, Any]: + return { + "name": statement.table.name + if is_named_from_clause(statement.table) + else None, + "table": statement.table, + } @classmethod - def get_returning_column_descriptions(cls, statement): + def get_returning_column_descriptions( + cls, statement: UpdateBase + ) -> List[Dict[str, Any]]: return [ { "name": c.key, @@ -94,11 +132,21 @@ class DMLState(CompileState): ] @property - def dml_table(self): + def dml_table(self) -> roles.DMLTableRole: return self.statement.table + if TYPE_CHECKING: + + @classmethod + def get_plugin_class(cls, statement: Executable) -> Type[DMLState]: + ... + @classmethod - def _get_crud_kv_pairs(cls, statement, kv_iterator): + def _get_crud_kv_pairs( + cls, + statement: UpdateBase, + kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]], + ) -> List[Tuple[_DMLColumnElement, Any]]: return [ ( coercions.expect(roles.DMLColumnRole, k), @@ -112,8 +160,8 @@ class DMLState(CompileState): for k, v in kv_iterator ] - def _make_extra_froms(self, statement): - froms = [] + def _make_extra_froms(self, statement: DMLWhereBase) -> List[FromClause]: + froms: List[FromClause] = [] all_tables = list(sql_util.tables_from_leftmost(statement.table)) seen = {all_tables[0]} @@ -127,7 +175,7 @@ class DMLState(CompileState): froms.extend(all_tables[1:]) return froms - def _process_multi_values(self, statement): + def _process_multi_values(self, statement: ValuesBase) -> None: if not statement._supports_multi_parameters: raise exc.InvalidRequestError( "%s construct does not support " @@ -135,7 +183,7 @@ class DMLState(CompileState): ) for parameters in statement._multi_values: - multi_parameters = [ + multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [ { c.key: value for c, value in zip(statement.table.c, parameter_set) @@ -153,9 +201,10 @@ class DMLState(CompileState): elif not self._has_multi_parameters: self._cant_mix_formats_error() else: + assert self._multi_parameters self._multi_parameters.extend(multi_parameters) - def _process_values(self, statement): + def _process_values(self, statement: ValuesBase) -> None: if self._no_parameters: self._has_multi_parameters = False self._dict_parameters = statement._values @@ -163,11 +212,12 @@ class DMLState(CompileState): elif self._has_multi_parameters: self._cant_mix_formats_error() - def _process_ordered_values(self, statement): + def _process_ordered_values(self, statement: ValuesBase) -> None: parameters = statement._ordered_values if self._no_parameters: self._no_parameters = False + assert parameters is not None self._dict_parameters = dict(parameters) self._ordered_values = parameters self._parameter_ordering = [key for key, value in parameters] @@ -179,7 +229,8 @@ class DMLState(CompileState): "with any other values() call" ) - def _process_select_values(self, statement): + def _process_select_values(self, statement: ValuesBase) -> None: + assert statement._select_names is not None parameters = { coercions.expect(roles.DMLColumnRole, name, as_key=True): Null() for name in statement._select_names @@ -193,7 +244,7 @@ class DMLState(CompileState): # does not allow this construction to occur assert False, "This statement already has parameters" - def _cant_mix_formats_error(self): + def _cant_mix_formats_error(self) -> NoReturn: raise exc.InvalidRequestError( "Can't mix single and multiple VALUES " "formats in one INSERT statement; one style appends to a " @@ -208,7 +259,7 @@ class InsertDMLState(DMLState): include_table_with_column_exprs = False - def __init__(self, statement, compiler, **kw): + def __init__(self, statement: Insert, compiler: SQLCompiler, **kw: Any): self.statement = statement self.isinsert = True @@ -226,10 +277,9 @@ class UpdateDMLState(DMLState): include_table_with_column_exprs = False - def __init__(self, statement, compiler, **kw): + def __init__(self, statement: Update, compiler: SQLCompiler, **kw: Any): self.statement = statement self.isupdate = True - self._preserve_parameter_order = statement._preserve_parameter_order if statement._ordered_values is not None: self._process_ordered_values(statement) elif statement._values is not None: @@ -238,7 +288,7 @@ class UpdateDMLState(DMLState): self._process_multi_values(statement) self._extra_froms = ef = self._make_extra_froms(statement) self.is_multitable = mt = ef and self._dict_parameters - self.include_table_with_column_exprs = ( + self.include_table_with_column_exprs = bool( mt and compiler.render_table_with_column_in_update_from ) @@ -247,7 +297,7 @@ class UpdateDMLState(DMLState): class DeleteDMLState(DMLState): isdelete = True - def __init__(self, statement, compiler, **kw): + def __init__(self, statement: Delete, compiler: SQLCompiler, **kw: Any): self.statement = statement self.isdelete = True @@ -271,23 +321,31 @@ class UpdateBase( __visit_name__ = "update_base" - _hints = util.immutabledict() + _hints: util.immutabledict[ + Tuple[roles.DMLTableRole, str], str + ] = util.EMPTY_DICT named_with_column = False - table: TableClause + table: roles.DMLTableRole _return_defaults = False - _return_defaults_columns = None - _returning = () + _return_defaults_columns: Optional[ + Tuple[roles.ColumnsClauseRole, ...] + ] = None + _returning: Tuple[roles.ColumnsClauseRole, ...] = () is_dml = True - def _generate_fromclause_column_proxies(self, fromclause): + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: fromclause._columns._populate_separate_keys( - col._make_proxy(fromclause) for col in self._returning + col._make_proxy(fromclause) + for col in self._all_selected_columns + if is_column_element(col) ) - def params(self, *arg, **kw): + def params(self, *arg: Any, **kw: Any) -> NoReturn: """Set the parameters for the statement. This method raises ``NotImplementedError`` on the base class, @@ -302,7 +360,9 @@ class UpdateBase( ) @_generative - def with_dialect_options(self: SelfUpdateBase, **opt) -> SelfUpdateBase: + def with_dialect_options( + self: SelfUpdateBase, **opt: Any + ) -> SelfUpdateBase: """Add dialect options to this INSERT/UPDATE/DELETE object. e.g.:: @@ -318,7 +378,9 @@ class UpdateBase( return self @_generative - def returning(self: SelfUpdateBase, *cols) -> SelfUpdateBase: + def returning( + self: SelfUpdateBase, *cols: _ColumnsClauseArgument + ) -> SelfUpdateBase: r"""Add a :term:`RETURNING` or equivalent clause to this statement. e.g.: @@ -397,26 +459,32 @@ class UpdateBase( ) return self - @property - def _all_selected_columns(self): - return self._returning + @util.non_memoized_property + def _all_selected_columns(self) -> _SelectIterable: + return [c for c in _select_iterables(self._returning)] @property - def exported_columns(self): + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[Optional[str], ColumnElement[Any]]: """Return the RETURNING columns as a column collection for this statement. .. versionadded:: 1.4 """ - # TODO: no coverage here return ColumnCollection( - (c.key, c) for c in self._all_selected_columns - ).as_immutable() + (c.key, c) + for c in self._all_selected_columns + if is_column_element(c) + ).as_readonly() @_generative def with_hint( - self: SelfUpdateBase, text, selectable=None, dialect_name="*" + self: SelfUpdateBase, + text: str, + selectable: Optional[roles.DMLTableRole] = None, + dialect_name: str = "*", ) -> SelfUpdateBase: """Add a table hint for a single table to this INSERT/UPDATE/DELETE statement. @@ -454,7 +522,7 @@ class UpdateBase( return self @property - def entity_description(self): + def entity_description(self) -> Dict[str, Any]: """Return a :term:`plugin-enabled` description of the table and/or entity which this DML construct is operating against. @@ -490,7 +558,7 @@ class UpdateBase( return meth(self) @property - def returning_column_descriptions(self): + def returning_column_descriptions(self) -> List[Dict[str, Any]]: """Return a :term:`plugin-enabled` description of the columns which this DML construct is RETURNING against, in other words the expressions established as part of :meth:`.UpdateBase.returning`. @@ -547,18 +615,30 @@ class ValuesBase(UpdateBase): __visit_name__ = "values_base" _supports_multi_parameters = False - _preserve_parameter_order = False - select = None - _post_values_clause = None - _values = None - _multi_values = () - _ordered_values = None - _select_names = None + select: Optional[Select] = None + """SELECT statement for INSERT .. FROM SELECT""" + + _post_values_clause: Optional[ClauseElement] = None + """used by extensions to Insert etc. to add additional syntacitcal + constructs, e.g. ON CONFLICT etc.""" + + _values: Optional[util.immutabledict[_DMLColumnElement, Any]] = None + _multi_values: Tuple[ + Union[ + Sequence[Dict[_DMLColumnElement, Any]], + Sequence[Sequence[Any]], + ], + ..., + ] = () + + _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None + + _select_names: Optional[List[str]] = None _inline: bool = False - _returning = () + _returning: Tuple[roles.ColumnsClauseRole, ...] = () - def __init__(self, table): + def __init__(self, table: _FromClauseArgument): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) @@ -573,7 +653,14 @@ class ValuesBase(UpdateBase): "values present", }, ) - def values(self: SelfValuesBase, *args, **kwargs) -> SelfValuesBase: + def values( + self: SelfValuesBase, + *args: Union[ + Dict[_DMLColumnArgument, Any], + Sequence[Any], + ], + **kwargs: Any, + ) -> SelfValuesBase: r"""Specify a fixed VALUES clause for an INSERT statement, or the SET clause for an UPDATE. @@ -704,9 +791,7 @@ class ValuesBase(UpdateBase): "dictionaries/tuples is accepted positionally." ) - elif not self._preserve_parameter_order and isinstance( - arg, collections_abc.Sequence - ): + elif isinstance(arg, collections_abc.Sequence): if arg and isinstance(arg[0], (list, dict, tuple)): self._multi_values += (arg,) @@ -714,18 +799,11 @@ class ValuesBase(UpdateBase): # tuple values arg = {c.key: value for c, value in zip(self.table.c, arg)} - elif self._preserve_parameter_order and not isinstance( - arg, collections_abc.Sequence - ): - raise ValueError( - "When preserve_parameter_order is True, " - "values() only accepts a list of 2-tuples" - ) else: # kwarg path. this is the most common path for non-multi-params # so this is fairly quick. - arg = kwargs + arg = cast("Dict[_DMLColumnArgument, Any]", kwargs) if args: raise exc.ArgumentError( "Only a single dictionary/tuple or list of " @@ -739,15 +817,11 @@ class ValuesBase(UpdateBase): # and ensures they get the "crud"-style name when rendered. kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs - - if self._preserve_parameter_order: - self._ordered_values = kv_generator(self, arg) + coerced_arg = {k: v for k, v in kv_generator(self, arg.items())} + if self._values: + self._values = self._values.union(coerced_arg) else: - arg = {k: v for k, v in kv_generator(self, arg.items())} - if self._values: - self._values = self._values.union(arg) - else: - self._values = util.immutabledict(arg) + self._values = util.immutabledict(coerced_arg) return self @_generative @@ -758,7 +832,9 @@ class ValuesBase(UpdateBase): }, defaults={"_returning": _returning}, ) - def return_defaults(self: SelfValuesBase, *cols) -> SelfValuesBase: + def return_defaults( + self: SelfValuesBase, *cols: _DMLColumnArgument + ) -> SelfValuesBase: """Make use of a :term:`RETURNING` clause for the purpose of fetching server-side expressions and defaults. @@ -843,7 +919,9 @@ class ValuesBase(UpdateBase): """ self._return_defaults = True - self._return_defaults_columns = cols + self._return_defaults_columns = tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) return self @@ -867,6 +945,8 @@ class Insert(ValuesBase): is_insert = True + table: TableClause + _traverse_internals = ( [ ("table", InternalTraversal.dp_clauseelement), @@ -890,7 +970,7 @@ class Insert(ValuesBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table): + def __init__(self, table: roles.FromClauseRole): super(Insert, self).__init__(table) @_generative @@ -916,7 +996,10 @@ class Insert(ValuesBase): @_generative def from_select( - self: SelfInsert, names, select, include_defaults=True + self: SelfInsert, + names: List[str], + select: Select, + include_defaults: bool = True, ) -> SelfInsert: """Return a new :class:`_expression.Insert` construct which represents an ``INSERT...FROM SELECT`` statement. @@ -983,10 +1066,13 @@ SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase") class DMLWhereBase: - _where_criteria = () + table: roles.DMLTableRole + _where_criteria: Tuple[ColumnElement[Any], ...] = () @_generative - def where(self: SelfDMLWhereBase, *whereclause) -> SelfDMLWhereBase: + def where( + self: SelfDMLWhereBase, *whereclause: roles.ExpressionElementRole[Any] + ) -> SelfDMLWhereBase: """Return a new construct with the given expression(s) added to its WHERE clause, joined to the existing clause via AND, if any. @@ -1022,7 +1108,9 @@ class DMLWhereBase: self._where_criteria += (where_criteria,) return self - def filter(self: SelfDMLWhereBase, *criteria) -> SelfDMLWhereBase: + def filter( + self: SelfDMLWhereBase, *criteria: roles.ExpressionElementRole[Any] + ) -> SelfDMLWhereBase: """A synonym for the :meth:`_dml.DMLWhereBase.where` method. .. versionadded:: 1.4 @@ -1031,10 +1119,10 @@ class DMLWhereBase: return self.where(*criteria) - def _filter_by_zero(self): + def _filter_by_zero(self) -> roles.DMLTableRole: return self.table - def filter_by(self: SelfDMLWhereBase, **kwargs) -> SelfDMLWhereBase: + def filter_by(self: SelfDMLWhereBase, **kwargs: Any) -> SelfDMLWhereBase: r"""apply the given filtering criterion as a WHERE clause to this select. @@ -1048,7 +1136,7 @@ class DMLWhereBase: return self.filter(*clauses) @property - def whereclause(self): + def whereclause(self) -> Optional[ColumnElement[Any]]: """Return the completed WHERE clause for this :class:`.DMLWhereBase` statement. @@ -1079,7 +1167,6 @@ class Update(DMLWhereBase, ValuesBase): __visit_name__ = "update" is_update = True - _preserve_parameter_order = False _traverse_internals = ( [ @@ -1102,11 +1189,13 @@ class Update(DMLWhereBase, ValuesBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table): + def __init__(self, table: roles.FromClauseRole): super(Update, self).__init__(table) @_generative - def ordered_values(self: SelfUpdate, *args) -> SelfUpdate: + def ordered_values( + self: SelfUpdate, *args: Tuple[_DMLColumnArgument, Any] + ) -> SelfUpdate: """Specify the VALUES clause of this UPDATE statement with an explicit parameter ordering that will be maintained in the SET clause of the resulting UPDATE statement. @@ -1190,7 +1279,7 @@ class Delete(DMLWhereBase, UpdateBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table): + def __init__(self, table: roles.FromClauseRole): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) |
