summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-08 17:14:41 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-13 15:29:20 -0400
commit769fa67d842035dd852ab8b6a26ea3f110a51131 (patch)
tree5c121caca336071091c6f5ea4c54743c92d6458a
parent77fc8216a74e6b2d0efc6591c6c735687bd10002 (diff)
downloadsqlalchemy-769fa67d842035dd852ab8b6a26ea3f110a51131.tar.gz
pep-484: sqlalchemy.sql pass one
sqlalchemy.sql will require many passes to get all modules even gradually typed. Will have to pick and choose what modules can be strictly typed vs. which can be gradual. in this patch, emphasis is on visitors.py, cache_key.py, annotations.py for strict typing, compiler.py is on gradual typing but has much more structure, in particular where it connects with the outside world. The work within compiler.py also reached back out to engine/cursor.py , default.py quite a bit. References: #6810 Change-Id: I6e8a29f6013fd216e43d45091bc193f8be0368fd
-rw-r--r--lib/sqlalchemy/engine/cursor.py8
-rw-r--r--lib/sqlalchemy/engine/default.py28
-rw-r--r--lib/sqlalchemy/engine/interfaces.py10
-rw-r--r--lib/sqlalchemy/engine/result.py3
-rw-r--r--lib/sqlalchemy/exc.py5
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py5
-rw-r--r--lib/sqlalchemy/sql/_py_util.py25
-rw-r--r--lib/sqlalchemy/sql/annotation.py305
-rw-r--r--lib/sqlalchemy/sql/base.py17
-rw-r--r--lib/sqlalchemy/sql/cache_key.py354
-rw-r--r--lib/sqlalchemy/sql/coercions.py235
-rw-r--r--lib/sqlalchemy/sql/compiler.py447
-rw-r--r--lib/sqlalchemy/sql/dml.py8
-rw-r--r--lib/sqlalchemy/sql/elements.py34
-rw-r--r--lib/sqlalchemy/sql/functions.py3
-rw-r--r--lib/sqlalchemy/sql/roles.py33
-rw-r--r--lib/sqlalchemy/sql/schema.py6
-rw-r--r--lib/sqlalchemy/sql/selectable.py29
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py2
-rw-r--r--lib/sqlalchemy/sql/traversals.py64
-rw-r--r--lib/sqlalchemy/sql/util.py6
-rw-r--r--lib/sqlalchemy/sql/visitors.py598
-rw-r--r--lib/sqlalchemy/util/langhelpers.py10
-rw-r--r--lib/sqlalchemy/util/typing.py24
-rw-r--r--pyproject.toml33
25 files changed, 1579 insertions, 713 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index 821c0cb8e..f776e5975 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -604,18 +604,20 @@ class CursorResultMetaData(ResultMetaData):
cls,
result_columns: List[ResultColumnsEntry],
loose_column_name_matching: bool = False,
- ) -> Dict[Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]]:
+ ) -> Dict[
+ Union[str, object], Tuple[str, Tuple[Any, ...], TypeEngine[Any], int]
+ ]:
"""when matching cursor.description to a set of names that are present
in a Compiled object, as is the case with TextualSelect, get all the
names we expect might match those in cursor.description.
"""
d: Dict[
- Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]
+ Union[str, object],
+ Tuple[str, Tuple[Any, ...], TypeEngine[Any], int],
] = {}
for ridx, elem in enumerate(result_columns):
key = elem[RM_RENDERED_NAME]
-
if key in d:
# conflicting keyname - just add the column-linked objects
# to the existing record. if there is a duplicate column
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 2579f573c..c9fb1ebf2 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -46,6 +46,7 @@ from .interfaces import ExecutionContext
from .. import event
from .. import exc
from .. import pool
+from .. import TupleType
from .. import types as sqltypes
from .. import util
from ..sql import compiler
@@ -76,6 +77,8 @@ if typing.TYPE_CHECKING:
from ..sql.compiler import Compiled
from ..sql.compiler import ResultColumnsEntry
from ..sql.compiler import TypeCompiler
+ from ..sql.dml import DMLState
+ from ..sql.elements import BindParameter
from ..sql.schema import Column
from ..sql.type_api import TypeEngine
@@ -820,7 +823,7 @@ class DefaultExecutionContext(ExecutionContext):
cursor: DBAPICursor
compiled_parameters: List[_MutableCoreSingleExecuteParams]
parameters: _DBAPIMultiExecuteParams
- extracted_parameters: _CoreSingleExecuteParams
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]]
_empty_dict_params = cast("Mapping[str, Any]", util.EMPTY_DICT)
@@ -878,7 +881,7 @@ class DefaultExecutionContext(ExecutionContext):
compiled: SQLCompiler,
parameters: _CoreMultiExecuteParams,
invoked_statement: Executable,
- extracted_parameters: _CoreSingleExecuteParams,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]],
cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
) -> ExecutionContext:
"""Initialize execution context for a Compiled construct."""
@@ -1513,9 +1516,10 @@ class DefaultExecutionContext(ExecutionContext):
inputsizes, self.cursor, self.statement, self.parameters, self
)
- has_escaped_names = bool(compiled.escaped_bind_names)
- if has_escaped_names:
+ if compiled.escaped_bind_names:
escaped_bind_names = compiled.escaped_bind_names
+ else:
+ escaped_bind_names = None
if dialect.positional:
items = [
@@ -1535,17 +1539,18 @@ class DefaultExecutionContext(ExecutionContext):
if key in self._expanded_parameters:
if bindparam.type._is_tuple_type:
- num = len(bindparam.type.types)
+ tup_type = cast(TupleType, bindparam.type)
+ num = len(tup_type.types)
dbtypes = inputsizes[bindparam]
generic_inputsizes.extend(
(
(
escaped_bind_names.get(paramname, paramname)
- if has_escaped_names
+ if escaped_bind_names is not None
else paramname
),
dbtypes[idx % num],
- bindparam.type.types[idx % num],
+ tup_type.types[idx % num],
)
for idx, paramname in enumerate(
self._expanded_parameters[key]
@@ -1557,7 +1562,7 @@ class DefaultExecutionContext(ExecutionContext):
(
(
escaped_bind_names.get(paramname, paramname)
- if has_escaped_names
+ if escaped_bind_names is not None
else paramname
),
dbtype,
@@ -1570,7 +1575,7 @@ class DefaultExecutionContext(ExecutionContext):
escaped_name = (
escaped_bind_names.get(key, key)
- if has_escaped_names
+ if escaped_bind_names is not None
else key
)
@@ -1702,7 +1707,9 @@ class DefaultExecutionContext(ExecutionContext):
else:
assert column is not None
assert parameters is not None
- compile_state = cast(SQLCompiler, self.compiled).compile_state
+ compile_state = cast(
+ "DMLState", cast(SQLCompiler, self.compiled).compile_state
+ )
assert compile_state is not None
if (
isolate_multiinsert_groups
@@ -1715,6 +1722,7 @@ class DefaultExecutionContext(ExecutionContext):
else:
d = {column.key: parameters[column.key]}
index = 0
+ assert compile_state._dict_parameters is not None
keys = compile_state._dict_parameters.keys()
d.update(
(key, parameters["%s_m%d" % (key, index)]) for key in keys
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index e65546eb7..e13295d6d 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -54,9 +54,10 @@ if TYPE_CHECKING:
from ..sql.compiler import IdentifierPreparer
from ..sql.compiler import Linting
from ..sql.compiler import SQLCompiler
+ from ..sql.elements import BindParameter
from ..sql.elements import ClauseElement
from ..sql.schema import Column
- from ..sql.schema import ColumnDefault
+ from ..sql.schema import DefaultGenerator
from ..sql.schema import Sequence as Sequence_SchemaItem
from ..sql.sqltypes import Integer
from ..sql.type_api import TypeEngine
@@ -813,6 +814,9 @@ class Dialect(EventTarget):
"""
+ _supports_statement_cache: bool
+ """internal evaluation for supports_statement_cache"""
+
bind_typing = BindTyping.NONE
"""define a means of passing typing information to the database and/or
driver for bound parameters.
@@ -2269,7 +2273,7 @@ class ExecutionContext:
compiled: SQLCompiler,
parameters: _CoreMultiExecuteParams,
invoked_statement: Executable,
- extracted_parameters: _CoreSingleExecuteParams,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]],
cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
) -> ExecutionContext:
raise NotImplementedError()
@@ -2299,7 +2303,7 @@ class ExecutionContext:
def _exec_default(
self,
column: Optional[Column[Any]],
- default: ColumnDefault,
+ default: DefaultGenerator,
type_: Optional[TypeEngine[Any]],
) -> Any:
raise NotImplementedError()
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index 87d3cac1c..d428b8a9d 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -17,6 +17,7 @@ import typing
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import NoReturn
@@ -326,7 +327,7 @@ class SimpleResultMetaData(ResultMetaData):
def result_tuple(
fields: Sequence[str], extra: Optional[Any] = None
-) -> Callable[[_RawRowType], Row]:
+) -> Callable[[Iterable[Any]], Row]:
parent = SimpleResultMetaData(fields, extra)
return functools.partial(
Row, parent, parent._processors, parent._keymap, Row._default_key_style
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py
index cc78e0971..8f4b963eb 100644
--- a/lib/sqlalchemy/exc.py
+++ b/lib/sqlalchemy/exc.py
@@ -33,6 +33,7 @@ if typing.TYPE_CHECKING:
from .engine.interfaces import _DBAPIAnyExecuteParams
from .engine.interfaces import Dialect
from .sql.compiler import Compiled
+ from .sql.compiler import TypeCompiler
from .sql.elements import ClauseElement
if typing.TYPE_CHECKING:
@@ -221,8 +222,8 @@ class UnsupportedCompilationError(CompileError):
def __init__(
self,
- compiler: "Compiled",
- element_type: Type["ClauseElement"],
+ compiler: Union[Compiled, TypeCompiler],
+ element_type: Type[ClauseElement],
message: Optional[str] = None,
):
super(UnsupportedCompilationError, self).__init__(
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index 709c13c14..e490a4f03 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -59,6 +59,7 @@ from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
from ..util.typing import SupportsIndex
+from ..util.typing import SupportsKeysAndGetItem
if typing.TYPE_CHECKING:
from ..orm.attributes import InstrumentedAttribute
@@ -1660,7 +1661,9 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]):
return (item[0], self._get(item[1]))
@overload
- def update(self, __m: Mapping[_KT, _VT], **kwargs: _VT) -> None:
+ def update(
+ self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT
+ ) -> None:
...
@overload
diff --git a/lib/sqlalchemy/sql/_py_util.py b/lib/sqlalchemy/sql/_py_util.py
index 96e8f6b2c..9f18b882d 100644
--- a/lib/sqlalchemy/sql/_py_util.py
+++ b/lib/sqlalchemy/sql/_py_util.py
@@ -7,7 +7,16 @@
from __future__ import annotations
+import typing
+from typing import Any
from typing import Dict
+from typing import Tuple
+from typing import Union
+
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+ from .cache_key import CacheConst
class prefix_anon_map(Dict[str, str]):
@@ -22,16 +31,18 @@ class prefix_anon_map(Dict[str, str]):
"""
- def __missing__(self, key):
+ def __missing__(self, key: str) -> str:
(ident, derived) = key.split(" ", 1)
anonymous_counter = self.get(derived, 1)
- self[derived] = anonymous_counter + 1
+ self[derived] = anonymous_counter + 1 # type: ignore
value = f"{derived}_{anonymous_counter}"
self[key] = value
return value
-class cache_anon_map(Dict[int, str]):
+class cache_anon_map(
+ Dict[Union[int, "Literal[CacheConst.NO_CACHE]"], Union[Literal[True], str]]
+):
"""A map that creates new keys for missing key access.
Produces an incrementing sequence given a series of unique keys.
@@ -45,11 +56,13 @@ class cache_anon_map(Dict[int, str]):
_index = 0
- def get_anon(self, object_):
+ def get_anon(self, object_: Any) -> Tuple[str, bool]:
idself = id(object_)
if idself in self:
- return self[idself], True
+ s_val = self[idself]
+ assert s_val is not True
+ return s_val, True
else:
# inline of __missing__
self[idself] = id_ = str(self._index)
@@ -57,7 +70,7 @@ class cache_anon_map(Dict[int, str]):
return id_, False
- def __missing__(self, key):
+ def __missing__(self, key: int) -> str:
self[key] = val = str(self._index)
self._index += 1
return val
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index b76393ad6..7afc2de97 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -13,22 +13,77 @@ associations.
from __future__ import annotations
+import typing
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Mapping
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TypeVar
+
from . import operators
-from .base import HasCacheKey
-from .traversals import anon_map
+from .cache_key import HasCacheKey
+from .visitors import anon_map
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
from .. import util
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+ from .visitors import _TraverseInternalsType
+ from ..util.typing import Self
+
+_AnnotationDict = Mapping[str, Any]
+
+EMPTY_ANNOTATIONS: util.immutabledict[str, Any] = util.EMPTY_DICT
+
-EMPTY_ANNOTATIONS = util.immutabledict()
+SelfSupportsAnnotations = TypeVar(
+ "SelfSupportsAnnotations", bound="SupportsAnnotations"
+)
-class SupportsAnnotations:
+class SupportsAnnotations(ExternallyTraversible):
__slots__ = ()
- _annotations = EMPTY_ANNOTATIONS
+ _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS
+ proxy_set: Set[SupportsAnnotations]
+ _is_immutable: bool
+
+ def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations:
+ raise NotImplementedError()
+
+ @overload
+ def _deannotate(
+ self: SelfSupportsAnnotations,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfSupportsAnnotations:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> SupportsAnnotations:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = False,
+ ) -> SupportsAnnotations:
+ raise NotImplementedError()
@util.memoized_property
- def _annotations_cache_key(self):
+ def _annotations_cache_key(self) -> Tuple[Any, ...]:
anon_map_ = anon_map()
return (
"_annotations",
@@ -47,14 +102,22 @@ class SupportsAnnotations:
)
+SelfSupportsCloneAnnotations = TypeVar(
+ "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations"
+)
+
+
class SupportsCloneAnnotations(SupportsAnnotations):
- __slots__ = ()
+ if not typing.TYPE_CHECKING:
+ __slots__ = ()
- _clone_annotations_traverse_internals = [
+ _clone_annotations_traverse_internals: _TraverseInternalsType = [
("_annotations", InternalTraversal.dp_annotations_key)
]
- def _annotate(self, values):
+ def _annotate(
+ self: SelfSupportsCloneAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsCloneAnnotations:
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -65,7 +128,9 @@ class SupportsCloneAnnotations(SupportsAnnotations):
new.__dict__.pop("_generate_cache_key", None)
return new
- def _with_annotations(self, values):
+ def _with_annotations(
+ self: SelfSupportsCloneAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsCloneAnnotations:
"""return a copy of this ClauseElement with annotations
replaced by the given dictionary.
@@ -76,7 +141,27 @@ class SupportsCloneAnnotations(SupportsAnnotations):
new.__dict__.pop("_generate_cache_key", None)
return new
- def _deannotate(self, values=None, clone=False):
+ @overload
+ def _deannotate(
+ self: SelfSupportsAnnotations,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfSupportsAnnotations:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> SupportsAnnotations:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = False,
+ ) -> SupportsAnnotations:
"""return a copy of this :class:`_expression.ClauseElement`
with annotations
removed.
@@ -96,24 +181,52 @@ class SupportsCloneAnnotations(SupportsAnnotations):
return self
+SelfSupportsWrappingAnnotations = TypeVar(
+ "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations"
+)
+
+
class SupportsWrappingAnnotations(SupportsAnnotations):
__slots__ = ()
- def _annotate(self, values):
+ _constructor: Callable[..., SupportsWrappingAnnotations]
+ entity_namespace: Mapping[str, Any]
+
+ def _annotate(self, values: _AnnotationDict) -> Annotated:
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
"""
- return Annotated(self, values)
+ return Annotated._as_annotated_instance(self, values)
- def _with_annotations(self, values):
+ def _with_annotations(self, values: _AnnotationDict) -> Annotated:
"""return a copy of this ClauseElement with annotations
replaced by the given dictionary.
"""
- return Annotated(self, values)
-
- def _deannotate(self, values=None, clone=False):
+ return Annotated._as_annotated_instance(self, values)
+
+ @overload
+ def _deannotate(
+ self: SelfSupportsAnnotations,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfSupportsAnnotations:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> SupportsAnnotations:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = False,
+ ) -> SupportsAnnotations:
"""return a copy of this :class:`_expression.ClauseElement`
with annotations
removed.
@@ -129,8 +242,11 @@ class SupportsWrappingAnnotations(SupportsAnnotations):
return self
-class Annotated:
- """clones a SupportsAnnotated and applies an 'annotations' dictionary.
+SelfAnnotated = TypeVar("SelfAnnotated", bound="Annotated")
+
+
+class Annotated(SupportsAnnotations):
+ """clones a SupportsAnnotations and applies an 'annotations' dictionary.
Unlike regular clones, this clone also mimics __hash__() and
__cmp__() of the original element so that it takes its place
@@ -151,21 +267,26 @@ class Annotated:
_is_column_operators = False
- def __new__(cls, *args):
- if not args:
- # clone constructor
- return object.__new__(cls)
- else:
- element, values = args
- # pull appropriate subclass from registry of annotated
- # classes
- try:
- cls = annotated_classes[element.__class__]
- except KeyError:
- cls = _new_annotation_type(element.__class__, cls)
- return object.__new__(cls)
-
- def __init__(self, element, values):
+ @classmethod
+ def _as_annotated_instance(
+ cls, element: SupportsWrappingAnnotations, values: _AnnotationDict
+ ) -> Annotated:
+ try:
+ cls = annotated_classes[element.__class__]
+ except KeyError:
+ cls = _new_annotation_type(element.__class__, cls)
+ return cls(element, values)
+
+ _annotations: util.immutabledict[str, Any]
+ __element: SupportsWrappingAnnotations
+ _hash: int
+
+ def __new__(cls: Type[SelfAnnotated], *args: Any) -> SelfAnnotated:
+ return object.__new__(cls)
+
+ def __init__(
+ self, element: SupportsWrappingAnnotations, values: _AnnotationDict
+ ):
self.__dict__ = element.__dict__.copy()
self.__dict__.pop("_annotations_cache_key", None)
self.__dict__.pop("_generate_cache_key", None)
@@ -173,11 +294,15 @@ class Annotated:
self._annotations = util.immutabledict(values)
self._hash = hash(element)
- def _annotate(self, values):
+ def _annotate(
+ self: SelfAnnotated, values: _AnnotationDict
+ ) -> SelfAnnotated:
_values = self._annotations.union(values)
return self._with_annotations(_values)
- def _with_annotations(self, values):
+ def _with_annotations(
+ self: SelfAnnotated, values: util.immutabledict[str, Any]
+ ) -> SelfAnnotated:
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
clone.__dict__.pop("_annotations_cache_key", None)
@@ -185,7 +310,27 @@ class Annotated:
clone._annotations = values
return clone
- def _deannotate(self, values=None, clone=True):
+ @overload
+ def _deannotate(
+ self: SelfAnnotated,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfAnnotated:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> Annotated:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = True,
+ ) -> SupportsAnnotations:
if values is None:
return self.__element
else:
@@ -199,14 +344,18 @@ class Annotated:
)
)
- def _compiler_dispatch(self, visitor, **kw):
- return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
+ if not typing.TYPE_CHECKING:
+ # manually proxy some methods that need extra attention
+ def _compiler_dispatch(self, visitor: Any, **kw: Any) -> Any:
+ return self.__element.__class__._compiler_dispatch(
+ self, visitor, **kw
+ )
- @property
- def _constructor(self):
- return self.__element._constructor
+ @property
+ def _constructor(self):
+ return self.__element._constructor
- def _clone(self, **kw):
+ def _clone(self: SelfAnnotated, **kw: Any) -> SelfAnnotated:
clone = self.__element._clone(**kw)
if clone is self.__element:
# detect immutable, don't change anything
@@ -217,22 +366,25 @@ class Annotated:
clone.__dict__.update(self.__dict__)
return self.__class__(clone, self._annotations)
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Type[Annotated], Tuple[Any, ...]]:
return self.__class__, (self.__element, self._annotations)
- def __hash__(self):
+ def __hash__(self) -> int:
return self._hash
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if self._is_column_operators:
return self.__element.__class__.__eq__(self, other)
else:
return hash(other) == hash(self)
@property
- def entity_namespace(self):
+ def entity_namespace(self) -> Mapping[str, Any]:
if "entity_namespace" in self._annotations:
- return self._annotations["entity_namespace"].entity_namespace
+ return cast(
+ SupportsWrappingAnnotations,
+ self._annotations["entity_namespace"],
+ ).entity_namespace
else:
return self.__element.entity_namespace
@@ -242,12 +394,19 @@ class Annotated:
# so that the resulting objects are pickleable; additionally, other
# decisions can be made up front about the type of object being annotated
# just once per class rather than per-instance.
-annotated_classes = {}
+annotated_classes: Dict[
+ Type[SupportsWrappingAnnotations], Type[Annotated]
+] = {}
+
+_SA = TypeVar("_SA", bound="SupportsAnnotations")
def _deep_annotate(
- element, annotations, exclude=None, detect_subquery_cols=False
-):
+ element: _SA,
+ annotations: _AnnotationDict,
+ exclude: Optional[Sequence[SupportsAnnotations]] = None,
+ detect_subquery_cols: bool = False,
+) -> _SA:
"""Deep copy the given ClauseElement, annotating each element
with the given annotations dictionary.
@@ -258,9 +417,9 @@ def _deep_annotate(
# annotated objects hack the __hash__() method so if we want to
# uniquely process them we have to use id()
- cloned_ids = {}
+ cloned_ids: Dict[int, SupportsAnnotations] = {}
- def clone(elem, **kw):
+ def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations:
kw["detect_subquery_cols"] = detect_subquery_cols
id_ = id(elem)
@@ -285,17 +444,20 @@ def _deep_annotate(
return newelem
if element is not None:
- element = clone(element)
- clone = None # remove gc cycles
+ element = cast(_SA, clone(element))
+ clone = None # type: ignore # remove gc cycles
return element
-def _deep_deannotate(element, values=None):
+def _deep_deannotate(
+ element: _SA, values: Optional[Sequence[str]] = None
+) -> _SA:
"""Deep copy the given element, removing annotations."""
- cloned = {}
+ cloned: Dict[Any, SupportsAnnotations] = {}
- def clone(elem, **kw):
+ def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations:
+ key: Any
if values:
key = id(elem)
else:
@@ -310,12 +472,14 @@ def _deep_deannotate(element, values=None):
return cloned[key]
if element is not None:
- element = clone(element)
- clone = None # remove gc cycles
+ element = cast(_SA, clone(element))
+ clone = None # type: ignore # remove gc cycles
return element
-def _shallow_annotate(element, annotations):
+def _shallow_annotate(
+ element: SupportsAnnotations, annotations: _AnnotationDict
+) -> SupportsAnnotations:
"""Annotate the given ClauseElement and copy its internals so that
internal objects refer to the new annotated object.
@@ -328,7 +492,13 @@ def _shallow_annotate(element, annotations):
return element
-def _new_annotation_type(cls, base_cls):
+def _new_annotation_type(
+ cls: Type[SupportsWrappingAnnotations], base_cls: Type[Annotated]
+) -> Type[Annotated]:
+ """Generates a new class that subclasses Annotated and proxies a given
+ element type.
+
+ """
if issubclass(cls, Annotated):
return cls
elif cls in annotated_classes:
@@ -342,8 +512,9 @@ def _new_annotation_type(cls, base_cls):
base_cls = annotated_classes[super_]
break
- annotated_classes[cls] = anno_cls = type(
- "Annotated%s" % cls.__name__, (base_cls, cls), {}
+ annotated_classes[cls] = anno_cls = cast(
+ Type[Annotated],
+ type("Annotated%s" % cls.__name__, (base_cls, cls), {}),
)
globals()["Annotated%s" % cls.__name__] = anno_cls
@@ -359,13 +530,15 @@ def _new_annotation_type(cls, base_cls):
# some classes include this even if they have traverse_internals
# e.g. BindParameter, add it if present.
if cls.__dict__.get("inherit_cache", False):
- anno_cls.inherit_cache = True
+ anno_cls.inherit_cache = True # type: ignore
anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
return anno_cls
-def _prepare_annotations(target_hierarchy, base_cls):
+def _prepare_annotations(
+ target_hierarchy: Type[SupportsAnnotations], base_cls: Type[Annotated]
+) -> None:
for cls in util.walk_subclasses(target_hierarchy):
_new_annotation_type(cls, base_cls)
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index a94590da1..a408a010a 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -19,8 +19,10 @@ from itertools import zip_longest
import operator
import re
import typing
+from typing import MutableMapping
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import TypeVar
from . import roles
@@ -36,14 +38,9 @@ from .. import util
from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
-from ..util._has_cy import HAS_CYEXTENSION
-
-if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
- from ._py_util import prefix_anon_map # noqa
-else:
- from sqlalchemy.cyextension.util import prefix_anon_map # noqa
if typing.TYPE_CHECKING:
+ from .elements import ColumnElement
from ..engine import Connection
from ..engine import Result
from ..engine.interfaces import _CoreMultiExecuteParams
@@ -63,6 +60,8 @@ NO_ARG = util.symbol("NO_ARG")
# symbols, mypy reports: "error: _Fn? not callable"
_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_AmbiguousTableNameMap = MutableMapping[str, str]
+
class Immutable:
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
@@ -87,6 +86,10 @@ class SingletonConstant(Immutable):
_is_singleton_constant = True
+ _singleton: SingletonConstant
+
+ proxy_set: Set[ColumnElement]
+
def __new__(cls, *arg, **kw):
return cls._singleton
@@ -519,6 +522,8 @@ class CompileState:
plugins = {}
+ _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
+
@classmethod
def create_for_statement(cls, statement, compiler, **kw):
# factory construction.
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index ff659b77d..fca58f98e 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -11,21 +11,41 @@ import enum
from itertools import zip_longest
import typing
from typing import Any
-from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
from typing import NamedTuple
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
from typing import Union
from .visitors import anon_map
-from .visitors import ExtendedInternalTraversal
+from .visitors import HasTraversalDispatch
+from .visitors import HasTraverseInternals
from .visitors import InternalTraversal
+from .visitors import prefix_anon_map
from .. import util
from ..inspection import inspect
from ..util import HasMemoized
from ..util.typing import Literal
-
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from .elements import BindParameter
+ from .elements import ClauseElement
+ from .visitors import _TraverseInternalsType
+ from ..engine.base import _CompiledCacheType
+ from ..engine.interfaces import _CoreSingleExecuteParams
+
+
+class _CacheKeyTraversalDispatchType(Protocol):
+ def __call__(
+ s, self: HasCacheKey, visitor: _CacheKeyTraversal
+ ) -> CacheKey:
+ ...
class CacheConst(enum.Enum):
@@ -70,7 +90,9 @@ class HasCacheKey:
__slots__ = ()
- _cache_key_traversal = NO_CACHE
+ _cache_key_traversal: Union[
+ _TraverseInternalsType, Literal[CacheConst.NO_CACHE]
+ ] = NO_CACHE
_is_has_cache_key = True
@@ -83,7 +105,7 @@ class HasCacheKey:
"""
- inherit_cache = None
+ inherit_cache: Optional[bool] = None
"""Indicate if this :class:`.HasCacheKey` instance should make use of the
cache key generation scheme used by its immediate superclass.
@@ -106,8 +128,12 @@ class HasCacheKey:
__slots__ = ()
+ _generated_cache_key_traversal: Any
+
@classmethod
- def _generate_cache_attrs(cls):
+ def _generate_cache_attrs(
+ cls,
+ ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]:
"""generate cache key dispatcher for a new class.
This sets the _generated_cache_key_traversal attribute once called
@@ -121,8 +147,11 @@ class HasCacheKey:
_cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
if _cache_key_traversal is None:
try:
- # this would be HasTraverseInternals
- _cache_key_traversal = cls._traverse_internals
+ # check for _traverse_internals, which is part of
+ # HasTraverseInternals
+ _cache_key_traversal = cast(
+ "Type[HasTraverseInternals]", cls
+ )._traverse_internals
except AttributeError:
cls._generated_cache_key_traversal = NO_CACHE
return NO_CACHE
@@ -138,7 +167,9 @@ class HasCacheKey:
# more complicated, so for the moment this is a little less
# efficient on startup but simpler.
return _cache_key_traversal_visitor.generate_dispatch(
- cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ cls,
+ _cache_key_traversal,
+ "_generated_cache_key_traversal",
)
else:
_cache_key_traversal = cls.__dict__.get(
@@ -170,11 +201,15 @@ class HasCacheKey:
return NO_CACHE
return _cache_key_traversal_visitor.generate_dispatch(
- cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ cls,
+ _cache_key_traversal,
+ "_generated_cache_key_traversal",
)
@util.preload_module("sqlalchemy.sql.elements")
- def _gen_cache_key(self, anon_map, bindparams):
+ def _gen_cache_key(
+ self, anon_map: anon_map, bindparams: List[BindParameter[Any]]
+ ) -> Optional[Tuple[Any, ...]]:
"""return an optional cache key.
The cache key is a tuple which can contain any series of
@@ -202,15 +237,15 @@ class HasCacheKey:
dispatcher: Union[
Literal[CacheConst.NO_CACHE],
- Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"],
+ _CacheKeyTraversalDispatchType,
]
try:
dispatcher = cls.__dict__["_generated_cache_key_traversal"]
except KeyError:
- # most of the dispatchers are generated up front
- # in sqlalchemy/sql/__init__.py ->
- # traversals.py-> _preconfigure_traversals().
+ # traversals.py -> _preconfigure_traversals()
+ # may be used to run these ahead of time, but
+ # is not enabled right now.
# this block will generate any remaining dispatchers.
dispatcher = cls._generate_cache_attrs()
@@ -218,7 +253,7 @@ class HasCacheKey:
anon_map[NO_CACHE] = True
return None
- result = (id_, cls)
+ result: Tuple[Any, ...] = (id_, cls)
# inline of _cache_key_traversal_visitor.run_generated_dispatch()
@@ -268,7 +303,7 @@ class HasCacheKey:
# Columns, this should be long lived. For select()
# statements, not so much, but they usually won't have
# annotations.
- result += self._annotations_cache_key
+ result += self._annotations_cache_key # type: ignore
elif (
meth is InternalTraversal.dp_clauseelement_list
or meth is InternalTraversal.dp_clauseelement_tuple
@@ -290,7 +325,7 @@ class HasCacheKey:
)
return result
- def _generate_cache_key(self):
+ def _generate_cache_key(self) -> Optional[CacheKey]:
"""return a cache key.
The cache key is a tuple which can contain any series of
@@ -322,32 +357,40 @@ class HasCacheKey:
"""
- bindparams = []
+ bindparams: List[BindParameter[Any]] = []
_anon_map = anon_map()
key = self._gen_cache_key(_anon_map, bindparams)
if NO_CACHE in _anon_map:
return None
else:
+ assert key is not None
return CacheKey(key, bindparams)
@classmethod
- def _generate_cache_key_for_object(cls, obj):
- bindparams = []
+ def _generate_cache_key_for_object(
+ cls, obj: HasCacheKey
+ ) -> Optional[CacheKey]:
+ bindparams: List[BindParameter[Any]] = []
_anon_map = anon_map()
key = obj._gen_cache_key(_anon_map, bindparams)
if NO_CACHE in _anon_map:
return None
else:
+ assert key is not None
return CacheKey(key, bindparams)
+class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey):
+ pass
+
+
class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
__slots__ = ()
@HasMemoized.memoized_instancemethod
- def _generate_cache_key(self):
+ def _generate_cache_key(self) -> Optional[CacheKey]:
return HasCacheKey._generate_cache_key(self)
@@ -362,14 +405,22 @@ class CacheKey(NamedTuple):
"""
key: Tuple[Any, ...]
- bindparams: Sequence[BindParameter]
+ bindparams: Sequence[BindParameter[Any]]
- def __hash__(self):
+ # can't set __hash__ attribute because it interferes
+ # with namedtuple
+ # can't use "if not TYPE_CHECKING" because mypy rejects it
+ # inside of a NamedTuple
+ def __hash__(self) -> Optional[int]: # type: ignore
"""CacheKey itself is not hashable - hash the .key portion"""
-
return None
- def to_offline_string(self, statement_cache, statement, parameters):
+ def to_offline_string(
+ self,
+ statement_cache: _CompiledCacheType,
+ statement: ClauseElement,
+ parameters: _CoreSingleExecuteParams,
+ ) -> str:
"""Generate an "offline string" form of this :class:`.CacheKey`
The "offline string" is basically the string SQL for the
@@ -400,21 +451,21 @@ class CacheKey(NamedTuple):
return repr((sql_str, param_tuple))
- def __eq__(self, other):
- return self.key == other.key
+ def __eq__(self, other: Any) -> bool:
+ return bool(self.key == other.key)
@classmethod
- def _diff_tuples(cls, left, right):
+ def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str:
ck1 = CacheKey(left, [])
ck2 = CacheKey(right, [])
return ck1._diff(ck2)
- def _whats_different(self, other):
+ def _whats_different(self, other: CacheKey) -> Iterator[str]:
k1 = self.key
k2 = other.key
- stack = []
+ stack: List[int] = []
pickup_index = 0
while True:
s1, s2 = k1, k2
@@ -440,11 +491,11 @@ class CacheKey(NamedTuple):
pickup_index = stack.pop(-1)
break
- def _diff(self, other):
+ def _diff(self, other: CacheKey) -> str:
return ", ".join(self._whats_different(other))
- def __str__(self):
- stack = [self.key]
+ def __str__(self) -> str:
+ stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key]
output = []
sentinel = object()
@@ -473,15 +524,15 @@ class CacheKey(NamedTuple):
return "CacheKey(key=%s)" % ("\n".join(output),)
- def _generate_param_dict(self):
+ def _generate_param_dict(self) -> Dict[str, Any]:
"""used for testing"""
- from .compiler import prefix_anon_map
-
_anon_map = prefix_anon_map()
return {b.key % _anon_map: b.effective_value for b in self.bindparams}
- def _apply_params_to_element(self, original_cache_key, target_element):
+ def _apply_params_to_element(
+ self, original_cache_key: CacheKey, target_element: ClauseElement
+ ) -> ClauseElement:
translate = {
k.key: v.value
for k, v in zip(original_cache_key.bindparams, self.bindparams)
@@ -490,7 +541,7 @@ class CacheKey(NamedTuple):
return target_element.params(translate)
-class _CacheKeyTraversal(ExtendedInternalTraversal):
+class _CacheKeyTraversal(HasTraversalDispatch):
# very common elements are inlined into the main _get_cache_key() method
# to produce a dramatic savings in Python function call overhead
@@ -512,17 +563,43 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
visit_propagate_attrs = PROPAGATE_ATTRS
def visit_with_context_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple((fn.__code__, c_key) for fn, c_key in obj)
- def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_inspectable(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
- def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_string_list(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple(obj)
- def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_multi(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
obj._gen_cache_key(anon_map, bindparams)
@@ -530,7 +607,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
else obj,
)
- def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_multi_list(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -542,8 +626,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_has_cache_key_tuples(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -558,8 +647,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_has_cache_key_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -568,8 +662,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_executable_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -582,22 +681,37 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_inspectable_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return self.visit_has_cache_key_list(
attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
)
def visit_clauseelement_tuples(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return self.visit_has_cache_key_tuples(
attrname, obj, parent, anon_map, bindparams
)
def visit_fromclause_ordered_set(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -606,8 +720,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_clauseelement_unordered_set(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
cache_keys = [
@@ -621,13 +740,23 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_named_ddl_element(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, obj.name)
def visit_prefix_sequence(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
@@ -642,8 +771,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_setup_join_tuple(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple(
(
target._gen_cache_key(anon_map, bindparams),
@@ -659,8 +793,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_table_hint_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
@@ -678,12 +817,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
),
)
- def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_plain_dict(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
def visit_dialect_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -701,8 +852,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_string_clauseelement_dict(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -712,8 +868,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_string_multi_dict(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -728,8 +889,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_fromclause_canonical_column_collection(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# inlining into the internals of ColumnCollection
return (
attrname,
@@ -740,14 +906,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_unknown_structure(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
anon_map[NO_CACHE] = True
return ()
def visit_dml_ordered_values(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -761,7 +937,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
),
)
- def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_dml_values(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# in py37 we can assume two dictionaries created in the same
# insert ordering will retain that sorting
return (
@@ -778,8 +961,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_dml_multi_values(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# multivalues are simply not cacheable right now
anon_map[NO_CACHE] = True
return ()
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index d616417ab..834bfb75d 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -13,6 +13,9 @@ import re
import typing
from typing import Any
from typing import Any as TODO_Any
+from typing import Dict
+from typing import List
+from typing import NoReturn
from typing import Optional
from typing import Type
from typing import TypeVar
@@ -42,6 +45,7 @@ if typing.TYPE_CHECKING:
from . import selectable
from . import traversals
from .elements import ClauseElement
+ from .elements import ColumnClause
_SR = TypeVar("_SR", bound=roles.SQLRole)
_StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole)
@@ -252,7 +256,7 @@ def expect_col_expression_collection(role, expressions):
if isinstance(resolved, str):
strname = resolved = expr
else:
- cols = []
+ cols: List[ColumnClause[Any]] = []
visitors.traverse(resolved, {}, {"column": cols.append})
if cols:
column = cols[0]
@@ -266,7 +270,7 @@ class RoleImpl:
def _literal_coercion(self, element, **kw):
raise NotImplementedError()
- _post_coercion = None
+ _post_coercion: Any = None
_resolve_literal_only = False
_skip_clauseelement_for_target_match = False
@@ -276,19 +280,24 @@ class RoleImpl:
self._use_inspection = issubclass(role_class, roles.UsesInspection)
def _implicit_coercions(
- self, element, resolved, argname=None, **kw
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
) -> Any:
self._raise_for_expected(element, argname, resolved)
def _raise_for_expected(
self,
- element,
- argname=None,
- resolved=None,
- advice=None,
- code=None,
- err=None,
- ):
+ element: Any,
+ argname: Optional[str] = None,
+ resolved: Optional[Any] = None,
+ advice: Optional[str] = None,
+ code: Optional[str] = None,
+ err: Optional[Exception] = None,
+ **kw: Any,
+ ) -> NoReturn:
if resolved is not None and resolved is not element:
got = "%r object resolved from %r object" % (resolved, element)
else:
@@ -324,22 +333,20 @@ class _StringOnly:
_resolve_literal_only = True
-class _ReturnsStringKey:
+class _ReturnsStringKey(RoleImpl):
__slots__ = ()
- def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, str):
- return original_element
+ def _implicit_coercions(self, element, resolved, argname=None, **kw):
+ if isinstance(element, str):
+ return element
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, **kw):
return element
-class _ColumnCoercions:
+class _ColumnCoercions(RoleImpl):
__slots__ = ()
def _warn_for_scalar_subquery_coercion(self):
@@ -368,8 +375,12 @@ class _ColumnCoercions:
def _no_text_coercion(
- element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None
-):
+ element: Any,
+ argname: Optional[str] = None,
+ exc_cls: Type[exc.SQLAlchemyError] = exc.ArgumentError,
+ extra: Optional[str] = None,
+ err: Optional[Exception] = None,
+) -> NoReturn:
raise exc_cls(
"%(extra)sTextual SQL expression %(expr)r %(argname)sshould be "
"explicitly declared as text(%(expr)r)"
@@ -381,7 +392,7 @@ def _no_text_coercion(
) from err
-class _NoTextCoercion:
+class _NoTextCoercion(RoleImpl):
__slots__ = ()
def _literal_coercion(self, element, argname=None, **kw):
@@ -393,7 +404,7 @@ class _NoTextCoercion:
self._raise_for_expected(element, argname)
-class _CoerceLiterals:
+class _CoerceLiterals(RoleImpl):
__slots__ = ()
_coerce_consts = False
_coerce_star = False
@@ -440,12 +451,19 @@ class LiteralValueImpl(RoleImpl):
return element
-class _SelectIsNotFrom:
+class _SelectIsNotFrom(RoleImpl):
__slots__ = ()
def _raise_for_expected(
- self, element, argname=None, resolved=None, advice=None, **kw
- ):
+ self,
+ element: Any,
+ argname: Optional[str] = None,
+ resolved: Optional[Any] = None,
+ advice: Optional[str] = None,
+ code: Optional[str] = None,
+ err: Optional[Exception] = None,
+ **kw: Any,
+ ) -> NoReturn:
if (
not advice
and isinstance(element, roles.SelectStatementRole)
@@ -460,26 +478,33 @@ class _SelectIsNotFrom:
else:
code = None
- return super(_SelectIsNotFrom, self)._raise_for_expected(
+ super()._raise_for_expected(
element,
argname=argname,
resolved=resolved,
advice=advice,
code=code,
+ err=err,
**kw,
)
+ # never reached
+ assert False
class HasCacheKeyImpl(RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, traversals.HasCacheKey):
- return original_element
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, HasCacheKey):
+ return element
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, **kw):
return element
@@ -489,12 +514,16 @@ class ExecutableOptionImpl(RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, ExecutableOption):
- return original_element
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, ExecutableOption):
+ return element
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, **kw):
return element
@@ -560,8 +589,12 @@ class InElementImpl(RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_from_clause:
if (
isinstance(resolved, selectable.Alias)
@@ -573,7 +606,7 @@ class InElementImpl(RoleImpl):
self._warn_for_implicit_coercion(resolved)
return self._post_coercion(resolved.select(), **kw)
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _warn_for_implicit_coercion(self, elem):
util.warn(
@@ -586,12 +619,16 @@ class InElementImpl(RoleImpl):
if isinstance(element, collections_abc.Iterable) and not isinstance(
element, str
):
- non_literal_expressions = {}
+ non_literal_expressions: Dict[
+ Optional[operators.ColumnOperators[Any]],
+ operators.ColumnOperators[Any],
+ ] = {}
element = list(element)
for o in element:
if not _is_literal(o):
if not isinstance(o, operators.ColumnOperators):
self._raise_for_expected(element, **kw)
+
else:
non_literal_expressions[o] = o
elif o is None:
@@ -712,8 +749,12 @@ class GroupByImpl(ByOfImpl, RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if isinstance(resolved, roles.StrictFromClauseRole):
return elements.ClauseList(*resolved.c)
else:
@@ -748,12 +789,16 @@ class TruncatedLabelImpl(_StringOnly, RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, str):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, str):
return resolved
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, argname=None, **kw):
"""coerce the given value to :class:`._truncated_label`.
@@ -794,7 +839,13 @@ class DDLReferredColumnImpl(DDLConstraintColumnImpl):
class LimitOffsetImpl(RoleImpl):
__slots__ = ()
- def _implicit_coercions(self, element, resolved, argname=None, **kw):
+ def _implicit_coercions(
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved is None:
return None
else:
@@ -814,18 +865,22 @@ class LabeledColumnExprImpl(ExpressionElementImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if isinstance(resolved, roles.ExpressionElementRole):
return resolved.label(None)
else:
new = super(LabeledColumnExprImpl, self)._implicit_coercions(
- original_element, resolved, argname=argname, **kw
+ element, resolved, argname=argname, **kw
)
if isinstance(new, roles.ExpressionElementRole):
return new.label(None)
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl):
@@ -899,13 +954,17 @@ class StatementImpl(_CoerceLiterals, RoleImpl):
return resolved
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_lambda_element:
return resolved
else:
- return super(StatementImpl, self)._implicit_coercions(
- original_element, resolved, argname=argname, **kw
+ return super()._implicit_coercions(
+ element, resolved, argname=argname, **kw
)
@@ -913,12 +972,16 @@ class SelectStatementImpl(_NoTextCoercion, RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_text_clause:
return resolved.columns()
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class HasCTEImpl(ReturnsRowsImpl):
@@ -938,13 +1001,18 @@ class JoinTargetImpl(RoleImpl):
self._raise_for_expected(element, argname)
def _implicit_coercions(
- self, original_element, resolved, argname=None, legacy=False, **kw
- ):
- if isinstance(original_element, roles.JoinTargetRole):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ legacy: bool = False,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, roles.JoinTargetRole):
# note that this codepath no longer occurs as of
# #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match
# were set to False.
- return original_element
+ return element
elif legacy and resolved._is_select_statement:
util.warn_deprecated(
"Implicit coercion of SELECT and textual SELECT "
@@ -959,7 +1027,7 @@ class JoinTargetImpl(RoleImpl):
# in _ORMJoin->Join
return resolved
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
@@ -967,13 +1035,13 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
def _implicit_coercions(
self,
- original_element,
- resolved,
- argname=None,
- explicit_subquery=False,
- allow_select=True,
- **kw,
- ):
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ explicit_subquery: bool = False,
+ allow_select: bool = True,
+ **kw: Any,
+ ) -> Any:
if resolved._is_select_statement:
if explicit_subquery:
return resolved.subquery()
@@ -989,7 +1057,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
elif resolved._is_text_clause:
return resolved
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _post_coercion(self, element, deannotate=False, **kw):
if deannotate:
@@ -1003,12 +1071,13 @@ class StrictFromClauseImpl(FromClauseImpl):
def _implicit_coercions(
self,
- original_element,
- resolved,
- argname=None,
- allow_select=False,
- **kw,
- ):
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ explicit_subquery: bool = False,
+ allow_select: bool = False,
+ **kw: Any,
+ ) -> Any:
if resolved._is_select_statement and allow_select:
util.warn_deprecated(
"Implicit coercion of SELECT and textual SELECT constructs "
@@ -1019,7 +1088,7 @@ class StrictFromClauseImpl(FromClauseImpl):
)
return resolved._implicit_subquery
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class AnonymizedFromClauseImpl(StrictFromClauseImpl):
@@ -1045,8 +1114,12 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_from_clause:
if (
isinstance(resolved, selectable.Alias)
@@ -1056,7 +1129,7 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl):
else:
return resolved.select()
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class CompoundElementImpl(_NoTextCoercion, RoleImpl):
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 423c3d446..f28dceefc 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -35,14 +35,19 @@ from time import perf_counter
import typing
from typing import Any
from typing import Callable
+from typing import cast
from typing import Dict
+from typing import FrozenSet
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
from typing import Union
from . import base
@@ -54,19 +59,42 @@ from . import operators
from . import schema
from . import selectable
from . import sqltypes
+from .base import _from_objects
from .base import NO_ARG
-from .base import prefix_anon_map
from .elements import quoted_name
from .schema import Column
+from .sqltypes import TupleType
from .type_api import TypeEngine
+from .visitors import prefix_anon_map
from .. import exc
from .. import util
from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
+ from .annotation import _AnnotationDict
+ from .base import _AmbiguousTableNameMap
+ from .base import CompileState
+ from .cache_key import CacheKey
+ from .elements import BindParameter
+ from .elements import ColumnClause
+ from .elements import Label
+ from .functions import Function
+ from .selectable import Alias
+ from .selectable import AliasedReturnsRows
+ from .selectable import CompoundSelectState
from .selectable import CTE
from .selectable import FromClause
+ from .selectable import NamedFromClause
+ from .selectable import ReturnsRows
+ from .selectable import Select
+ from .selectable import SelectState
+ from ..engine.cursor import CursorResultMetaData
from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _MutableCoreSingleExecuteParams
+ from ..engine.interfaces import _SchemaTranslateMapType
from ..engine.result import _ProcessorType
_FromHintsType = Dict["FromClause", str]
@@ -236,7 +264,7 @@ OPERATORS = {
operators.nulls_last_op: " NULLS LAST",
}
-FUNCTIONS = {
+FUNCTIONS: Dict[Type[Function], str] = {
functions.coalesce: "coalesce",
functions.current_date: "CURRENT_DATE",
functions.current_time: "CURRENT_TIME",
@@ -298,8 +326,8 @@ class ResultColumnsEntry(NamedTuple):
name: str
"""column name, may be labeled"""
- objects: List[Any]
- """list of objects that should be able to locate this column
+ objects: Tuple[Any, ...]
+ """sequence of objects that should be able to locate this column
in a RowMapping. This is typically string names and aliases
as well as Column objects.
@@ -313,6 +341,17 @@ class ResultColumnsEntry(NamedTuple):
"""
+class _ResultMapAppender(Protocol):
+ def __call__(
+ self,
+ keyname: str,
+ name: str,
+ objects: Sequence[Any],
+ type_: TypeEngine[Any],
+ ) -> None:
+ ...
+
+
# integer indexes into ResultColumnsEntry used by cursor.py.
# some profiling showed integer access faster than named tuple
RM_RENDERED_NAME: Literal[0] = 0
@@ -321,6 +360,20 @@ RM_OBJECTS: Literal[2] = 2
RM_TYPE: Literal[3] = 3
+class _BaseCompilerStackEntry(TypedDict):
+ asfrom_froms: Set[FromClause]
+ correlate_froms: Set[FromClause]
+ selectable: ReturnsRows
+
+
+class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
+ compile_state: CompileState
+ need_result_map_for_nested: bool
+ need_result_map_for_compound: bool
+ select_0: ReturnsRows
+ insert_from_select: Select
+
+
class ExpandedState(NamedTuple):
statement: str
additional_parameters: _CoreSingleExecuteParams
@@ -427,21 +480,23 @@ class Compiled:
defaults.
"""
- _cached_metadata = None
+ _cached_metadata: Optional[CursorResultMetaData] = None
_result_columns: Optional[List[ResultColumnsEntry]] = None
- schema_translate_map = None
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None
- execution_options = util.EMPTY_DICT
+ execution_options: _ExecuteOptions = util.EMPTY_DICT
"""
Execution options propagated from the statement. In some cases,
sub-elements of the statement can modify these.
"""
- _annotations = util.EMPTY_DICT
+ preparer: IdentifierPreparer
+
+ _annotations: _AnnotationDict = util.EMPTY_DICT
- compile_state = None
+ compile_state: Optional[CompileState] = None
"""Optional :class:`.CompileState` object that maintains additional
state used by the compiler.
@@ -457,9 +512,21 @@ class Compiled:
"""
- cache_key = None
+ cache_key: Optional[CacheKey] = None
+ """The :class:`.CacheKey` that was generated ahead of creating this
+ :class:`.Compiled` object.
+
+ This is used for routines that need access to the original
+ :class:`.CacheKey` instance generated when the :class:`.Compiled`
+ instance was first cached, typically in order to reconcile
+ the original list of :class:`.BindParameter` objects with a
+ per-statement list that's generated on each call.
+
+ """
_gen_time: float
+ """Generation time of this :class:`.Compiled`, used for reporting
+ cache stats."""
def __init__(
self,
@@ -543,7 +610,11 @@ class Compiled:
return self.string or ""
- def construct_params(self, params=None, extracted_parameters=None):
+ def construct_params(
+ self,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ ) -> Optional[_MutableCoreSingleExecuteParams]:
"""Return the bind params for this compiled object.
:param params: a dict of string/object pairs whose values will
@@ -646,6 +717,17 @@ class SQLCompiler(Compiled):
isplaintext: bool = False
+ binds: Dict[str, BindParameter[Any]]
+ """a dictionary of bind parameter keys to BindParameter instances."""
+
+ bind_names: Dict[BindParameter[Any], str]
+ """a dictionary of BindParameter instances to "compiled" names
+ that are actually present in the generated SQL"""
+
+ stack: List[_CompilerStackEntry]
+ """major statements such as SELECT, INSERT, UPDATE, DELETE are
+ tracked in this stack using an entry format."""
+
result_columns: List[ResultColumnsEntry]
"""relates label names in the final SQL to a tuple of local
column/label name, ColumnElement object (if any) and
@@ -709,7 +791,7 @@ class SQLCompiler(Compiled):
"""
- insert_single_values_expr = None
+ insert_single_values_expr: Optional[str] = None
"""When an INSERT is compiled with a single set of parameters inside
a VALUES expression, the string is assigned here, where it can be
used for insert batching schemes to rewrite the VALUES expression.
@@ -718,19 +800,19 @@ class SQLCompiler(Compiled):
"""
- literal_execute_params = frozenset()
+ literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as literal values at statement
execution time.
"""
- post_compile_params = frozenset()
+ post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as bound parameter placeholders
at statement execution time.
"""
- escaped_bind_names = util.EMPTY_DICT
+ escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
"""Late escaping of bound parameter names that has to be converted
to the original name when looking in the parameter dictionary.
@@ -744,14 +826,25 @@ class SQLCompiler(Compiled):
"""if True, and this in insert, use cursor.lastrowid to populate
result.inserted_primary_key. """
- _cache_key_bind_match = None
+ _cache_key_bind_match: Optional[
+ Tuple[
+ Dict[
+ BindParameter[Any],
+ List[BindParameter[Any]],
+ ],
+ Dict[
+ str,
+ BindParameter[Any],
+ ],
+ ]
+ ] = None
"""a mapping that will relate the BindParameter object we compile
to those that are part of the extracted collection of parameters
in the cache key, if we were given a cache key.
"""
- positiontup: Optional[Sequence[str]] = None
+ positiontup: Optional[List[str]] = None
"""for a compiled construct that uses a positional paramstyle, will be
a sequence of strings, indicating the names of bound parameters in order.
@@ -768,6 +861,19 @@ class SQLCompiler(Compiled):
inline: bool = False
+ ctes: Optional[MutableMapping[CTE, str]]
+
+ # Detect same CTE references - Dict[(level, name), cte]
+ # Level is required for supporting nesting
+ ctes_by_level_name: Dict[Tuple[int, str], CTE]
+
+ # To retrieve key/level in ctes_by_level_name -
+ # Dict[cte_reference, (level, cte_name, cte_opts)]
+ level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
+
+ ctes_recursive: bool
+ cte_positional: Dict[CTE, List[str]]
+
def __init__(
self,
dialect,
@@ -804,10 +910,9 @@ class SQLCompiler(Compiled):
self.cache_key = cache_key
if cache_key:
- self._cache_key_bind_match = ckbm = {
- b.key: b for b in cache_key[1]
- }
- ckbm.update({b: [b] for b in cache_key[1]})
+ cksm = {b.key: b for b in cache_key[1]}
+ ckbm = {b: [b] for b in cache_key[1]}
+ self._cache_key_bind_match = (ckbm, cksm)
# compile INSERT/UPDATE defaults/sequences to expect executemany
# style execution, which may mean no pre-execute of defaults,
@@ -911,14 +1016,14 @@ class SQLCompiler(Compiled):
@property
def prefetch(self):
- return list(self.insert_prefetch + self.update_prefetch)
+ return list(self.insert_prefetch) + list(self.update_prefetch)
@util.memoized_property
def _global_attributes(self):
return {}
@util.memoized_instancemethod
- def _init_cte_state(self) -> None:
+ def _init_cte_state(self) -> MutableMapping[CTE, str]:
"""Initialize collections related to CTEs only if
a CTE is located, to save on the overhead of
these collections otherwise.
@@ -926,21 +1031,22 @@ class SQLCompiler(Compiled):
"""
# collect CTEs to tack on top of a SELECT
# To store the query to print - Dict[cte, text_query]
- self.ctes: MutableMapping[CTE, str] = util.OrderedDict()
+ ctes: MutableMapping[CTE, str] = util.OrderedDict()
+ self.ctes = ctes
# Detect same CTE references - Dict[(level, name), cte]
# Level is required for supporting nesting
- self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {}
+ self.ctes_by_level_name = {}
# To retrieve key/level in ctes_by_level_name -
# Dict[cte_reference, (level, cte_name, cte_opts)]
- self.level_name_by_cte: Dict[
- CTE, Tuple[int, str, selectable._CTEOpts]
- ] = {}
+ self.level_name_by_cte = {}
- self.ctes_recursive: bool = False
+ self.ctes_recursive = False
if self.positional:
- self.cte_positional: Dict[CTE, List[str]] = {}
+ self.cte_positional = {}
+
+ return ctes
@contextlib.contextmanager
def _nested_result(self):
@@ -985,7 +1091,7 @@ class SQLCompiler(Compiled):
if not bindparam.type._is_tuple_type
else tuple(
elem_type._cached_bind_processor(self.dialect)
- for elem_type in bindparam.type.types
+ for elem_type in cast(TupleType, bindparam.type).types
),
)
for bindparam in self.bind_names
@@ -1002,11 +1108,11 @@ class SQLCompiler(Compiled):
def construct_params(
self,
- params=None,
- _group_number=None,
- _check=True,
- extracted_parameters=None,
- ):
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ _group_number: Optional[int] = None,
+ _check: bool = True,
+ ) -> _MutableCoreSingleExecuteParams:
"""return a dictionary of bind parameter keys and values"""
has_escaped_names = bool(self.escaped_bind_names)
@@ -1018,15 +1124,17 @@ class SQLCompiler(Compiled):
# way. The parameters present in self.bind_names may be clones of
# these original cache key params in the case of DML but the .key
# will be guaranteed to match.
- try:
- orig_extracted = self.cache_key[1]
- except TypeError as err:
+ if self.cache_key is None:
raise exc.CompileError(
"This compiled object has no original cache key; "
"can't pass extracted_parameters to construct_params"
- ) from err
+ )
+ else:
+ orig_extracted = self.cache_key[1]
- ckbm = self._cache_key_bind_match
+ ckbm_tuple = self._cache_key_bind_match
+ assert ckbm_tuple is not None
+ ckbm, _ = ckbm_tuple
resolved_extracted = {
bind: extracted
for b, extracted in zip(orig_extracted, extracted_parameters)
@@ -1142,7 +1250,8 @@ class SQLCompiler(Compiled):
if bindparam.type._is_tuple_type:
inputsizes[bindparam] = [
- lookup_type(typ) for typ in bindparam.type.types
+ lookup_type(typ)
+ for typ in cast(TupleType, bindparam.type).types
]
else:
inputsizes[bindparam] = lookup_type(bindparam.type)
@@ -1164,7 +1273,7 @@ class SQLCompiler(Compiled):
def _process_parameters_for_postcompile(
self,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_MutableCoreSingleExecuteParams] = None,
_populate_self: bool = False,
) -> ExpandedState:
"""handle special post compile parameters.
@@ -1183,14 +1292,20 @@ class SQLCompiler(Compiled):
parameters = self.construct_params()
expanded_parameters = {}
+ positiontup: Optional[List[str]]
+
if self.positional:
positiontup = []
else:
positiontup = None
processors = self._bind_processors
+ single_processors = cast("Mapping[str, _ProcessorType]", processors)
+ tuple_processors = cast(
+ "Mapping[str, Sequence[_ProcessorType]]", processors
+ )
- new_processors = {}
+ new_processors: Dict[str, _ProcessorType] = {}
if self.positional and self._numeric_binds:
# I'm not familiar with any DBAPI that uses 'numeric'.
@@ -1203,8 +1318,8 @@ class SQLCompiler(Compiled):
"the 'numeric' paramstyle at this time."
)
- replacement_expressions = {}
- to_update_sets = {}
+ replacement_expressions: Dict[str, Any] = {}
+ to_update_sets: Dict[str, Any] = {}
# notes:
# *unescaped* parameter names in:
@@ -1213,9 +1328,12 @@ class SQLCompiler(Compiled):
# *escaped* parameter names in:
# construct_params(), replacement_expressions
- for name in (
- self.positiontup if self.positional else self.bind_names.values()
- ):
+ if self.positional and self.positiontup is not None:
+ names: Iterable[str] = self.positiontup
+ else:
+ names = self.bind_names.values()
+
+ for name in names:
escaped_name = (
self.escaped_bind_names.get(name, name)
if self.escaped_bind_names
@@ -1236,6 +1354,7 @@ class SQLCompiler(Compiled):
if parameter in self.post_compile_params:
if escaped_name in replacement_expressions:
to_update = to_update_sets[escaped_name]
+ values = None
else:
# we are removing the parameter from parameters
# because it is a list value, which is not expected by
@@ -1256,28 +1375,29 @@ class SQLCompiler(Compiled):
if not parameter.literal_execute:
parameters.update(to_update)
if parameter.type._is_tuple_type:
+ assert values is not None
new_processors.update(
(
"%s_%s_%s" % (name, i, j),
- processors[name][j - 1],
+ tuple_processors[name][j - 1],
)
for i, tuple_element in enumerate(values, 1)
- for j, value in enumerate(tuple_element, 1)
- if name in processors
- and processors[name][j - 1] is not None
+ for j, _ in enumerate(tuple_element, 1)
+ if name in tuple_processors
+ and tuple_processors[name][j - 1] is not None
)
else:
new_processors.update(
- (key, processors[name])
- for key, value in to_update
- if name in processors
+ (key, single_processors[name])
+ for key, _ in to_update
+ if name in single_processors
)
- if self.positional:
- positiontup.extend(name for name, value in to_update)
+ if positiontup is not None:
+ positiontup.extend(name for name, _ in to_update)
expanded_parameters[name] = [
- expand_key for expand_key, value in to_update
+ expand_key for expand_key, _ in to_update
]
- elif self.positional:
+ elif positiontup is not None:
positiontup.append(name)
def process_expanding(m):
@@ -1315,7 +1435,7 @@ class SQLCompiler(Compiled):
# special use cases.
self.string = expanded_state.statement
self._bind_processors.update(expanded_state.processors)
- self.positiontup = expanded_state.positiontup
+ self.positiontup = list(expanded_state.positiontup or ())
self.post_compile_params = frozenset()
for key in expanded_state.parameter_expansion:
bind = self.binds.pop(key)
@@ -1338,6 +1458,12 @@ class SQLCompiler(Compiled):
self._result_columns
)
+ _key_getters_for_crud_column: Tuple[
+ Callable[[Union[str, Column[Any]]], str],
+ Callable[[Column[Any]], str],
+ Callable[[Column[Any]], str],
+ ]
+
@util.memoized_property
def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
getter = self._key_getters_for_crud_column[2]
@@ -1398,22 +1524,30 @@ class SQLCompiler(Compiled):
@util.memoized_property
@util.preload_module("sqlalchemy.engine.result")
def _inserted_primary_key_from_returning_getter(self):
- result = util.preloaded.engine_result
+ if typing.TYPE_CHECKING:
+ from ..engine import result
+ else:
+ result = util.preloaded.engine_result
param_key_getter = self._within_exec_param_key_getter
table = self.statement.table
- ret = {col: idx for idx, col in enumerate(self.returning)}
+ returning = self.returning
+ assert returning is not None
+ ret = {col: idx for idx, col in enumerate(returning)}
- getters = [
- (operator.itemgetter(ret[col]), True)
- if col in ret
- else (
- operator.methodcaller("get", param_key_getter(col), None),
- False,
- )
- for col in table.primary_key
- ]
+ getters = cast(
+ "List[Tuple[Callable[[Any], Any], bool]]",
+ [
+ (operator.itemgetter(ret[col]), True)
+ if col in ret
+ else (
+ operator.methodcaller("get", param_key_getter(col), None),
+ False,
+ )
+ for col in table.primary_key
+ ],
+ )
row_fn = result.result_tuple([col.key for col in table.primary_key])
@@ -1444,7 +1578,16 @@ class SQLCompiler(Compiled):
self, element, within_columns_clause=False, **kwargs
):
if self.stack and self.dialect.supports_simple_order_by_label:
- compile_state = self.stack[-1]["compile_state"]
+ try:
+ compile_state = cast(
+ "Union[SelectState, CompoundSelectState]",
+ self.stack[-1]["compile_state"],
+ )
+ except KeyError as ke:
+ raise exc.CompileError(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ) from ke
(
with_cols,
@@ -1485,7 +1628,22 @@ class SQLCompiler(Compiled):
# compiling the element outside of the context of a SELECT
return self.process(element._text_clause)
- compile_state = self.stack[-1]["compile_state"]
+ try:
+ compile_state = cast(
+ "Union[SelectState, CompoundSelectState]",
+ self.stack[-1]["compile_state"],
+ )
+ except KeyError as ke:
+ coercions._no_text_coercion(
+ element.element,
+ extra=(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ),
+ exc_cls=exc.CompileError,
+ err=ke,
+ )
+
with_cols, only_froms, only_cols = compile_state._label_resolve_dict
try:
if within_columns_clause:
@@ -1568,13 +1726,13 @@ class SQLCompiler(Compiled):
def visit_column(
self,
- column,
- add_to_result_map=None,
- include_table=True,
- result_map_targets=(),
- ambiguous_table_name_map=None,
- **kwargs,
- ):
+ column: ColumnClause[Any],
+ add_to_result_map: Optional[_ResultMapAppender] = None,
+ include_table: bool = True,
+ result_map_targets: Tuple[Any, ...] = (),
+ ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
+ **kwargs: Any,
+ ) -> str:
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
@@ -1608,7 +1766,8 @@ class SQLCompiler(Compiled):
)
else:
schema_prefix = ""
- tablename = table.name
+
+ tablename = cast("NamedFromClause", table).name
if (
not effective_schema
@@ -1678,7 +1837,7 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- new_entry = {
+ new_entry: _CompilerStackEntry = {
"correlate_froms": set(),
"asfrom_froms": set(),
"selectable": taf,
@@ -1879,11 +2038,19 @@ class SQLCompiler(Compiled):
compiled_col = self.visit_column(element, **kw)
return "(%s).%s" % (compiled_fn, compiled_col)
- def visit_function(self, func, add_to_result_map=None, **kwargs):
+ def visit_function(
+ self,
+ func: Function,
+ add_to_result_map: Optional[_ResultMapAppender] = None,
+ **kwargs: Any,
+ ) -> str:
if add_to_result_map is not None:
add_to_result_map(func.name, func.name, (), func.type)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+
+ text: str
+
if disp:
text = disp(func, **kwargs)
else:
@@ -1964,7 +2131,7 @@ class SQLCompiler(Compiled):
if compound_stmt._independent_ctes:
self._dispatch_independent_ctes(compound_stmt, kwargs)
- keyword = self.compound_keywords.get(cs.keyword)
+ keyword = self.compound_keywords[cs.keyword]
text = (" " + keyword + " ").join(
(
@@ -2591,11 +2758,13 @@ class SQLCompiler(Compiled):
# a different set of parameter values. here, we accommodate for
# parameters that may have been cloned both before and after the cache
# key was been generated.
- ckbm = self._cache_key_bind_match
- if ckbm:
+ ckbm_tuple = self._cache_key_bind_match
+
+ if ckbm_tuple:
+ ckbm, cksm = ckbm_tuple
for bp in bindparam._cloned_set:
- if bp.key in ckbm:
- cb = ckbm[bp.key]
+ if bp.key in cksm:
+ cb = cksm[bp.key]
ckbm[cb].append(bindparam)
if bindparam.isoutparam:
@@ -2720,7 +2889,7 @@ class SQLCompiler(Compiled):
if positional_names is not None:
positional_names.append(name)
else:
- self.positiontup.append(name)
+ self.positiontup.append(name) # type: ignore[union-attr]
elif not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
@@ -2735,9 +2904,9 @@ class SQLCompiler(Compiled):
name = new_name
if escaped_from:
- if not self.escaped_bind_names:
- self.escaped_bind_names = {}
- self.escaped_bind_names[escaped_from] = name
+ self.escaped_bind_names = self.escaped_bind_names.union(
+ {escaped_from: name}
+ )
if post_compile:
return "__[POSTCOMPILE_%s]" % name
@@ -2772,7 +2941,8 @@ class SQLCompiler(Compiled):
cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
**kwargs: Any,
) -> Optional[str]:
- self._init_cte_state()
+ self_ctes = self._init_cte_state()
+ assert self_ctes is self.ctes
kwargs["visiting_cte"] = cte
@@ -2838,7 +3008,7 @@ class SQLCompiler(Compiled):
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
# discard the text for the "inner".
- del self.ctes[existing_cte]
+ del self_ctes[existing_cte]
existing_cte_reference_cte = existing_cte._get_reference_cte()
@@ -2875,7 +3045,7 @@ class SQLCompiler(Compiled):
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
- if not cte_pre_alias_name and cte not in self.ctes:
+ if not cte_pre_alias_name and cte not in self_ctes:
if cte.recursive:
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
@@ -2942,14 +3112,14 @@ class SQLCompiler(Compiled):
cte, cte._suffixes, **kwargs
)
- self.ctes[cte] = text
+ self_ctes[cte] = text
if asfrom:
if from_linter:
from_linter.froms[cte] = cte_name
if not is_new_cte and embedded_in_current_named_cte:
- return self.preparer.format_alias(cte, cte_name)
+ return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501
if cte_pre_alias_name:
text = self.preparer.format_alias(cte, cte_pre_alias_name)
@@ -2960,6 +3130,8 @@ class SQLCompiler(Compiled):
else:
return self.preparer.format_alias(cte, cte_name)
+ return None
+
def visit_table_valued_alias(self, element, **kw):
if element._is_lateral:
return self.visit_lateral(element, **kw)
@@ -3143,7 +3315,7 @@ class SQLCompiler(Compiled):
self,
keyname: str,
name: str,
- objects: List[Any],
+ objects: Tuple[Any, ...],
type_: TypeEngine[Any],
) -> None:
if keyname is None or keyname == "*":
@@ -3358,9 +3530,12 @@ class SQLCompiler(Compiled):
def get_statement_hint_text(self, hint_texts):
return " ".join(hint_texts)
- _default_stack_entry = util.immutabledict(
- [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
- )
+ _default_stack_entry: _CompilerStackEntry
+
+ if not typing.TYPE_CHECKING:
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
def _display_froms_for_select(
self, select_stmt, asfrom, lateral=False, **kw
@@ -3391,7 +3566,7 @@ class SQLCompiler(Compiled):
)
return froms
- translate_select_structure = None
+ translate_select_structure: Any = None
"""if not ``None``, should be a callable which accepts ``(select_stmt,
**kw)`` and returns a select object. this is used for structural changes
mostly to accommodate for LIMIT/OFFSET schemes
@@ -3563,7 +3738,9 @@ class SQLCompiler(Compiled):
)
self._result_columns = [
- (key, name, tuple(translate.get(o, o) for o in obj), type_)
+ ResultColumnsEntry(
+ key, name, tuple(translate.get(o, o) for o in obj), type_
+ )
for key, name, obj, type_ in self._result_columns
]
@@ -3660,10 +3837,10 @@ class SQLCompiler(Compiled):
implicit_correlate_froms=asfrom_froms,
)
- new_correlate_froms = set(selectable._from_objects(*froms))
+ new_correlate_froms = set(_from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
- new_entry = {
+ new_entry: _CompilerStackEntry = {
"asfrom_froms": new_correlate_froms,
"correlate_froms": all_correlate_froms,
"selectable": select,
@@ -3734,6 +3911,7 @@ class SQLCompiler(Compiled):
text += " \nWHERE " + t
if warn_linting:
+ assert from_linter is not None
from_linter.warn()
if select._group_by_clauses:
@@ -3781,6 +3959,8 @@ class SQLCompiler(Compiled):
if not self.ctes:
return ""
+ ctes: MutableMapping[CTE, str]
+
if nesting_level and nesting_level > 1:
ctes = util.OrderedDict()
for cte in list(self.ctes.keys()):
@@ -3805,10 +3985,16 @@ class SQLCompiler(Compiled):
ctes_recursive = any([cte.recursive for cte in ctes])
if self.positional:
+ assert self.positiontup is not None
self.positiontup = (
- sum([self.cte_positional[cte] for cte in ctes], [])
+ list(
+ itertools.chain.from_iterable(
+ self.cte_positional[cte] for cte in ctes
+ )
+ )
+ self.positiontup
)
+
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
@@ -4190,7 +4376,7 @@ class SQLCompiler(Compiled):
if is_multitable:
# main table might be a JOIN
- main_froms = set(selectable._from_objects(update_stmt.table))
+ main_froms = set(_from_objects(update_stmt.table))
render_extra_froms = [
f for f in extra_froms if f not in main_froms
]
@@ -4506,7 +4692,11 @@ class DDLCompiler(Compiled):
def type_compiler(self):
return self.dialect.type_compiler
- def construct_params(self, params=None, extracted_parameters=None):
+ def construct_params(
+ self,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ ) -> Optional[_MutableCoreSingleExecuteParams]:
return None
def visit_ddl(self, ddl, **kwargs):
@@ -5199,6 +5389,11 @@ class StrSQLTypeCompiler(GenericTypeCompiler):
return get_col_spec(**kw)
+class _SchemaForObjectCallable(Protocol):
+ def __call__(self, obj: Any) -> str:
+ ...
+
+
class IdentifierPreparer:
"""Handle quoting and case-folding of identifiers based on options."""
@@ -5209,7 +5404,13 @@ class IdentifierPreparer:
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
- schema_for_object = operator.attrgetter("schema")
+ initial_quote: str
+
+ final_quote: str
+
+ _strings: MutableMapping[str, str]
+
+ schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
"""Return the .schema attribute for an object.
For the default IdentifierPreparer, the schema for an object is always
@@ -5297,7 +5498,7 @@ class IdentifierPreparer:
return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
- def _escape_identifier(self, value):
+ def _escape_identifier(self, value: str) -> str:
"""Escape an identifier.
Subclasses should override this to provide database-dependent
@@ -5309,7 +5510,7 @@ class IdentifierPreparer:
value = value.replace("%", "%%")
return value
- def _unescape_identifier(self, value):
+ def _unescape_identifier(self, value: str) -> str:
"""Canonicalize an escaped identifier.
Subclasses should override this to provide database-dependent
@@ -5336,7 +5537,7 @@ class IdentifierPreparer:
)
return element
- def quote_identifier(self, value):
+ def quote_identifier(self, value: str) -> str:
"""Quote an identifier.
Subclasses should override this to provide database-dependent
@@ -5349,7 +5550,7 @@ class IdentifierPreparer:
+ self.final_quote
)
- def _requires_quotes(self, value):
+ def _requires_quotes(self, value: str) -> bool:
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
return (
@@ -5364,7 +5565,7 @@ class IdentifierPreparer:
not taking case convention into account."""
return not self.legal_characters.match(str(value))
- def quote_schema(self, schema, force=None):
+ def quote_schema(self, schema: str, force: Any = None) -> str:
"""Conditionally quote a schema name.
@@ -5403,7 +5604,7 @@ class IdentifierPreparer:
return self.quote(schema)
- def quote(self, ident, force=None):
+ def quote(self, ident: str, force: Any = None) -> str:
"""Conditionally quote an identifier.
The identifier is quoted if it is a reserved word, contains
@@ -5474,11 +5675,19 @@ class IdentifierPreparer:
name = self.quote_schema(effective_schema) + "." + name
return name
- def format_label(self, label, name=None):
+ def format_label(
+ self, label: Label[Any], name: Optional[str] = None
+ ) -> str:
return self.quote(name or label.name)
- def format_alias(self, alias, name=None):
- return self.quote(name or alias.name)
+ def format_alias(
+ self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
+ ) -> str:
+ if name is None:
+ assert alias is not None
+ return self.quote(alias.name)
+ else:
+ return self.quote(name)
def format_savepoint(self, savepoint, name=None):
# Running the savepoint name through quoting is unnecessary
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 5aded307b..96e90b0ea 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -13,6 +13,10 @@ from __future__ import annotations
import collections.abc as collections_abc
import typing
+from typing import Any
+from typing import List
+from typing import MutableMapping
+from typing import Optional
from . import coercions
from . import roles
@@ -40,8 +44,8 @@ from .. import util
class DMLState(CompileState):
_no_parameters = True
- _dict_parameters = None
- _multi_parameters = None
+ _dict_parameters: Optional[MutableMapping[str, Any]] = None
+ _multi_parameters: Optional[List[MutableMapping[str, Any]]] = None
_ordered_values = None
_parameter_ordering = None
_has_multi_parameters = False
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 168da17cc..08d632afd 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -18,7 +18,9 @@ import re
import typing
from typing import Any
from typing import Callable
+from typing import Dict
from typing import Generic
+from typing import List
from typing import Optional
from typing import overload
from typing import Sequence
@@ -47,6 +49,7 @@ from .coercions import _document_text_coercion # noqa
from .operators import ColumnOperators
from .traversals import HasCopyInternals
from .visitors import cloned_traverse
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
from .visitors import traverse
from .visitors import Visitable
@@ -68,6 +71,8 @@ if typing.TYPE_CHECKING:
from ..engine import Connection
from ..engine import Dialect
from ..engine import Engine
+ from ..engine.base import _CompiledCacheType
+ from ..engine.base import _SchemaTranslateMapType
_NUMERIC = Union[complex, "Decimal"]
@@ -238,6 +243,7 @@ class ClauseElement(
SupportsWrappingAnnotations,
MemoizedHasCacheKey,
HasCopyInternals,
+ ExternallyTraversible,
CompilerElement,
):
"""Base class for elements of a programmatically constructed SQL
@@ -398,7 +404,9 @@ class ClauseElement(
"""
return self._replace_params(True, optionaldict, kwargs)
- def params(self, *optionaldict, **kwargs):
+ def params(
+ self, *optionaldict: Dict[str, Any], **kwargs: Any
+ ) -> ClauseElement:
"""Return a copy with :func:`_expression.bindparam` elements
replaced.
@@ -415,7 +423,12 @@ class ClauseElement(
"""
return self._replace_params(False, optionaldict, kwargs)
- def _replace_params(self, unique, optionaldict, kwargs):
+ def _replace_params(
+ self,
+ unique: bool,
+ optionaldict: Optional[Dict[str, Any]],
+ kwargs: Dict[str, Any],
+ ) -> ClauseElement:
if len(optionaldict) == 1:
kwargs.update(optionaldict[0])
@@ -487,12 +500,12 @@ class ClauseElement(
def _compile_w_cache(
self,
- dialect,
- compiled_cache=None,
- column_keys=None,
- for_executemany=False,
- schema_translate_map=None,
- **kw,
+ dialect: Dialect,
+ compiled_cache: Optional[_CompiledCacheType] = None,
+ column_keys: Optional[List[str]] = None,
+ for_executemany: bool = False,
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+ **kw: Any,
):
if compiled_cache is not None and dialect._supports_statement_cache:
elem_cache_key = self._generate_cache_key()
@@ -1383,7 +1396,7 @@ class ColumnElement(
"""
return Cast(self, type_)
- def label(self, name):
+ def label(self, name: Optional[str]) -> Label[_T]:
"""Produce a column label, i.e. ``<columnname> AS <name>``.
This is a shortcut to the :func:`_expression.label` function.
@@ -1608,6 +1621,9 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
("value", InternalTraversal.dp_plain_obj),
]
+ key: str
+ type: TypeEngine
+
_is_crud = False
_is_bind_parameter = True
_key_is_anon = False
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index eb3d17ee4..6e5eec127 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -12,6 +12,7 @@
from __future__ import annotations
from typing import Any
+from typing import Sequence
from typing import TypeVar
from . import annotation
@@ -839,6 +840,8 @@ class Function(FunctionElement):
identifier: str
+ packagenames: Sequence[str]
+
type: TypeEngine = sqltypes.NULLTYPE
"""A :class:`_types.TypeEngine` object which refers to the SQL return
type represented by this SQL function.
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 64bd4b951..1a7a5f4d4 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -7,14 +7,22 @@
from __future__ import annotations
import typing
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
-from sqlalchemy.util.langhelpers import TypingOnly
from .. import util
-
+from ..util import TypingOnly
+from ..util.typing import Literal
if typing.TYPE_CHECKING:
+ from .base import ColumnCollection
from .elements import ClauseElement
+ from .elements import Label
from .selectable import FromClause
+ from .selectable import Subquery
class SQLRole:
@@ -35,7 +43,7 @@ class SQLRole:
class UsesInspection:
__slots__ = ()
- _post_inspect = None
+ _post_inspect: Literal[None] = None
uses_inspection = True
@@ -96,7 +104,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
_role_name = "Column expression or FROM clause"
@property
- def _select_iterable(self):
+ def _select_iterable(self) -> Sequence[ColumnsClauseRole]:
raise NotImplementedError()
@@ -150,6 +158,9 @@ class ExpressionElementRole(SQLRole):
__slots__ = ()
_role_name = "SQL expression element"
+ def label(self, name: Optional[str]) -> Label[Any]:
+ raise NotImplementedError()
+
class ConstExprRole(ExpressionElementRole):
__slots__ = ()
@@ -187,7 +198,7 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
_is_subquery = False
@property
- def _hide_froms(self):
+ def _hide_froms(self) -> Iterable[FromClause]:
raise NotImplementedError()
@@ -195,8 +206,10 @@ class StrictFromClauseRole(FromClauseRole):
__slots__ = ()
# does not allow text() or select() objects
+ c: ColumnCollection
+
@property
- def description(self):
+ def description(self) -> str:
raise NotImplementedError()
@@ -204,7 +217,9 @@ class AnonymizedFromClauseRole(StrictFromClauseRole):
__slots__ = ()
# calls .alias() as a post processor
- def _anonymous_fromclause(self, name=None, flat=False):
+ def _anonymous_fromclause(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> FromClause:
raise NotImplementedError()
@@ -220,14 +235,14 @@ class StatementRole(SQLRole):
__slots__ = ()
_role_name = "Executable SQL or text() construct"
- _propagate_attrs = util.immutabledict()
+ _propagate_attrs: Mapping[str, Any] = util.immutabledict()
class SelectStatementRole(StatementRole, ReturnsRowsRole):
__slots__ = ()
_role_name = "SELECT construct or equivalent text() construct"
- def subquery(self):
+ def subquery(self) -> Subquery:
raise NotImplementedError(
"All SelectStatementRole objects should implement a "
".subquery() method."
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index c270e1564..33e300bf6 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -51,7 +51,7 @@ from . import visitors
from .base import DedupeColumnCollection
from .base import DialectKWArgs
from .base import Executable
-from .base import SchemaEventTarget
+from .base import SchemaEventTarget as SchemaEventTarget
from .coercions import _document_text_coercion
from .elements import ClauseElement
from .elements import ColumnClause
@@ -2676,6 +2676,10 @@ class DefaultGenerator(Executable, SchemaItem):
def __init__(self, for_update=False):
self.for_update = for_update
+ @util.memoized_property
+ def is_callable(self):
+ raise NotImplementedError()
+
def _set_parent(self, column, **kw):
self.column = column
if self.for_update:
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index e5c2bef68..09befb078 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -53,7 +53,6 @@ from .base import Generative
from .base import HasCompileState
from .base import HasMemoized
from .base import Immutable
-from .base import prefix_anon_map
from .coercions import _document_text_coercion
from .elements import _anonymous_label
from .elements import BindParameter
@@ -69,10 +68,10 @@ from .elements import literal_column
from .elements import TableValuedColumn
from .elements import UnaryExpression
from .visitors import InternalTraversal
+from .visitors import prefix_anon_map
from .. import exc
from .. import util
-
and_ = BooleanClauseList.and_
_T = TypeVar("_T", bound=Any)
@@ -855,6 +854,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return self.alias(name=name)
+class NamedFromClause(FromClause):
+ named_with_column = True
+
+ name: str
+
+
class SelectLabelStyle(Enum):
"""Label style constants that may be passed to
:meth:`_sql.Select.set_label_style`."""
@@ -1317,15 +1322,16 @@ class NoInit:
# -> Lateral -> FromClause, but we accept SelectBase
# w/ non-deprecated coercion
# -> TableSample -> only for FromClause
-class AliasedReturnsRows(NoInit, FromClause):
+class AliasedReturnsRows(NoInit, NamedFromClause):
"""Base class of aliases against tables, subqueries, and other
selectables."""
_is_from_container = True
- named_with_column = True
_supports_derived_columns = False
+ element: ClauseElement
+
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("name", InternalTraversal.dp_anon_name),
@@ -1423,6 +1429,8 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows):
inherit_cache = True
+ element: FromClause
+
@classmethod
def _factory(cls, selectable, name=None, flat=False):
return coercions.expect(
@@ -1689,6 +1697,8 @@ class CTE(
+ HasSuffixes._has_suffixes_traverse_internals
)
+ element: HasCTE
+
@classmethod
def _factory(cls, selectable, name=None, recursive=False):
r"""Return a new :class:`_expression.CTE`,
@@ -1819,7 +1829,7 @@ class _CTEOpts(NamedTuple):
nesting: bool
-class HasCTE(roles.HasCTERole):
+class HasCTE(roles.HasCTERole, ClauseElement):
"""Mixin that declares a class to include CTE support.
.. versionadded:: 1.1
@@ -2247,6 +2257,8 @@ class Subquery(AliasedReturnsRows):
inherit_cache = True
+ element: Select
+
@classmethod
def _factory(cls, selectable, name=None):
"""Return a :class:`.Subquery` object."""
@@ -2331,7 +2343,7 @@ class FromGrouping(GroupedElement, FromClause):
self.element = state["element"]
-class TableClause(roles.DMLTableRole, Immutable, FromClause):
+class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
"""Represents a minimal "table" construct.
This is a lightweight table object that has only a name, a
@@ -2371,8 +2383,6 @@ class TableClause(roles.DMLTableRole, Immutable, FromClause):
("name", InternalTraversal.dp_string),
]
- named_with_column = True
-
_is_table = True
implicit_returning = False
@@ -2542,7 +2552,7 @@ class ForUpdateArg(ClauseElement):
SelfValues = typing.TypeVar("SelfValues", bound="Values")
-class Values(Generative, FromClause):
+class Values(Generative, NamedFromClause):
"""Represent a ``VALUES`` construct that can be used as a FROM element
in a statement.
@@ -2553,7 +2563,6 @@ class Values(Generative, FromClause):
"""
- named_with_column = True
__visit_name__ = "values"
_data = ()
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 7d21f1262..b2b1d9bc2 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -35,13 +35,13 @@ from .elements import _NONE_NAME
from .elements import quoted_name
from .elements import Slice
from .elements import TypeCoerce as type_coerce # noqa
-from .traversals import InternalTraversal
from .type_api import Emulated
from .type_api import NativeForEmulated # noqa
from .type_api import to_instance
from .type_api import TypeDecorator
from .type_api import TypeEngine
from .type_api import Variant # noqa
+from .visitors import InternalTraversal
from .. import event
from .. import exc
from .. import inspection
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 4fa23d370..cf9487f93 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -15,7 +15,10 @@ import operator
import typing
from typing import Any
from typing import Callable
+from typing import Deque
from typing import Dict
+from typing import Set
+from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -23,9 +26,9 @@ from . import operators
from .cache_key import HasCacheKey
from .visitors import _TraverseInternalsType
from .visitors import anon_map
-from .visitors import ExtendedInternalTraversal
+from .visitors import ExternallyTraversible
+from .visitors import HasTraversalDispatch
from .visitors import HasTraverseInternals
-from .visitors import InternalTraversal
from .. import util
from ..util import langhelpers
@@ -35,6 +38,7 @@ COMPARE_SUCCEEDED = True
def compare(obj1, obj2, **kw):
+ strategy: TraversalComparatorStrategy
if kw.get("use_proxies", False):
strategy = ColIdentityComparatorStrategy()
else:
@@ -45,16 +49,18 @@ def compare(obj1, obj2, **kw):
def _preconfigure_traversals(target_hierarchy):
for cls in util.walk_subclasses(target_hierarchy):
- if hasattr(cls, "_traverse_internals"):
- cls._generate_cache_attrs()
+ if hasattr(cls, "_generate_cache_attrs") and hasattr(
+ cls, "_traverse_internals"
+ ):
+ cls._generate_cache_attrs() # type: ignore
_copy_internals.generate_dispatch(
- cls,
- cls._traverse_internals,
+ cls, # type: ignore
+ cls._traverse_internals, # type: ignore
"_generated_copy_internals_traversal",
)
_get_children.generate_dispatch(
- cls,
- cls._traverse_internals,
+ cls, # type: ignore
+ cls._traverse_internals, # type: ignore
"_generated_get_children_traversal",
)
@@ -125,54 +131,58 @@ class HasShallowCopy(HasTraverseInternals):
meth_text = f"def {method_name}(self, d):\n{code}\n"
return langhelpers._exec_code_in_env(meth_text, {}, method_name)
- def _shallow_from_dict(self, d: Dict) -> None:
+ def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
cls = self.__class__
+ shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
try:
shallow_from_dict = cls.__dict__[
"_generated_shallow_from_dict_traversal"
]
except KeyError:
- shallow_from_dict = (
- cls._generated_shallow_from_dict_traversal # type: ignore
- ) = self._generate_shallow_from_dict(
+ shallow_from_dict = self._generate_shallow_from_dict(
cls._traverse_internals,
"_generated_shallow_from_dict_traversal",
)
+ cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa E501
+
shallow_from_dict(self, d)
def _shallow_to_dict(self) -> Dict[str, Any]:
cls = self.__class__
+ shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
+
try:
shallow_to_dict = cls.__dict__[
"_generated_shallow_to_dict_traversal"
]
except KeyError:
- shallow_to_dict = (
- cls._generated_shallow_to_dict_traversal # type: ignore
- ) = self._generate_shallow_to_dict(
+ shallow_to_dict = self._generate_shallow_to_dict(
cls._traverse_internals, "_generated_shallow_to_dict_traversal"
)
+ cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa E501
return shallow_to_dict(self)
- def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy):
+ def _shallow_copy_to(
+ self: SelfHasShallowCopy, other: SelfHasShallowCopy
+ ) -> None:
cls = self.__class__
+ shallow_copy: Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None]
try:
shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
except KeyError:
- shallow_copy = (
- cls._generated_shallow_copy_traversal # type: ignore
- ) = self._generate_shallow_copy(
+ shallow_copy = self._generate_shallow_copy(
cls._traverse_internals, "_generated_shallow_copy_traversal"
)
+ cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501
shallow_copy(self, other)
- def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy:
+ def _clone(self: SelfHasShallowCopy, **kw: Any) -> SelfHasShallowCopy:
"""Create a shallow copy"""
c = self.__class__.__new__(self.__class__)
self._shallow_copy_to(c)
@@ -246,7 +256,7 @@ class HasCopyInternals(HasTraverseInternals):
setattr(self, attrname, result)
-class _CopyInternalsTraversal(InternalTraversal):
+class _CopyInternalsTraversal(HasTraversalDispatch):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
@@ -381,7 +391,7 @@ def _flatten_clauseelement(element):
return element
-class _GetChildrenTraversal(InternalTraversal):
+class _GetChildrenTraversal(HasTraversalDispatch):
"""Generate a _children_traversal internal traversal dispatch for classes
with a _traverse_internals collection."""
@@ -463,13 +473,13 @@ def _resolve_name_for_compare(element, name, anon_map, **kw):
return name
-class TraversalComparatorStrategy(
- ExtendedInternalTraversal, util.MemoizedSlots
-):
+class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
__slots__ = "stack", "cache", "anon_map"
def __init__(self):
- self.stack = deque()
+ self.stack: Deque[
+ Tuple[ExternallyTraversible, ExternallyTraversible]
+ ] = deque()
self.cache = set()
def _memoized_attr_anon_map(self):
@@ -653,7 +663,7 @@ class TraversalComparatorStrategy(
if seq1 is None:
return seq2 is None
- completed = set()
+ completed: Set[object] = set()
for clause in seq1:
for other_clause in set(seq2).difference(completed):
if self.compare_inner(clause, other_clause, **kw):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index e0248adf0..5114a2431 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -21,9 +21,9 @@ from . import coercions
from . import operators
from . import roles
from . import visitors
-from .annotation import _deep_annotate # noqa
-from .annotation import _deep_deannotate # noqa
-from .annotation import _shallow_annotate # noqa
+from .annotation import _deep_annotate as _deep_annotate
+from .annotation import _deep_deannotate as _deep_deannotate
+from .annotation import _shallow_annotate as _shallow_annotate
from .base import _expand_cloned
from .base import _from_objects
from .base import ColumnSet
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 111ecd32e..0c41e440e 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -7,43 +7,46 @@
"""Visitor/traversal interface and library functions.
-SQLAlchemy schema and expression constructs rely on a Python-centric
-version of the classic "visitor" pattern as the primary way in which
-they apply functionality. The most common use of this pattern
-is statement compilation, where individual expression classes match
-up to rendering methods that produce a string result. Beyond this,
-the visitor system is also used to inspect expressions for various
-information and patterns, as well as for the purposes of applying
-transformations to expressions.
-
-Examples of how the visit system is used can be seen in the source code
-of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler``
-modules. Some background on clause adaption is also at
-https://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
"""
from __future__ import annotations
from collections import deque
+from enum import Enum
import itertools
import operator
import typing
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import ClassVar
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import Mapping
+from typing import Optional
from typing import Tuple
+from typing import Type
+from typing import TypeVar
+from typing import Union
from .. import exc
from .. import util
from ..util import langhelpers
-from ..util import symbol
from ..util._has_cy import HAS_CYEXTENSION
-from ..util.langhelpers import _symbol
+from ..util.typing import Protocol
+from ..util.typing import Self
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
- from ._py_util import cache_anon_map as anon_map # noqa
+ from ._py_util import prefix_anon_map as prefix_anon_map
+ from ._py_util import cache_anon_map as anon_map
else:
- from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
+ from sqlalchemy.cyextension.util import prefix_anon_map as prefix_anon_map
+ from sqlalchemy.cyextension.util import cache_anon_map as anon_map
+
__all__ = [
"iterate",
@@ -54,57 +57,23 @@ __all__ = [
"Visitable",
"ExternalTraversal",
"InternalTraversal",
+ "anon_map",
]
-_TraverseInternalsType = List[Tuple[str, _symbol]]
-
-
-class HasTraverseInternals:
- """base for classes that have a "traverse internals" element,
- which defines all kinds of ways of traversing the elements of an object.
-
- """
-
- __slots__ = ()
-
- _traverse_internals: _TraverseInternalsType
-
- @util.preload_module("sqlalchemy.sql.traversals")
- def get_children(self, omit_attrs=(), **kw):
- r"""Return immediate child :class:`.visitors.Visitable`
- elements of this :class:`.visitors.Visitable`.
-
- This is used for visit traversal.
-
- \**kw may contain flags that change the collection that is
- returned, for example to return a subset of items in order to
- cut down on larger traversals, or to return child items from a
- different context (such as schema-level collections instead of
- clause-level).
-
- """
-
- traversals = util.preloaded.sql_traversals
-
- try:
- traverse_internals = self._traverse_internals
- except AttributeError:
- # user-defined classes may not have a _traverse_internals
- return []
- dispatch = traversals._get_children.run_generated_dispatch
- return itertools.chain.from_iterable(
- meth(obj, **kw)
- for attrname, obj, meth in dispatch(
- self, traverse_internals, "_generated_get_children_traversal"
- )
- if attrname not in omit_attrs and obj is not None
- )
+class _CompilerDispatchType(Protocol):
+ def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any:
+ ...
class Visitable:
"""Base class for visitable objects.
+ :class:`.Visitable` is used to implement the SQL compiler dispatch
+ functions. Other forms of traversal such as for cache key generation
+ are implemented separately using the :class:`.HasTraverseInternals`
+ interface.
+
.. versionchanged:: 2.0 The :class:`.Visitable` class was named
:class:`.Traversible` in the 1.4 series; the name is changed back
to :class:`.Visitable` in 2.0 which is what it was prior to 1.4.
@@ -117,32 +86,20 @@ class Visitable:
__visit_name__: str
+ _original_compiler_dispatch: _CompilerDispatchType
+
+ if typing.TYPE_CHECKING:
+
+ def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str:
+ ...
+
def __init_subclass__(cls) -> None:
if "__visit_name__" in cls.__dict__:
cls._generate_compiler_dispatch()
super().__init_subclass__()
@classmethod
- def _generate_compiler_dispatch(cls):
- """Assign dispatch attributes to various kinds of
- "visitable" classes.
-
- Attributes include:
-
- * The ``_compiler_dispatch`` method, corresponding to
- ``__visit_name__``. This is called "external traversal" because the
- caller of each visit() method is responsible for sub-traversing the
- inner elements of each object. This is appropriate for string
- compilers and other traversals that need to call upon the inner
- elements in a specific pattern.
-
- * internal traversal collections ``_children_traversal``,
- ``_cache_key_traversal``, ``_copy_internals_traversal``, generated
- from an optional ``_traverse_internals`` collection of symbols which
- comes from the :class:`.InternalTraversal` list of symbols. This is
- called "internal traversal".
-
- """
+ def _generate_compiler_dispatch(cls) -> None:
visit_name = cls.__visit_name__
if "_compiler_dispatch" in cls.__dict__:
@@ -161,7 +118,9 @@ class Visitable:
name = "visit_%s" % visit_name
getter = operator.attrgetter(name)
- def _compiler_dispatch(self, visitor, **kw):
+ def _compiler_dispatch(
+ self: Visitable, visitor: Any, **kw: Any
+ ) -> str:
"""Look for an attribute named "visit_<visit_name>" on the
visitor, and call it with the same kw params.
@@ -169,105 +128,20 @@ class Visitable:
try:
meth = getter(visitor)
except AttributeError as err:
- return visitor.visit_unsupported_compilation(self, err, **kw)
+ return visitor.visit_unsupported_compilation(self, err, **kw) # type: ignore # noqa E501
else:
- return meth(self, **kw)
+ return meth(self, **kw) # type: ignore # noqa E501
- cls._compiler_dispatch = (
+ cls._compiler_dispatch = ( # type: ignore
cls._original_compiler_dispatch
) = _compiler_dispatch
- def __class_getitem__(cls, key):
+ def __class_getitem__(cls, key: str) -> Any:
# allow generic classes in py3.9+
return cls
-class _HasTraversalDispatch:
- r"""Define infrastructure for the :class:`.InternalTraversal` class.
-
- .. versionadded:: 2.0
-
- """
-
- __slots__ = ()
-
- def __init_subclass__(cls) -> None:
- cls._generate_traversal_dispatch()
- super().__init_subclass__()
-
- def dispatch(self, visit_symbol):
- """Given a method from :class:`._HasTraversalDispatch`, return the
- corresponding method on a subclass.
-
- """
- name = self._dispatch_lookup[visit_symbol]
- return getattr(self, name, None)
-
- def run_generated_dispatch(
- self, target, internal_dispatch, generate_dispatcher_name
- ):
- try:
- dispatcher = target.__class__.__dict__[generate_dispatcher_name]
- except KeyError:
- # most of the dispatchers are generated up front
- # in sqlalchemy/sql/__init__.py ->
- # traversals.py-> _preconfigure_traversals().
- # this block will generate any remaining dispatchers.
- dispatcher = self.generate_dispatch(
- target.__class__, internal_dispatch, generate_dispatcher_name
- )
- return dispatcher(target, self)
-
- def generate_dispatch(
- self, target_cls, internal_dispatch, generate_dispatcher_name
- ):
- dispatcher = self._generate_dispatcher(
- internal_dispatch, generate_dispatcher_name
- )
- # assert isinstance(target_cls, type)
- setattr(target_cls, generate_dispatcher_name, dispatcher)
- return dispatcher
-
- @classmethod
- def _generate_traversal_dispatch(cls):
- lookup = {}
- clsdict = cls.__dict__
- for key, sym in clsdict.items():
- if key.startswith("dp_"):
- visit_key = key.replace("dp_", "visit_")
- sym_name = sym.name
- assert sym_name not in lookup, sym_name
- lookup[sym] = lookup[sym_name] = visit_key
- if hasattr(cls, "_dispatch_lookup"):
- lookup.update(cls._dispatch_lookup)
- cls._dispatch_lookup = lookup
-
- def _generate_dispatcher(self, internal_dispatch, method_name):
- names = []
- for attrname, visit_sym in internal_dispatch:
- meth = self.dispatch(visit_sym)
- if meth:
- visit_name = ExtendedInternalTraversal._dispatch_lookup[
- visit_sym
- ]
- names.append((attrname, visit_name))
-
- code = (
- (" return [\n")
- + (
- ", \n".join(
- " (%r, self.%s, visitor.%s)"
- % (attrname, attrname, visit_name)
- for attrname, visit_name in names
- )
- )
- + ("\n ]\n")
- )
- meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
- return langhelpers._exec_code_in_env(meth_text, {}, method_name)
-
-
-class InternalTraversal(_HasTraversalDispatch):
+class InternalTraversal(Enum):
r"""Defines visitor symbols used for internal traversal.
The :class:`.InternalTraversal` class is used in two ways. One is that
@@ -306,18 +180,16 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- __slots__ = ()
-
- dp_has_cache_key = symbol("HC")
+ dp_has_cache_key = "HC"
"""Visit a :class:`.HasCacheKey` object."""
- dp_has_cache_key_list = symbol("HL")
+ dp_has_cache_key_list = "HL"
"""Visit a list of :class:`.HasCacheKey` objects."""
- dp_clauseelement = symbol("CE")
+ dp_clauseelement = "CE"
"""Visit a :class:`_expression.ClauseElement` object."""
- dp_fromclause_canonical_column_collection = symbol("FC")
+ dp_fromclause_canonical_column_collection = "FC"
"""Visit a :class:`_expression.FromClause` object in the context of the
``columns`` attribute.
@@ -329,30 +201,30 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_clauseelement_tuples = symbol("CTS")
+ dp_clauseelement_tuples = "CTS"
"""Visit a list of tuples which contain :class:`_expression.ClauseElement`
objects.
"""
- dp_clauseelement_list = symbol("CL")
+ dp_clauseelement_list = "CL"
"""Visit a list of :class:`_expression.ClauseElement` objects.
"""
- dp_clauseelement_tuple = symbol("CT")
+ dp_clauseelement_tuple = "CT"
"""Visit a tuple of :class:`_expression.ClauseElement` objects.
"""
- dp_executable_options = symbol("EO")
+ dp_executable_options = "EO"
- dp_with_context_options = symbol("WC")
+ dp_with_context_options = "WC"
- dp_fromclause_ordered_set = symbol("CO")
+ dp_fromclause_ordered_set = "CO"
"""Visit an ordered set of :class:`_expression.FromClause` objects. """
- dp_string = symbol("S")
+ dp_string = "S"
"""Visit a plain string value.
Examples include table and column names, bound parameter keys, special
@@ -363,10 +235,10 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_string_list = symbol("SL")
+ dp_string_list = "SL"
"""Visit a list of strings."""
- dp_anon_name = symbol("AN")
+ dp_anon_name = "AN"
"""Visit a potentially "anonymized" string value.
The string value is considered to be significant for cache key
@@ -374,7 +246,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_boolean = symbol("B")
+ dp_boolean = "B"
"""Visit a boolean value.
The boolean value is considered to be significant for cache key
@@ -382,7 +254,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_operator = symbol("O")
+ dp_operator = "O"
"""Visit an operator.
The operator is a function from the :mod:`sqlalchemy.sql.operators`
@@ -393,7 +265,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_type = symbol("T")
+ dp_type = "T"
"""Visit a :class:`.TypeEngine` object
The type object is considered to be significant for cache key
@@ -401,7 +273,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_plain_dict = symbol("PD")
+ dp_plain_dict = "PD"
"""Visit a dictionary with string keys.
The keys of the dictionary should be strings, the values should
@@ -410,22 +282,22 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_dialect_options = symbol("DO")
+ dp_dialect_options = "DO"
"""Visit a dialect options structure."""
- dp_string_clauseelement_dict = symbol("CD")
+ dp_string_clauseelement_dict = "CD"
"""Visit a dictionary of string keys to :class:`_expression.ClauseElement`
objects.
"""
- dp_string_multi_dict = symbol("MD")
+ dp_string_multi_dict = "MD"
"""Visit a dictionary of string keys to values which may either be
plain immutable/hashable or :class:`.HasCacheKey` objects.
"""
- dp_annotations_key = symbol("AK")
+ dp_annotations_key = "AK"
"""Visit the _annotations_cache_key element.
This is a dictionary of additional information about a ClauseElement
@@ -436,7 +308,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_plain_obj = symbol("PO")
+ dp_plain_obj = "PO"
"""Visit a plain python object.
The value should be immutable and hashable, such as an integer.
@@ -444,7 +316,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_named_ddl_element = symbol("DD")
+ dp_named_ddl_element = "DD"
"""Visit a simple named DDL element.
The current object used by this method is the :class:`.Sequence`.
@@ -454,57 +326,56 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_prefix_sequence = symbol("PS")
+ dp_prefix_sequence = "PS"
"""Visit the sequence represented by :class:`_expression.HasPrefixes`
or :class:`_expression.HasSuffixes`.
"""
- dp_table_hint_list = symbol("TH")
+ dp_table_hint_list = "TH"
"""Visit the ``_hints`` collection of a :class:`_expression.Select`
object.
"""
- dp_setup_join_tuple = symbol("SJ")
+ dp_setup_join_tuple = "SJ"
- dp_memoized_select_entities = symbol("ME")
+ dp_memoized_select_entities = "ME"
- dp_statement_hint_list = symbol("SH")
+ dp_statement_hint_list = "SH"
"""Visit the ``_statement_hints`` collection of a
:class:`_expression.Select`
object.
"""
- dp_unknown_structure = symbol("UK")
+ dp_unknown_structure = "UK"
"""Visit an unknown structure.
"""
- dp_dml_ordered_values = symbol("DML_OV")
+ dp_dml_ordered_values = "DML_OV"
"""Visit the values() ordered tuple list of an
:class:`_expression.Update` object."""
- dp_dml_values = symbol("DML_V")
+ dp_dml_values = "DML_V"
"""Visit the values() dictionary of a :class:`.ValuesBase`
(e.g. Insert or Update) object.
"""
- dp_dml_multi_values = symbol("DML_MV")
+ dp_dml_multi_values = "DML_MV"
"""Visit the values() multi-valued list of dictionaries of an
:class:`_expression.Insert` object.
"""
- dp_propagate_attrs = symbol("PA")
+ dp_propagate_attrs = "PA"
"""Visit the propagate attrs dict. This hardcodes to the particular
elements we care about right now."""
-
-class ExtendedInternalTraversal(InternalTraversal):
- """Defines additional symbols that are useful in caching applications.
+ """Symbols that follow are additional symbols that are useful in
+ caching applications.
Traversals for :class:`_expression.ClauseElement` objects only need to use
those symbols present in :class:`.InternalTraversal`. However, for
@@ -513,9 +384,7 @@ class ExtendedInternalTraversal(InternalTraversal):
"""
- __slots__ = ()
-
- dp_ignore = symbol("IG")
+ dp_ignore = "IG"
"""Specify an object that should be ignored entirely.
This currently applies function call argument caching where some
@@ -523,29 +392,235 @@ class ExtendedInternalTraversal(InternalTraversal):
"""
- dp_inspectable = symbol("IS")
+ dp_inspectable = "IS"
"""Visit an inspectable object where the return value is a
:class:`.HasCacheKey` object."""
- dp_multi = symbol("M")
+ dp_multi = "M"
"""Visit an object that may be a :class:`.HasCacheKey` or may be a
plain hashable object."""
- dp_multi_list = symbol("MT")
+ dp_multi_list = "MT"
"""Visit a tuple containing elements that may be :class:`.HasCacheKey` or
may be a plain hashable object."""
- dp_has_cache_key_tuples = symbol("HT")
+ dp_has_cache_key_tuples = "HT"
"""Visit a list of tuples which contain :class:`.HasCacheKey`
objects.
"""
- dp_inspectable_list = symbol("IL")
+ dp_inspectable_list = "IL"
"""Visit a list of inspectable objects which upon inspection are
HasCacheKey objects."""
+_TraverseInternalsType = List[Tuple[str, InternalTraversal]]
+"""a structure that defines how a HasTraverseInternals should be
+traversed.
+
+This structure consists of a list of (attributename, internaltraversal)
+tuples, where the "attributename" refers to the name of an attribute on an
+instance of the HasTraverseInternals object, and "internaltraversal" refers
+to an :class:`.InternalTraversal` enumeration symbol defining what kind
+of data this attribute stores, which indicates to the traverser how it should
+be handled.
+
+"""
+
+
+class HasTraverseInternals:
+ """base for classes that have a "traverse internals" element,
+ which defines all kinds of ways of traversing the elements of an object.
+
+ Compared to :class:`.Visitable`, which relies upon an external visitor to
+ define how the object is travered (i.e. the :class:`.SQLCompiler`), the
+ :class:`.HasTraverseInternals` interface allows classes to define their own
+ traversal, that is, what attributes are accessed and in what order.
+
+ """
+
+ __slots__ = ()
+
+ _traverse_internals: _TraverseInternalsType
+
+ @util.preload_module("sqlalchemy.sql.traversals")
+ def get_children(
+ self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> Iterable[HasTraverseInternals]:
+ r"""Return immediate child :class:`.visitors.HasTraverseInternals`
+ elements of this :class:`.visitors.HasTraverseInternals`.
+
+ This is used for visit traversal.
+
+ \**kw may contain flags that change the collection that is
+ returned, for example to return a subset of items in order to
+ cut down on larger traversals, or to return child items from a
+ different context (such as schema-level collections instead of
+ clause-level).
+
+ """
+
+ traversals = util.preloaded.sql_traversals
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return []
+
+ dispatch = traversals._get_children.run_generated_dispatch
+ return itertools.chain.from_iterable(
+ meth(obj, **kw)
+ for attrname, obj, meth in dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ )
+ if attrname not in omit_attrs and obj is not None
+ )
+
+
+class _InternalTraversalDispatchType(Protocol):
+ def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any:
+ ...
+
+
+class HasTraversalDispatch:
+ r"""Define infrastructure for classes that perform internal traversals
+
+ .. versionadded:: 2.0
+
+ """
+
+ __slots__ = ()
+
+ _dispatch_lookup: ClassVar[Dict[Union[InternalTraversal, str], str]] = {}
+
+ def dispatch(self, visit_symbol: InternalTraversal) -> Callable[..., Any]:
+ """Given a method from :class:`.HasTraversalDispatch`, return the
+ corresponding method on a subclass.
+
+ """
+ name = _dispatch_lookup[visit_symbol]
+ return getattr(self, name, None) # type: ignore
+
+ def run_generated_dispatch(
+ self,
+ target: object,
+ internal_dispatch: _TraverseInternalsType,
+ generate_dispatcher_name: str,
+ ) -> Any:
+ dispatcher: _InternalTraversalDispatchType
+ try:
+ dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+ except KeyError:
+ # traversals.py -> _preconfigure_traversals()
+ # may be used to run these ahead of time, but
+ # is not enabled right now.
+ # this block will generate any remaining dispatchers.
+ dispatcher = self.generate_dispatch(
+ target.__class__, internal_dispatch, generate_dispatcher_name
+ )
+ return dispatcher(target, self)
+
+ def generate_dispatch(
+ self,
+ target_cls: Type[object],
+ internal_dispatch: _TraverseInternalsType,
+ generate_dispatcher_name: str,
+ ) -> _InternalTraversalDispatchType:
+ dispatcher = self._generate_dispatcher(
+ internal_dispatch, generate_dispatcher_name
+ )
+ # assert isinstance(target_cls, type)
+ setattr(target_cls, generate_dispatcher_name, dispatcher)
+ return dispatcher
+
+ def _generate_dispatcher(
+ self, internal_dispatch: _TraverseInternalsType, method_name: str
+ ) -> _InternalTraversalDispatchType:
+ names = []
+ for attrname, visit_sym in internal_dispatch:
+ meth = self.dispatch(visit_sym)
+ if meth:
+ visit_name = _dispatch_lookup[visit_sym]
+ names.append((attrname, visit_name))
+
+ code = (
+ (" return [\n")
+ + (
+ ", \n".join(
+ " (%r, self.%s, visitor.%s)"
+ % (attrname, attrname, visit_name)
+ for attrname, visit_name in names
+ )
+ )
+ + ("\n ]\n")
+ )
+ meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+ return cast(
+ _InternalTraversalDispatchType,
+ langhelpers._exec_code_in_env(meth_text, {}, method_name),
+ )
+
+
+ExtendedInternalTraversal = InternalTraversal
+
+
+def _generate_traversal_dispatch() -> None:
+ lookup = _dispatch_lookup
+
+ for sym in InternalTraversal:
+ key = sym.name
+ if key.startswith("dp_"):
+ visit_key = key.replace("dp_", "visit_")
+ sym_name = sym.value
+ assert sym_name not in lookup, sym_name
+ lookup[sym] = lookup[sym_name] = visit_key
+
+
+_dispatch_lookup = HasTraversalDispatch._dispatch_lookup
+_generate_traversal_dispatch()
+
+
+class ExternallyTraversible(HasTraverseInternals, Visitable):
+ __slots__ = ()
+
+ _annotations: Collection[Any] = ()
+
+ if typing.TYPE_CHECKING:
+
+ def get_children(
+ self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> Iterable[ExternallyTraversible]:
+ ...
+
+ def _clone(self: Self, **kw: Any) -> Self:
+ """clone this element"""
+ raise NotImplementedError()
+
+ def _copy_internals(
+ self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> Self:
+ """Reassign internal elements to be clones of themselves.
+
+ Called during a copy-and-traverse operation on newly
+ shallow-copied elements to create a deep copy.
+
+ The given clone function should be used, which may be applying
+ additional transformations to the element (i.e. replacement
+ traversal, cloned traversal, annotations).
+
+ """
+ raise NotImplementedError()
+
+
+_ET = TypeVar("_ET", bound=ExternallyTraversible)
+_TraverseCallableType = Callable[[_ET], None]
+_TraverseTransformCallableType = Callable[
+ [ExternallyTraversible], Optional[ExternallyTraversible]
+]
+
+
class ExternalTraversal:
"""Base class for visitor objects which can traverse externally using
the :func:`.visitors.traverse` function.
@@ -555,7 +630,8 @@ class ExternalTraversal:
"""
- __traverse_options__ = {}
+ __traverse_options__: Dict[str, Any] = {}
+ _next: Optional[ExternalTraversal]
def traverse_single(self, obj: Visitable, **kw: Any) -> Any:
for v in self.visitor_iterator:
@@ -563,20 +639,22 @@ class ExternalTraversal:
if meth:
return meth(obj, **kw)
- def iterate(self, obj):
+ def iterate(
+ self, obj: ExternallyTraversible
+ ) -> Iterator[ExternallyTraversible]:
"""Traverse the given expression structure, returning an iterator
of all elements.
"""
return iterate(obj, self.__traverse_options__)
- def traverse(self, obj):
+ def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
"""Traverse and visit the given expression structure."""
return traverse(obj, self.__traverse_options__, self._visitor_dict)
@util.memoized_property
- def _visitor_dict(self):
+ def _visitor_dict(self) -> Dict[str, _TraverseCallableType[Any]]:
visitors = {}
for name in dir(self):
@@ -585,16 +663,16 @@ class ExternalTraversal:
return visitors
@property
- def visitor_iterator(self):
+ def visitor_iterator(self) -> Iterator[ExternalTraversal]:
"""Iterate through this visitor and each 'chained' visitor."""
- v = self
+ v: Optional[ExternalTraversal] = self
while v:
yield v
v = getattr(v, "_next", None)
- def chain(self, visitor):
- """'Chain' an additional ClauseVisitor onto this ClauseVisitor.
+ def chain(self, visitor: ExternalTraversal) -> ExternalTraversal:
+ """'Chain' an additional ExternalTraversal onto this ExternalTraversal
The chained visitor will receive all visit events after this one.
@@ -614,14 +692,16 @@ class CloningExternalTraversal(ExternalTraversal):
"""
- def copy_and_process(self, list_):
+ def copy_and_process(
+ self, list_: List[ExternallyTraversible]
+ ) -> List[ExternallyTraversible]:
"""Apply cloned traversal to the given list of elements, and return
the new list.
"""
return [self.traverse(x) for x in list_]
- def traverse(self, obj):
+ def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
"""Traverse and visit the given expression structure."""
return cloned_traverse(
@@ -638,7 +718,9 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
"""
- def replace(self, elem):
+ def replace(
+ self, elem: ExternallyTraversible
+ ) -> Optional[ExternallyTraversible]:
"""Receive pre-copied elements during a cloning traversal.
If the method returns a new element, the element is used
@@ -647,15 +729,19 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
"""
return None
- def traverse(self, obj):
+ def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
"""Traverse and visit the given expression structure."""
- def replace(elem):
+ def replace(
+ elem: ExternallyTraversible,
+ ) -> Optional[ExternallyTraversible]:
for v in self.visitor_iterator:
- e = v.replace(elem)
+ e = cast(ReplacingExternalTraversal, v).replace(elem)
if e is not None:
return e
+ return None
+
return replacement_traverse(obj, self.__traverse_options__, replace)
@@ -667,7 +753,9 @@ CloningVisitor = CloningExternalTraversal
ReplacingCloningVisitor = ReplacingExternalTraversal
-def iterate(obj, opts=util.immutabledict()):
+def iterate(
+ obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT
+) -> Iterator[ExternallyTraversible]:
r"""Traverse the given expression structure, returning an iterator.
Traversal is configured to be breadth-first.
@@ -702,7 +790,11 @@ def iterate(obj, opts=util.immutabledict()):
stack.append(t.get_children(**opts))
-def traverse_using(iterator, obj, visitors):
+def traverse_using(
+ iterator: Iterable[ExternallyTraversible],
+ obj: ExternallyTraversible,
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> ExternallyTraversible:
"""Visit the given expression structure using the given iterator of
objects.
@@ -734,7 +826,11 @@ def traverse_using(iterator, obj, visitors):
return obj
-def traverse(obj, opts, visitors):
+def traverse(
+ obj: ExternallyTraversible,
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> ExternallyTraversible:
"""Traverse and visit the given expression structure using the default
iterator.
@@ -767,7 +863,11 @@ def traverse(obj, opts, visitors):
return traverse_using(iterate(obj, opts), obj, visitors)
-def cloned_traverse(obj, opts, visitors):
+def cloned_traverse(
+ obj: ExternallyTraversible,
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseTransformCallableType],
+) -> ExternallyTraversible:
"""Clone the given expression structure, allowing modifications by
visitors.
@@ -794,20 +894,24 @@ def cloned_traverse(obj, opts, visitors):
"""
- cloned = {}
+ cloned: Dict[int, ExternallyTraversible] = {}
stop_on = set(opts.get("stop_on", []))
- def deferred_copy_internals(obj):
+ def deferred_copy_internals(
+ obj: ExternallyTraversible,
+ ) -> ExternallyTraversible:
return cloned_traverse(obj, opts, visitors)
- def clone(elem, **kw):
+ def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible:
if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
if "replace" in kw:
- newelem = kw["replace"](elem)
+ newelem = cast(
+ Optional[ExternallyTraversible], kw["replace"](elem)
+ )
if newelem is not None:
cloned[id(elem)] = newelem
return newelem
@@ -823,11 +927,15 @@ def cloned_traverse(obj, opts, visitors):
obj = clone(
obj, deferred_copy_internals=deferred_copy_internals, **opts
)
- clone = None # remove gc cycles
+ clone = None # type: ignore[assignment] # remove gc cycles
return obj
-def replacement_traverse(obj, opts, replace):
+def replacement_traverse(
+ obj: ExternallyTraversible,
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType,
+) -> ExternallyTraversible:
"""Clone the given expression structure, allowing element
replacement by a given replacement function.
@@ -854,10 +962,12 @@ def replacement_traverse(obj, opts, replace):
cloned = {}
stop_on = {id(x) for x in opts.get("stop_on", [])}
- def deferred_copy_internals(obj):
+ def deferred_copy_internals(
+ obj: ExternallyTraversible,
+ ) -> ExternallyTraversible:
return replacement_traverse(obj, opts, replace)
- def clone(elem, **kw):
+ def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible:
if (
id(elem) in stop_on
or "no_replacement_traverse" in elem._annotations
@@ -888,5 +998,5 @@ def replacement_traverse(obj, opts, replace):
obj = clone(
obj, deferred_copy_internals=deferred_copy_internals, **opts
)
- clone = None # remove gc cycles
+ clone = None # type: ignore[assignment] # remove gc cycles
return obj
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 8cb84f73f..ae61155ff 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -305,9 +305,11 @@ def _update_argspec_defaults_into_env(spec, env):
return spec
-def _exec_code_in_env(code, env, fn_name):
+def _exec_code_in_env(
+ code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str
+) -> Callable[..., Any]:
exec(code, env)
- return env[fn_name]
+ return env[fn_name] # type: ignore[no-any-return]
_PF = TypeVar("_PF")
@@ -1181,7 +1183,7 @@ class memoized_property(Generic[_T]):
obj.__dict__.pop(name, None)
-def memoized_instancemethod(fn):
+def memoized_instancemethod(fn: _F) -> _F:
"""Decorate a method memoize its return value.
Best applied to no-arg methods: memoization is not sensitive to
@@ -1201,7 +1203,7 @@ def memoized_instancemethod(fn):
self.__dict__[fn.__name__] = memo
return result
- return update_wrapper(oneshot, fn)
+ return update_wrapper(oneshot, fn) # type: ignore
class HasMemoized:
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 160eabd85..c089616e4 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -3,10 +3,11 @@ from __future__ import annotations
import sys
import typing
from typing import Any
-from typing import Callable # noqa
from typing import cast
from typing import Dict
from typing import ForwardRef
+from typing import Iterable
+from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
@@ -16,6 +17,11 @@ from typing_extensions import NotRequired as NotRequired # noqa
from . import compat
_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT")
+_KT_co = TypeVar("_KT_co", covariant=True)
+_KT_contra = TypeVar("_KT_contra", contravariant=True)
+_VT = TypeVar("_VT")
+_VT_co = TypeVar("_VT_co", covariant=True)
Self = TypeVar("Self", bound=Any)
@@ -45,6 +51,18 @@ else:
from typing_extensions import Protocol as Protocol # noqa F401
from typing_extensions import TypedDict as TypedDict # noqa F401
+# copied from TypeShed, required in order to implement
+# MutableMapping.update()
+
+
+class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
+ def keys(self) -> Iterable[_KT]:
+ ...
+
+ def __getitem__(self, __k: _KT) -> _VT_co:
+ ...
+
+
# work around https://github.com/microsoft/pyright/issues/3025
_LiteralStar = Literal["*"]
@@ -120,7 +138,9 @@ def make_union_type(*types):
return cast(Any, Union).__getitem__(types)
-def expand_unions(type_, include_union=False, discard_none=False):
+def expand_unions(
+ type_: Type[Any], include_union: bool = False, discard_none: bool = False
+) -> Tuple[Type[Any], ...]:
"""Return a type as as a tuple of individual types, expanding for
``Union`` types."""
diff --git a/pyproject.toml b/pyproject.toml
index b90feae49..407af71c3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,30 +40,12 @@ markers = [
]
[tool.pyright]
-include = [
- "lib/sqlalchemy/engine/base.py",
- "lib/sqlalchemy/engine/events.py",
- "lib/sqlalchemy/engine/interfaces.py",
- "lib/sqlalchemy/engine/_py_row.py",
- "lib/sqlalchemy/engine/result.py",
- "lib/sqlalchemy/engine/row.py",
- "lib/sqlalchemy/engine/util.py",
- "lib/sqlalchemy/engine/url.py",
- "lib/sqlalchemy/pool/",
- "lib/sqlalchemy/event/",
- "lib/sqlalchemy/events.py",
- "lib/sqlalchemy/exc.py",
- "lib/sqlalchemy/log.py",
- "lib/sqlalchemy/inspection.py",
- "lib/sqlalchemy/schema.py",
- "lib/sqlalchemy/types.py",
- "lib/sqlalchemy/util/",
-]
+
reportPrivateUsage = "none"
reportUnusedClass = "none"
reportUnusedFunction = "none"
-
+reportTypedDictNotRequiredAccess = "warning"
[tool.mypy]
mypy_path = "./lib/"
@@ -99,6 +81,11 @@ ignore_errors = true
# strict checking
[[tool.mypy.overrides]]
module = [
+ "sqlalchemy.sql.annotation",
+ "sqlalchemy.sql.cache_key",
+ "sqlalchemy.sql.roles",
+ "sqlalchemy.sql.visitors",
+ "sqlalchemy.sql._py_util",
"sqlalchemy.connectors.*",
"sqlalchemy.engine.*",
"sqlalchemy.ext.associationproxy",
@@ -117,6 +104,12 @@ strict = true
[[tool.mypy.overrides]]
module = [
+ "sqlalchemy.sql.coercions",
+ "sqlalchemy.sql.compiler",
+ #"sqlalchemy.sql.crud",
+ #"sqlalchemy.sql.default_comparator",
+ "sqlalchemy.sql.naming",
+ "sqlalchemy.sql.traversals",
"sqlalchemy.util.*",
"sqlalchemy.engine.cursor",
"sqlalchemy.engine.default",