summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/base.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-20 16:39:36 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-24 16:57:30 -0400
commit6f02d5edd88fe2475629438b0730181a2b00c5fe (patch)
treebbf9e9f3e8a2363659be35d59a7749c7fe35ef7c /lib/sqlalchemy/sql/base.py
parentc565c470517e1cc70a7f33d1ad3d3256935f1121 (diff)
downloadsqlalchemy-6f02d5edd88fe2475629438b0730181a2b00c5fe.tar.gz
pep484 - SQL internals
non-strict checking for mostly internal or semi-internal code Change-Id: Ib91b47f1a8ccc15e666b94bad1ce78c4ab15b0ec
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r--lib/sqlalchemy/sql/base.py269
1 files changed, 173 insertions, 96 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 6a6b389de..8f5135915 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -12,22 +12,32 @@
from __future__ import annotations
-import collections.abc as collections_abc
from enum import Enum
from functools import reduce
import itertools
from itertools import zip_longest
import operator
import re
-import typing
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import FrozenSet
+from typing import Generic
from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import Mapping
from typing import MutableMapping
+from typing import NoReturn
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from . import roles
from . import visitors
@@ -36,17 +46,26 @@ from .cache_key import MemoizedHasCacheKey # noqa
from .traversals import HasCopyInternals # noqa
from .visitors import ClauseVisitor
from .visitors import ExtendedInternalTraversal
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
+from .. import event
from .. import exc
from .. import util
from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
+from ..util.typing import Protocol
from ..util.typing import Self
+from ..util.typing import TypeGuard
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
+ from . import coercions
+ from . import elements
+ from . import type_api
from .elements import BindParameter
+ from .elements import ColumnClause
from .elements import ColumnElement
+ from .elements import SQLCoreOperations
from ..engine import Connection
from ..engine import Result
from ..engine.base import _CompiledCacheType
@@ -58,10 +77,12 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import CacheStats
from ..engine.interfaces import Compiled
from ..engine.interfaces import Dialect
+ from ..event import dispatcher
-coercions = None
-elements = None
-type_api = None
+if not TYPE_CHECKING:
+ coercions = None # noqa
+ elements = None # noqa
+ type_api = None # noqa
class _NoArg(Enum):
@@ -70,13 +91,24 @@ class _NoArg(Enum):
NO_ARG = _NoArg.NO_ARG
-# if I use sqlalchemy.util.typing, which has the exact same
-# symbols, mypy reports: "error: _Fn? not callable"
-_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_Fn = TypeVar("_Fn", bound=Callable[..., Any])
_AmbiguousTableNameMap = MutableMapping[str, str]
+class _EntityNamespace(Protocol):
+ def __getattr__(self, key: str) -> SQLCoreOperations[Any]:
+ ...
+
+
+class _HasEntityNamespace(Protocol):
+ entity_namespace: _EntityNamespace
+
+
+def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]:
+ return hasattr(element, "entity_namespace")
+
+
class Immutable:
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
@@ -107,10 +139,14 @@ class SingletonConstant(Immutable):
def __new__(cls, *arg, **kw):
return cls._singleton
+ @util.non_memoized_property
+ def proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
+ raise NotImplementedError()
+
@classmethod
def _create_singleton(cls):
obj = object.__new__(cls)
- obj.__init__()
+ obj.__init__() # type: ignore
# for a long time this was an empty frozenset, meaning
# a SingletonConstant would never be a "corresponding column" in
@@ -139,12 +175,11 @@ def _select_iterables(elements):
)
-_Self = typing.TypeVar("_Self", bound="_GenerativeType")
-_Args = compat_typing.ParamSpec("_Args")
+_SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType")
class _GenerativeType(compat_typing.Protocol):
- def _generate(self: "_Self") -> "_Self":
+ def _generate(self: _SelfGenerativeType) -> _SelfGenerativeType:
...
@@ -158,8 +193,8 @@ def _generative(fn: _Fn) -> _Fn:
@util.decorator
def _generative(
- fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
- ) -> _Self:
+ fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any
+ ) -> _SelfGenerativeType:
"""Mark a method as generative."""
self = self._generate()
@@ -167,9 +202,9 @@ def _generative(fn: _Fn) -> _Fn:
assert x is self, "generative methods must return self"
return self
- decorated = _generative(fn)
- decorated.non_generative = fn
- return decorated
+ decorated = _generative(fn) # type: ignore
+ decorated.non_generative = fn # type: ignore
+ return decorated # type: ignore
def _exclusive_against(*names, **kw):
@@ -233,7 +268,7 @@ def _cloned_difference(a, b):
)
-class _DialectArgView(collections_abc.MutableMapping):
+class _DialectArgView(MutableMapping[str, Any]):
"""A dictionary view of dialect-level arguments in the form
<dialectname>_<argument_name>.
@@ -290,7 +325,7 @@ class _DialectArgView(collections_abc.MutableMapping):
)
-class _DialectArgDict(collections_abc.MutableMapping):
+class _DialectArgDict(MutableMapping[str, Any]):
"""A dictionary view of dialect-level arguments for a specific
dialect.
@@ -343,6 +378,8 @@ class DialectKWArgs:
"""
+ __slots__ = ()
+
_dialect_kwargs_traverse_internals = [
("dialect_options", InternalTraversal.dp_dialect_options)
]
@@ -534,7 +571,7 @@ class CompileState:
__slots__ = ("statement", "_ambiguous_table_name_map")
- plugins = {}
+ plugins: Dict[Tuple[str, str], Type[CompileState]] = {}
_ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
@@ -639,9 +676,9 @@ class InPlaceGenerative(HasMemoized):
class HasCompileState(Generative):
"""A class that has a :class:`.CompileState` associated with it."""
- _compile_state_plugin = None
+ _compile_state_plugin: Optional[Type[CompileState]] = None
- _attributes = util.immutabledict()
+ _attributes: util.immutabledict[str, Any] = util.EMPTY_DICT
_compile_state_factory = CompileState.create_for_statement
@@ -655,6 +692,8 @@ class _MetaOptions(type):
"""
+ _cache_attrs: Tuple[str, ...]
+
def __add__(self, other):
o1 = self()
@@ -674,6 +713,8 @@ class Options(metaclass=_MetaOptions):
__slots__ = ()
+ _cache_attrs: Tuple[str, ...]
+
def __init_subclass__(cls) -> None:
dict_ = cls.__dict__
cls._cache_attrs = tuple(
@@ -732,13 +773,13 @@ class Options(metaclass=_MetaOptions):
return self + {name: getattr(self, name) + value}
@hybridmethod
- def _state_dict(self):
+ def _state_dict_inst(self) -> Mapping[str, Any]:
return self.__dict__
- _state_dict_const = util.immutabledict()
+ _state_dict_const: util.immutabledict[str, Any] = util.EMPTY_DICT
- @_state_dict.classlevel
- def _state_dict(cls):
+ @_state_dict_inst.classlevel
+ def _state_dict(cls) -> Mapping[str, Any]:
return cls._state_dict_const
@classmethod
@@ -825,10 +866,10 @@ class CacheableOptions(Options, HasCacheKey):
__slots__ = ()
@hybridmethod
- def _gen_cache_key(self, anon_map, bindparams):
+ def _gen_cache_key_inst(self, anon_map, bindparams):
return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
- @_gen_cache_key.classlevel
+ @_gen_cache_key_inst.classlevel
def _gen_cache_key(cls, anon_map, bindparams):
return (cls, ())
@@ -849,11 +890,11 @@ class ExecutableOption(HasCopyInternals):
def _clone(self, **kw):
"""Create a shallow copy of this ExecutableOption."""
c = self.__class__.__new__(self.__class__)
- c.__dict__ = dict(self.__dict__)
+ c.__dict__ = dict(self.__dict__) # type: ignore
return c
-SelfExecutable = typing.TypeVar("SelfExecutable", bound="Executable")
+SelfExecutable = TypeVar("SelfExecutable", bound="Executable")
class Executable(roles.StatementRole, Generative):
@@ -866,9 +907,12 @@ class Executable(roles.StatementRole, Generative):
"""
supports_execution: bool = True
- _execution_options: _ImmutableExecuteOptions = util.immutabledict()
- _with_options = ()
- _with_context_options = ()
+ _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
+ _with_options: Tuple[ExecutableOption, ...] = ()
+ _with_context_options: Tuple[
+ Tuple[Callable[[CompileState], None], Any], ...
+ ] = ()
+ _compile_options: Optional[CacheableOptions]
_executable_traverse_internals = [
("_with_options", InternalTraversal.dp_executable_options),
@@ -886,7 +930,9 @@ class Executable(roles.StatementRole, Generative):
is_delete = False
is_dml = False
- if typing.TYPE_CHECKING:
+ if TYPE_CHECKING:
+
+ __visit_name__: str
def _compile_w_cache(
self,
@@ -916,11 +962,13 @@ class Executable(roles.StatementRole, Generative):
raise NotImplementedError()
@property
- def _effective_plugin_target(self):
+ def _effective_plugin_target(self) -> str:
return self.__visit_name__
@_generative
- def options(self: SelfExecutable, *options) -> SelfExecutable:
+ def options(
+ self: SelfExecutable, *options: ExecutableOption
+ ) -> SelfExecutable:
"""Apply options to this statement.
In the general sense, options are any kind of Python object
@@ -957,7 +1005,7 @@ class Executable(roles.StatementRole, Generative):
@_generative
def _set_compile_options(
- self: SelfExecutable, compile_options
+ self: SelfExecutable, compile_options: CacheableOptions
) -> SelfExecutable:
"""Assign the compile options to a new value.
@@ -970,16 +1018,19 @@ class Executable(roles.StatementRole, Generative):
@_generative
def _update_compile_options(
- self: SelfExecutable, options
+ self: SelfExecutable, options: CacheableOptions
) -> SelfExecutable:
"""update the _compile_options with new keys."""
+ assert self._compile_options is not None
self._compile_options += options
return self
@_generative
def _add_context_option(
- self: SelfExecutable, callable_, cache_args
+ self: SelfExecutable,
+ callable_: Callable[[CompileState], None],
+ cache_args: Any,
) -> SelfExecutable:
"""Add a context option to this statement.
@@ -995,7 +1046,7 @@ class Executable(roles.StatementRole, Generative):
return self
@_generative
- def execution_options(self: SelfExecutable, **kw) -> SelfExecutable:
+ def execution_options(self: SelfExecutable, **kw: Any) -> SelfExecutable:
"""Set non-SQL options for the statement which take effect during
execution.
@@ -1112,7 +1163,7 @@ class Executable(roles.StatementRole, Generative):
self._execution_options = self._execution_options.union(kw)
return self
- def get_execution_options(self):
+ def get_execution_options(self) -> _ExecuteOptions:
"""Get the non-SQL options which will take effect during execution.
.. versionadded:: 1.3
@@ -1124,7 +1175,7 @@ class Executable(roles.StatementRole, Generative):
return self._execution_options
-class SchemaEventTarget:
+class SchemaEventTarget(event.EventTarget):
"""Base class for elements that are the targets of :class:`.DDLEvents`
events.
@@ -1132,6 +1183,8 @@ class SchemaEventTarget:
"""
+ dispatch: dispatcher[SchemaEventTarget]
+
def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
"""Associate with this SchemaEvent's parent object."""
@@ -1149,7 +1202,10 @@ class SchemaVisitor(ClauseVisitor):
__traverse_options__ = {"schema_visitor": True}
-class ColumnCollection:
+_COL = TypeVar("_COL", bound="ColumnClause[Any]")
+
+
+class ColumnCollection(Generic[_COL]):
"""Collection of :class:`_expression.ColumnElement` instances,
typically for
:class:`_sql.FromClause` objects.
@@ -1260,32 +1316,36 @@ class ColumnCollection:
__slots__ = "_collection", "_index", "_colset"
- def __init__(self, columns=None):
+ _collection: List[Tuple[str, _COL]]
+ _index: Dict[Union[str, int], _COL]
+ _colset: Set[_COL]
+
+ def __init__(self, columns: Optional[Iterable[Tuple[str, _COL]]] = None):
object.__setattr__(self, "_colset", set())
object.__setattr__(self, "_index", {})
object.__setattr__(self, "_collection", [])
if columns:
self._initial_populate(columns)
- def _initial_populate(self, iter_):
+ def _initial_populate(self, iter_: Iterable[Tuple[str, _COL]]) -> None:
self._populate_separate_keys(iter_)
@property
- def _all_columns(self):
+ def _all_columns(self) -> List[_COL]:
return [col for (k, col) in self._collection]
- def keys(self):
+ def keys(self) -> List[str]:
"""Return a sequence of string key names for all columns in this
collection."""
return [k for (k, col) in self._collection]
- def values(self):
+ def values(self) -> List[_COL]:
"""Return a sequence of :class:`_sql.ColumnClause` or
:class:`_schema.Column` objects for all columns in this
collection."""
return [col for (k, col) in self._collection]
- def items(self):
+ def items(self) -> List[Tuple[str, _COL]]:
"""Return a sequence of (key, column) tuples for all columns in this
collection each consisting of a string key name and a
:class:`_sql.ColumnClause` or
@@ -1294,17 +1354,17 @@ class ColumnCollection:
return list(self._collection)
- def __bool__(self):
+ def __bool__(self) -> bool:
return bool(self._collection)
- def __len__(self):
+ def __len__(self) -> int:
return len(self._collection)
- def __iter__(self):
+ def __iter__(self) -> Iterator[_COL]:
# turn to a list first to maintain over a course of changes
return iter([col for k, col in self._collection])
- def __getitem__(self, key):
+ def __getitem__(self, key: Union[str, int]) -> _COL:
try:
return self._index[key]
except KeyError as err:
@@ -1313,13 +1373,13 @@ class ColumnCollection:
else:
raise
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> _COL:
try:
return self._index[key]
except KeyError as err:
raise AttributeError(key) from err
- def __contains__(self, key):
+ def __contains__(self, key: str) -> bool:
if key not in self._index:
if not isinstance(key, str):
raise exc.ArgumentError(
@@ -1329,7 +1389,7 @@ class ColumnCollection:
else:
return True
- def compare(self, other):
+ def compare(self, other: ColumnCollection[Any]) -> bool:
"""Compare this :class:`_expression.ColumnCollection` to another
based on the names of the keys"""
@@ -1339,10 +1399,10 @@ class ColumnCollection:
else:
return True
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return self.compare(other)
- def get(self, key, default=None):
+ def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]:
"""Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
based on a string key name from this
:class:`_expression.ColumnCollection`."""
@@ -1352,39 +1412,40 @@ class ColumnCollection:
else:
return default
- def __str__(self):
+ def __str__(self) -> str:
return "%s(%s)" % (
self.__class__.__name__,
", ".join(str(c) for c in self),
)
- def __setitem__(self, key, value):
+ def __setitem__(self, key: str, value: Any) -> NoReturn:
raise NotImplementedError()
- def __delitem__(self, key):
+ def __delitem__(self, key: str) -> NoReturn:
raise NotImplementedError()
- def __setattr__(self, key, obj):
+ def __setattr__(self, key: str, obj: Any) -> NoReturn:
raise NotImplementedError()
- def clear(self):
+ def clear(self) -> NoReturn:
"""Dictionary clear() is not implemented for
:class:`_sql.ColumnCollection`."""
raise NotImplementedError()
- def remove(self, column):
- """Dictionary remove() is not implemented for
- :class:`_sql.ColumnCollection`."""
+ def remove(self, column: Any) -> None:
raise NotImplementedError()
- def update(self, iter_):
+ def update(self, iter_: Any) -> NoReturn:
"""Dictionary update() is not implemented for
:class:`_sql.ColumnCollection`."""
raise NotImplementedError()
- __hash__ = None
+ # https://github.com/python/mypy/issues/4266
+ __hash__ = None # type: ignore
- def _populate_separate_keys(self, iter_):
+ def _populate_separate_keys(
+ self, iter_: Iterable[Tuple[str, _COL]]
+ ) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
self._collection[:] = cols
@@ -1394,7 +1455,7 @@ class ColumnCollection:
)
self._index.update({k: col for k, col in reversed(self._collection)})
- def add(self, column, key=None):
+ def add(self, column: _COL, key: Optional[str] = None) -> None:
"""Add a column to this :class:`_sql.ColumnCollection`.
.. note::
@@ -1416,17 +1477,17 @@ class ColumnCollection:
if key not in self._index:
self._index[key] = column
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
return {"_collection": self._collection, "_index": self._index}
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "_index", state["_index"])
object.__setattr__(self, "_collection", state["_collection"])
object.__setattr__(
self, "_colset", {col for k, col in self._collection}
)
- def contains_column(self, col):
+ def contains_column(self, col: _COL) -> bool:
"""Checks if a column object exists in this collection"""
if col not in self._colset:
if isinstance(col, str):
@@ -1438,13 +1499,15 @@ class ColumnCollection:
else:
return True
- def as_immutable(self):
+ def as_immutable(self) -> ImmutableColumnCollection[_COL]:
"""Return an "immutable" form of this
:class:`_sql.ColumnCollection`."""
return ImmutableColumnCollection(self)
- def corresponding_column(self, column, require_embedded=False):
+ def corresponding_column(
+ self, column: _COL, require_embedded: bool = False
+ ) -> Optional[_COL]:
"""Given a :class:`_expression.ColumnElement`, return the exported
:class:`_expression.ColumnElement` object from this
:class:`_expression.ColumnCollection`
@@ -1497,7 +1560,7 @@ class ColumnCollection:
not require_embedded
or embedded(expanded_proxy_set, target_set)
):
- if col is None:
+ if col is None or intersect is None:
# no corresponding column yet, pick this one.
@@ -1542,7 +1605,7 @@ class ColumnCollection:
return col
-class DedupeColumnCollection(ColumnCollection):
+class DedupeColumnCollection(ColumnCollection[_COL]):
"""A :class:`_expression.ColumnCollection`
that maintains deduplicating behavior.
@@ -1555,7 +1618,7 @@ class DedupeColumnCollection(ColumnCollection):
"""
- def add(self, column, key=None):
+ def add(self, column: _COL, key: Optional[str] = None) -> None:
if key is not None and column.key != key:
raise exc.ArgumentError(
@@ -1589,7 +1652,9 @@ class DedupeColumnCollection(ColumnCollection):
self._index[l] = column
self._index[key] = column
- def _populate_separate_keys(self, iter_):
+ def _populate_separate_keys(
+ self, iter_: Iterable[Tuple[str, _COL]]
+ ) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
@@ -1614,10 +1679,10 @@ class DedupeColumnCollection(ColumnCollection):
for col in replace_col:
self.replace(col)
- def extend(self, iter_):
+ def extend(self, iter_: Iterable[_COL]) -> None:
self._populate_separate_keys((col.key, col) for col in iter_)
- def remove(self, column):
+ def remove(self, column: _COL) -> None:
if column not in self._colset:
raise ValueError(
"Can't remove column %r; column is not in this collection"
@@ -1634,7 +1699,7 @@ class DedupeColumnCollection(ColumnCollection):
# delete higher index
del self._index[len(self._collection)]
- def replace(self, column):
+ def replace(self, column: _COL) -> None:
"""add the given column to this collection, removing unaliased
versions of this column as well as existing columns with the
same key.
@@ -1687,7 +1752,9 @@ class DedupeColumnCollection(ColumnCollection):
self._index.update(self._collection)
-class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
+class ImmutableColumnCollection(
+ util.ImmutableContainer, ColumnCollection[_COL]
+):
__slots__ = ("_parent",)
def __init__(self, collection):
@@ -1701,12 +1768,19 @@ class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
def __setstate__(self, state):
parent = state["_parent"]
- self.__init__(parent)
+ self.__init__(parent) # type: ignore
- add = extend = remove = util.ImmutableContainer._immutable
+ def add(self, column: Any, key: Any = ...) -> Any:
+ self._immutable()
+ def extend(self, elements: Any) -> None:
+ self._immutable()
-class ColumnSet(util.ordered_column_set):
+ def remove(self, item: Any) -> None:
+ self._immutable()
+
+
+class ColumnSet(util.OrderedSet["ColumnClause[Any]"]):
def contains_column(self, col):
return col in self
@@ -1714,9 +1788,6 @@ class ColumnSet(util.ordered_column_set):
for col in cols:
self.add(col)
- def __add__(self, other):
- return list(self) + list(other)
-
def __eq__(self, other):
l = []
for c in other:
@@ -1729,7 +1800,9 @@ class ColumnSet(util.ordered_column_set):
return hash(tuple(x for x in self))
-def _entity_namespace(entity):
+def _entity_namespace(
+ entity: Union[_HasEntityNamespace, ExternallyTraversible]
+) -> _EntityNamespace:
"""Return the nearest .entity_namespace for the given entity.
If not immediately available, does an iterate to find a sub-element
@@ -1737,16 +1810,20 @@ def _entity_namespace(entity):
"""
try:
- return entity.entity_namespace
+ return cast(_HasEntityNamespace, entity).entity_namespace
except AttributeError:
- for elem in visitors.iterate(entity):
- if hasattr(elem, "entity_namespace"):
+ for elem in visitors.iterate(cast(ExternallyTraversible, entity)):
+ if _is_has_entity_namespace(elem):
return elem.entity_namespace
else:
raise
-def _entity_namespace_key(entity, key, default=NO_ARG):
+def _entity_namespace_key(
+ entity: Union[_HasEntityNamespace, ExternallyTraversible],
+ key: str,
+ default: Union[SQLCoreOperations[Any], _NoArg] = NO_ARG,
+) -> SQLCoreOperations[Any]:
"""Return an entry from an entity_namespace.
@@ -1760,7 +1837,7 @@ def _entity_namespace_key(entity, key, default=NO_ARG):
if default is not NO_ARG:
return getattr(ns, key, default)
else:
- return getattr(ns, key)
+ return getattr(ns, key) # type: ignore
except AttributeError as err:
raise exc.InvalidRequestError(
'Entity namespace for "%s" has no property "%s"' % (entity, key)