diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-20 16:39:36 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-24 16:57:30 -0400 |
| commit | 6f02d5edd88fe2475629438b0730181a2b00c5fe (patch) | |
| tree | bbf9e9f3e8a2363659be35d59a7749c7fe35ef7c /lib/sqlalchemy/sql/base.py | |
| parent | c565c470517e1cc70a7f33d1ad3d3256935f1121 (diff) | |
| download | sqlalchemy-6f02d5edd88fe2475629438b0730181a2b00c5fe.tar.gz | |
pep484 - SQL internals
non-strict checking for mostly internal or semi-internal
code
Change-Id: Ib91b47f1a8ccc15e666b94bad1ce78c4ab15b0ec
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 269 |
1 files changed, 173 insertions, 96 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6a6b389de..8f5135915 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -12,22 +12,32 @@ from __future__ import annotations -import collections.abc as collections_abc from enum import Enum from functools import reduce import itertools from itertools import zip_longest import operator import re -import typing from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Generic from typing import Iterable +from typing import Iterator from typing import List +from typing import Mapping from typing import MutableMapping +from typing import NoReturn from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import roles from . import visitors @@ -36,17 +46,26 @@ from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa from .visitors import ClauseVisitor from .visitors import ExtendedInternalTraversal +from .visitors import ExternallyTraversible from .visitors import InternalTraversal +from .. import event from .. import exc from .. import util from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util import typing as compat_typing +from ..util.typing import Protocol from ..util.typing import Self +from ..util.typing import TypeGuard -if typing.TYPE_CHECKING: +if TYPE_CHECKING: + from . import coercions + from . import elements + from . import type_api from .elements import BindParameter + from .elements import ColumnClause from .elements import ColumnElement + from .elements import SQLCoreOperations from ..engine import Connection from ..engine import Result from ..engine.base import _CompiledCacheType @@ -58,10 +77,12 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import CacheStats from ..engine.interfaces import Compiled from ..engine.interfaces import Dialect + from ..event import dispatcher -coercions = None -elements = None -type_api = None +if not TYPE_CHECKING: + coercions = None # noqa + elements = None # noqa + type_api = None # noqa class _NoArg(Enum): @@ -70,13 +91,24 @@ class _NoArg(Enum): NO_ARG = _NoArg.NO_ARG -# if I use sqlalchemy.util.typing, which has the exact same -# symbols, mypy reports: "error: _Fn? not callable" -_Fn = typing.TypeVar("_Fn", bound=typing.Callable) +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) _AmbiguousTableNameMap = MutableMapping[str, str] +class _EntityNamespace(Protocol): + def __getattr__(self, key: str) -> SQLCoreOperations[Any]: + ... + + +class _HasEntityNamespace(Protocol): + entity_namespace: _EntityNamespace + + +def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]: + return hasattr(element, "entity_namespace") + + class Immutable: """mark a ClauseElement as 'immutable' when expressions are cloned.""" @@ -107,10 +139,14 @@ class SingletonConstant(Immutable): def __new__(cls, *arg, **kw): return cls._singleton + @util.non_memoized_property + def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + raise NotImplementedError() + @classmethod def _create_singleton(cls): obj = object.__new__(cls) - obj.__init__() + obj.__init__() # type: ignore # for a long time this was an empty frozenset, meaning # a SingletonConstant would never be a "corresponding column" in @@ -139,12 +175,11 @@ def _select_iterables(elements): ) -_Self = typing.TypeVar("_Self", bound="_GenerativeType") -_Args = compat_typing.ParamSpec("_Args") +_SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType") class _GenerativeType(compat_typing.Protocol): - def _generate(self: "_Self") -> "_Self": + def _generate(self: _SelfGenerativeType) -> _SelfGenerativeType: ... @@ -158,8 +193,8 @@ def _generative(fn: _Fn) -> _Fn: @util.decorator def _generative( - fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs - ) -> _Self: + fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any + ) -> _SelfGenerativeType: """Mark a method as generative.""" self = self._generate() @@ -167,9 +202,9 @@ def _generative(fn: _Fn) -> _Fn: assert x is self, "generative methods must return self" return self - decorated = _generative(fn) - decorated.non_generative = fn - return decorated + decorated = _generative(fn) # type: ignore + decorated.non_generative = fn # type: ignore + return decorated # type: ignore def _exclusive_against(*names, **kw): @@ -233,7 +268,7 @@ def _cloned_difference(a, b): ) -class _DialectArgView(collections_abc.MutableMapping): +class _DialectArgView(MutableMapping[str, Any]): """A dictionary view of dialect-level arguments in the form <dialectname>_<argument_name>. @@ -290,7 +325,7 @@ class _DialectArgView(collections_abc.MutableMapping): ) -class _DialectArgDict(collections_abc.MutableMapping): +class _DialectArgDict(MutableMapping[str, Any]): """A dictionary view of dialect-level arguments for a specific dialect. @@ -343,6 +378,8 @@ class DialectKWArgs: """ + __slots__ = () + _dialect_kwargs_traverse_internals = [ ("dialect_options", InternalTraversal.dp_dialect_options) ] @@ -534,7 +571,7 @@ class CompileState: __slots__ = ("statement", "_ambiguous_table_name_map") - plugins = {} + plugins: Dict[Tuple[str, str], Type[CompileState]] = {} _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] @@ -639,9 +676,9 @@ class InPlaceGenerative(HasMemoized): class HasCompileState(Generative): """A class that has a :class:`.CompileState` associated with it.""" - _compile_state_plugin = None + _compile_state_plugin: Optional[Type[CompileState]] = None - _attributes = util.immutabledict() + _attributes: util.immutabledict[str, Any] = util.EMPTY_DICT _compile_state_factory = CompileState.create_for_statement @@ -655,6 +692,8 @@ class _MetaOptions(type): """ + _cache_attrs: Tuple[str, ...] + def __add__(self, other): o1 = self() @@ -674,6 +713,8 @@ class Options(metaclass=_MetaOptions): __slots__ = () + _cache_attrs: Tuple[str, ...] + def __init_subclass__(cls) -> None: dict_ = cls.__dict__ cls._cache_attrs = tuple( @@ -732,13 +773,13 @@ class Options(metaclass=_MetaOptions): return self + {name: getattr(self, name) + value} @hybridmethod - def _state_dict(self): + def _state_dict_inst(self) -> Mapping[str, Any]: return self.__dict__ - _state_dict_const = util.immutabledict() + _state_dict_const: util.immutabledict[str, Any] = util.EMPTY_DICT - @_state_dict.classlevel - def _state_dict(cls): + @_state_dict_inst.classlevel + def _state_dict(cls) -> Mapping[str, Any]: return cls._state_dict_const @classmethod @@ -825,10 +866,10 @@ class CacheableOptions(Options, HasCacheKey): __slots__ = () @hybridmethod - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key_inst(self, anon_map, bindparams): return HasCacheKey._gen_cache_key(self, anon_map, bindparams) - @_gen_cache_key.classlevel + @_gen_cache_key_inst.classlevel def _gen_cache_key(cls, anon_map, bindparams): return (cls, ()) @@ -849,11 +890,11 @@ class ExecutableOption(HasCopyInternals): def _clone(self, **kw): """Create a shallow copy of this ExecutableOption.""" c = self.__class__.__new__(self.__class__) - c.__dict__ = dict(self.__dict__) + c.__dict__ = dict(self.__dict__) # type: ignore return c -SelfExecutable = typing.TypeVar("SelfExecutable", bound="Executable") +SelfExecutable = TypeVar("SelfExecutable", bound="Executable") class Executable(roles.StatementRole, Generative): @@ -866,9 +907,12 @@ class Executable(roles.StatementRole, Generative): """ supports_execution: bool = True - _execution_options: _ImmutableExecuteOptions = util.immutabledict() - _with_options = () - _with_context_options = () + _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT + _with_options: Tuple[ExecutableOption, ...] = () + _with_context_options: Tuple[ + Tuple[Callable[[CompileState], None], Any], ... + ] = () + _compile_options: Optional[CacheableOptions] _executable_traverse_internals = [ ("_with_options", InternalTraversal.dp_executable_options), @@ -886,7 +930,9 @@ class Executable(roles.StatementRole, Generative): is_delete = False is_dml = False - if typing.TYPE_CHECKING: + if TYPE_CHECKING: + + __visit_name__: str def _compile_w_cache( self, @@ -916,11 +962,13 @@ class Executable(roles.StatementRole, Generative): raise NotImplementedError() @property - def _effective_plugin_target(self): + def _effective_plugin_target(self) -> str: return self.__visit_name__ @_generative - def options(self: SelfExecutable, *options) -> SelfExecutable: + def options( + self: SelfExecutable, *options: ExecutableOption + ) -> SelfExecutable: """Apply options to this statement. In the general sense, options are any kind of Python object @@ -957,7 +1005,7 @@ class Executable(roles.StatementRole, Generative): @_generative def _set_compile_options( - self: SelfExecutable, compile_options + self: SelfExecutable, compile_options: CacheableOptions ) -> SelfExecutable: """Assign the compile options to a new value. @@ -970,16 +1018,19 @@ class Executable(roles.StatementRole, Generative): @_generative def _update_compile_options( - self: SelfExecutable, options + self: SelfExecutable, options: CacheableOptions ) -> SelfExecutable: """update the _compile_options with new keys.""" + assert self._compile_options is not None self._compile_options += options return self @_generative def _add_context_option( - self: SelfExecutable, callable_, cache_args + self: SelfExecutable, + callable_: Callable[[CompileState], None], + cache_args: Any, ) -> SelfExecutable: """Add a context option to this statement. @@ -995,7 +1046,7 @@ class Executable(roles.StatementRole, Generative): return self @_generative - def execution_options(self: SelfExecutable, **kw) -> SelfExecutable: + def execution_options(self: SelfExecutable, **kw: Any) -> SelfExecutable: """Set non-SQL options for the statement which take effect during execution. @@ -1112,7 +1163,7 @@ class Executable(roles.StatementRole, Generative): self._execution_options = self._execution_options.union(kw) return self - def get_execution_options(self): + def get_execution_options(self) -> _ExecuteOptions: """Get the non-SQL options which will take effect during execution. .. versionadded:: 1.3 @@ -1124,7 +1175,7 @@ class Executable(roles.StatementRole, Generative): return self._execution_options -class SchemaEventTarget: +class SchemaEventTarget(event.EventTarget): """Base class for elements that are the targets of :class:`.DDLEvents` events. @@ -1132,6 +1183,8 @@ class SchemaEventTarget: """ + dispatch: dispatcher[SchemaEventTarget] + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: """Associate with this SchemaEvent's parent object.""" @@ -1149,7 +1202,10 @@ class SchemaVisitor(ClauseVisitor): __traverse_options__ = {"schema_visitor": True} -class ColumnCollection: +_COL = TypeVar("_COL", bound="ColumnClause[Any]") + + +class ColumnCollection(Generic[_COL]): """Collection of :class:`_expression.ColumnElement` instances, typically for :class:`_sql.FromClause` objects. @@ -1260,32 +1316,36 @@ class ColumnCollection: __slots__ = "_collection", "_index", "_colset" - def __init__(self, columns=None): + _collection: List[Tuple[str, _COL]] + _index: Dict[Union[str, int], _COL] + _colset: Set[_COL] + + def __init__(self, columns: Optional[Iterable[Tuple[str, _COL]]] = None): object.__setattr__(self, "_colset", set()) object.__setattr__(self, "_index", {}) object.__setattr__(self, "_collection", []) if columns: self._initial_populate(columns) - def _initial_populate(self, iter_): + def _initial_populate(self, iter_: Iterable[Tuple[str, _COL]]) -> None: self._populate_separate_keys(iter_) @property - def _all_columns(self): + def _all_columns(self) -> List[_COL]: return [col for (k, col) in self._collection] - def keys(self): + def keys(self) -> List[str]: """Return a sequence of string key names for all columns in this collection.""" return [k for (k, col) in self._collection] - def values(self): + def values(self) -> List[_COL]: """Return a sequence of :class:`_sql.ColumnClause` or :class:`_schema.Column` objects for all columns in this collection.""" return [col for (k, col) in self._collection] - def items(self): + def items(self) -> List[Tuple[str, _COL]]: """Return a sequence of (key, column) tuples for all columns in this collection each consisting of a string key name and a :class:`_sql.ColumnClause` or @@ -1294,17 +1354,17 @@ class ColumnCollection: return list(self._collection) - def __bool__(self): + def __bool__(self) -> bool: return bool(self._collection) - def __len__(self): + def __len__(self) -> int: return len(self._collection) - def __iter__(self): + def __iter__(self) -> Iterator[_COL]: # turn to a list first to maintain over a course of changes return iter([col for k, col in self._collection]) - def __getitem__(self, key): + def __getitem__(self, key: Union[str, int]) -> _COL: try: return self._index[key] except KeyError as err: @@ -1313,13 +1373,13 @@ class ColumnCollection: else: raise - def __getattr__(self, key): + def __getattr__(self, key: str) -> _COL: try: return self._index[key] except KeyError as err: raise AttributeError(key) from err - def __contains__(self, key): + def __contains__(self, key: str) -> bool: if key not in self._index: if not isinstance(key, str): raise exc.ArgumentError( @@ -1329,7 +1389,7 @@ class ColumnCollection: else: return True - def compare(self, other): + def compare(self, other: ColumnCollection[Any]) -> bool: """Compare this :class:`_expression.ColumnCollection` to another based on the names of the keys""" @@ -1339,10 +1399,10 @@ class ColumnCollection: else: return True - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self.compare(other) - def get(self, key, default=None): + def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]: """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object based on a string key name from this :class:`_expression.ColumnCollection`.""" @@ -1352,39 +1412,40 @@ class ColumnCollection: else: return default - def __str__(self): + def __str__(self) -> str: return "%s(%s)" % ( self.__class__.__name__, ", ".join(str(c) for c in self), ) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> NoReturn: raise NotImplementedError() - def __delitem__(self, key): + def __delitem__(self, key: str) -> NoReturn: raise NotImplementedError() - def __setattr__(self, key, obj): + def __setattr__(self, key: str, obj: Any) -> NoReturn: raise NotImplementedError() - def clear(self): + def clear(self) -> NoReturn: """Dictionary clear() is not implemented for :class:`_sql.ColumnCollection`.""" raise NotImplementedError() - def remove(self, column): - """Dictionary remove() is not implemented for - :class:`_sql.ColumnCollection`.""" + def remove(self, column: Any) -> None: raise NotImplementedError() - def update(self, iter_): + def update(self, iter_: Any) -> NoReturn: """Dictionary update() is not implemented for :class:`_sql.ColumnCollection`.""" raise NotImplementedError() - __hash__ = None + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore - def _populate_separate_keys(self, iter_): + def _populate_separate_keys( + self, iter_: Iterable[Tuple[str, _COL]] + ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) self._collection[:] = cols @@ -1394,7 +1455,7 @@ class ColumnCollection: ) self._index.update({k: col for k, col in reversed(self._collection)}) - def add(self, column, key=None): + def add(self, column: _COL, key: Optional[str] = None) -> None: """Add a column to this :class:`_sql.ColumnCollection`. .. note:: @@ -1416,17 +1477,17 @@ class ColumnCollection: if key not in self._index: self._index[key] = column - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return {"_collection": self._collection, "_index": self._index} - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: object.__setattr__(self, "_index", state["_index"]) object.__setattr__(self, "_collection", state["_collection"]) object.__setattr__( self, "_colset", {col for k, col in self._collection} ) - def contains_column(self, col): + def contains_column(self, col: _COL) -> bool: """Checks if a column object exists in this collection""" if col not in self._colset: if isinstance(col, str): @@ -1438,13 +1499,15 @@ class ColumnCollection: else: return True - def as_immutable(self): + def as_immutable(self) -> ImmutableColumnCollection[_COL]: """Return an "immutable" form of this :class:`_sql.ColumnCollection`.""" return ImmutableColumnCollection(self) - def corresponding_column(self, column, require_embedded=False): + def corresponding_column( + self, column: _COL, require_embedded: bool = False + ) -> Optional[_COL]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from this :class:`_expression.ColumnCollection` @@ -1497,7 +1560,7 @@ class ColumnCollection: not require_embedded or embedded(expanded_proxy_set, target_set) ): - if col is None: + if col is None or intersect is None: # no corresponding column yet, pick this one. @@ -1542,7 +1605,7 @@ class ColumnCollection: return col -class DedupeColumnCollection(ColumnCollection): +class DedupeColumnCollection(ColumnCollection[_COL]): """A :class:`_expression.ColumnCollection` that maintains deduplicating behavior. @@ -1555,7 +1618,7 @@ class DedupeColumnCollection(ColumnCollection): """ - def add(self, column, key=None): + def add(self, column: _COL, key: Optional[str] = None) -> None: if key is not None and column.key != key: raise exc.ArgumentError( @@ -1589,7 +1652,9 @@ class DedupeColumnCollection(ColumnCollection): self._index[l] = column self._index[key] = column - def _populate_separate_keys(self, iter_): + def _populate_separate_keys( + self, iter_: Iterable[Tuple[str, _COL]] + ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) @@ -1614,10 +1679,10 @@ class DedupeColumnCollection(ColumnCollection): for col in replace_col: self.replace(col) - def extend(self, iter_): + def extend(self, iter_: Iterable[_COL]) -> None: self._populate_separate_keys((col.key, col) for col in iter_) - def remove(self, column): + def remove(self, column: _COL) -> None: if column not in self._colset: raise ValueError( "Can't remove column %r; column is not in this collection" @@ -1634,7 +1699,7 @@ class DedupeColumnCollection(ColumnCollection): # delete higher index del self._index[len(self._collection)] - def replace(self, column): + def replace(self, column: _COL) -> None: """add the given column to this collection, removing unaliased versions of this column as well as existing columns with the same key. @@ -1687,7 +1752,9 @@ class DedupeColumnCollection(ColumnCollection): self._index.update(self._collection) -class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection): +class ImmutableColumnCollection( + util.ImmutableContainer, ColumnCollection[_COL] +): __slots__ = ("_parent",) def __init__(self, collection): @@ -1701,12 +1768,19 @@ class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection): def __setstate__(self, state): parent = state["_parent"] - self.__init__(parent) + self.__init__(parent) # type: ignore - add = extend = remove = util.ImmutableContainer._immutable + def add(self, column: Any, key: Any = ...) -> Any: + self._immutable() + def extend(self, elements: Any) -> None: + self._immutable() -class ColumnSet(util.ordered_column_set): + def remove(self, item: Any) -> None: + self._immutable() + + +class ColumnSet(util.OrderedSet["ColumnClause[Any]"]): def contains_column(self, col): return col in self @@ -1714,9 +1788,6 @@ class ColumnSet(util.ordered_column_set): for col in cols: self.add(col) - def __add__(self, other): - return list(self) + list(other) - def __eq__(self, other): l = [] for c in other: @@ -1729,7 +1800,9 @@ class ColumnSet(util.ordered_column_set): return hash(tuple(x for x in self)) -def _entity_namespace(entity): +def _entity_namespace( + entity: Union[_HasEntityNamespace, ExternallyTraversible] +) -> _EntityNamespace: """Return the nearest .entity_namespace for the given entity. If not immediately available, does an iterate to find a sub-element @@ -1737,16 +1810,20 @@ def _entity_namespace(entity): """ try: - return entity.entity_namespace + return cast(_HasEntityNamespace, entity).entity_namespace except AttributeError: - for elem in visitors.iterate(entity): - if hasattr(elem, "entity_namespace"): + for elem in visitors.iterate(cast(ExternallyTraversible, entity)): + if _is_has_entity_namespace(elem): return elem.entity_namespace else: raise -def _entity_namespace_key(entity, key, default=NO_ARG): +def _entity_namespace_key( + entity: Union[_HasEntityNamespace, ExternallyTraversible], + key: str, + default: Union[SQLCoreOperations[Any], _NoArg] = NO_ARG, +) -> SQLCoreOperations[Any]: """Return an entry from an entity_namespace. @@ -1760,7 +1837,7 @@ def _entity_namespace_key(entity, key, default=NO_ARG): if default is not NO_ARG: return getattr(ns, key, default) else: - return getattr(ns, key) + return getattr(ns, key) # type: ignore except AttributeError as err: raise exc.InvalidRequestError( 'Entity namespace for "%s" has no property "%s"' % (entity, key) |
