summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/_elements_constructors.py165
-rw-r--r--lib/sqlalchemy/sql/_typing.py46
-rw-r--r--lib/sqlalchemy/sql/annotation.py6
-rw-r--r--lib/sqlalchemy/sql/base.py37
-rw-r--r--lib/sqlalchemy/sql/cache_key.py14
-rw-r--r--lib/sqlalchemy/sql/coercions.py72
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py82
-rw-r--r--lib/sqlalchemy/sql/elements.py1284
-rw-r--r--lib/sqlalchemy/sql/operators.py363
-rw-r--r--lib/sqlalchemy/sql/roles.py38
-rw-r--r--lib/sqlalchemy/sql/schema.py1
-rw-r--r--lib/sqlalchemy/sql/selectable.py50
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py9
-rw-r--r--lib/sqlalchemy/sql/traversals.py5
-rw-r--r--lib/sqlalchemy/sql/type_api.py27
-rw-r--r--lib/sqlalchemy/sql/visitors.py2
16 files changed, 1371 insertions, 830 deletions
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py
index 9d15cdcc3..770fbe40c 100644
--- a/lib/sqlalchemy/sql/_elements_constructors.py
+++ b/lib/sqlalchemy/sql/_elements_constructors.py
@@ -9,17 +9,19 @@ from __future__ import annotations
import typing
from typing import Any
-from typing import cast as _typing_cast
+from typing import Callable
+from typing import Iterable
+from typing import Mapping
from typing import Optional
from typing import overload
-from typing import Type
+from typing import Sequence
+from typing import Tuple as typing_Tuple
from typing import TypeVar
from typing import Union
from . import coercions
-from . import operators
from . import roles
-from .base import NO_ARG
+from .base import _NoArg
from .coercions import _document_text_coercion
from .elements import BindParameter
from .elements import BooleanClauseList
@@ -35,18 +37,20 @@ from .elements import FunctionFilter
from .elements import Label
from .elements import Null
from .elements import Over
-from .elements import SQLCoreOperations
from .elements import TextClause
from .elements import True_
from .elements import Tuple
from .elements import TypeCoerce
from .elements import UnaryExpression
from .elements import WithinGroup
+from .functions import FunctionElement
+from ..util.typing import Literal
if typing.TYPE_CHECKING:
- from elements import BinaryExpression
-
from . import sqltypes
+ from ._typing import _ColumnExpression
+ from ._typing import _TypeEngineArgument
+ from .elements import BinaryExpression
from .functions import FunctionElement
from .selectable import FromClause
from .type_api import TypeEngine
@@ -54,7 +58,7 @@ if typing.TYPE_CHECKING:
_T = TypeVar("_T")
-def all_(expr):
+def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]:
"""Produce an ALL expression.
For dialects such as that of PostgreSQL, this operator applies
@@ -108,7 +112,7 @@ def all_(expr):
return CollectionAggregate._create_all(expr)
-def and_(*clauses):
+def and_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList:
r"""Produce a conjunction of expressions joined by ``AND``.
E.g.::
@@ -169,7 +173,7 @@ def and_(*clauses):
return BooleanClauseList.and_(*clauses)
-def any_(expr):
+def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]:
"""Produce an ANY expression.
For dialects such as that of PostgreSQL, this operator applies
@@ -223,7 +227,7 @@ def any_(expr):
return CollectionAggregate._create_any(expr)
-def asc(column):
+def asc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
"""Produce an ascending ``ORDER BY`` clause element.
e.g.::
@@ -261,7 +265,9 @@ def asc(column):
return UnaryExpression._create_asc(column)
-def collate(expression, collation):
+def collate(
+ expression: _ColumnExpression[str], collation: str
+) -> BinaryExpression[str]:
"""Return the clause ``expression COLLATE collation``.
e.g.::
@@ -282,7 +288,12 @@ def collate(expression, collation):
return CollationClause._create_collation_expression(expression, collation)
-def between(expr, lower_bound, upper_bound, symmetric=False):
+def between(
+ expr: _ColumnExpression[_T],
+ lower_bound: Any,
+ upper_bound: Any,
+ symmetric: bool = False,
+) -> BinaryExpression[bool]:
"""Produce a ``BETWEEN`` predicate clause.
E.g.::
@@ -338,7 +349,9 @@ def between(expr, lower_bound, upper_bound, symmetric=False):
return expr.between(lower_bound, upper_bound, symmetric=symmetric)
-def outparam(key, type_=None):
+def outparam(
+ key: str, type_: Optional[TypeEngine[_T]] = None
+) -> BindParameter[_T]:
"""Create an 'OUT' parameter for usage in functions (stored procedures),
for databases which support them.
@@ -352,16 +365,16 @@ def outparam(key, type_=None):
@overload
-def not_(clause: "BinaryExpression[_T]") -> "BinaryExpression[_T]":
+def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]:
...
@overload
-def not_(clause: "ColumnElement[_T]") -> "UnaryExpression[_T]":
+def not_(clause: _ColumnExpression[_T]) -> ColumnElement[_T]:
...
-def not_(clause: "ColumnElement[_T]") -> "ColumnElement[_T]":
+def not_(clause: _ColumnExpression[_T]) -> ColumnElement[_T]:
"""Return a negation of the given clause, i.e. ``NOT(clause)``.
The ``~`` operator is also overloaded on all
@@ -370,29 +383,21 @@ def not_(clause: "ColumnElement[_T]") -> "ColumnElement[_T]":
"""
- return operators.inv(
- _typing_cast(
- "ColumnElement[_T]",
- coercions.expect(roles.ExpressionElementRole, clause),
- )
- )
+ return coercions.expect(roles.ExpressionElementRole, clause).__invert__()
def bindparam(
- key,
- value=NO_ARG,
- type_: Optional[Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]] = None,
- unique=False,
- required=NO_ARG,
- quote=None,
- callable_=None,
- expanding=False,
- isoutparam=False,
- literal_execute=False,
- _compared_to_operator=None,
- _compared_to_type=None,
- _is_crud=False,
-) -> "BindParameter[_T]":
+ key: str,
+ value: Any = _NoArg.NO_ARG,
+ type_: Optional[TypeEngine[_T]] = None,
+ unique: bool = False,
+ required: Union[bool, Literal[_NoArg.NO_ARG]] = _NoArg.NO_ARG,
+ quote: Optional[bool] = None,
+ callable_: Optional[Callable[[], Any]] = None,
+ expanding: bool = False,
+ isoutparam: bool = False,
+ literal_execute: bool = False,
+) -> BindParameter[_T]:
r"""Produce a "bound expression".
The return value is an instance of :class:`.BindParameter`; this
@@ -636,13 +641,16 @@ def bindparam(
expanding,
isoutparam,
literal_execute,
- _compared_to_operator,
- _compared_to_type,
- _is_crud,
)
-def case(*whens, value=None, else_=None) -> "Case[Any]":
+def case(
+ *whens: Union[
+ typing_Tuple[_ColumnExpression[bool], Any], Mapping[Any, Any]
+ ],
+ value: Optional[Any] = None,
+ else_: Optional[Any] = None,
+) -> Case[Any]:
r"""Produce a ``CASE`` expression.
The ``CASE`` construct in SQL is a conditional object that
@@ -767,9 +775,9 @@ def case(*whens, value=None, else_=None) -> "Case[Any]":
def cast(
- expression: ColumnElement,
- type_: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"],
-) -> "Cast[_T]":
+ expression: _ColumnExpression[Any],
+ type_: _TypeEngineArgument[_T],
+) -> Cast[_T]:
r"""Produce a ``CAST`` expression.
:func:`.cast` returns an instance of :class:`.Cast`.
@@ -826,10 +834,10 @@ def cast(
def column(
text: str,
- type_: Optional[Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]] = None,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
is_literal: bool = False,
- _selectable: Optional["FromClause"] = None,
-) -> "ColumnClause[_T]":
+ _selectable: Optional[FromClause] = None,
+) -> ColumnClause[_T]:
"""Produce a :class:`.ColumnClause` object.
The :class:`.ColumnClause` is a lightweight analogue to the
@@ -921,12 +929,10 @@ def column(
:ref:`sqlexpression_literal_column`
"""
- self = ColumnClause.__new__(ColumnClause)
- self.__init__(text, type_, is_literal, _selectable)
- return self
+ return ColumnClause(text, type_, is_literal, _selectable)
-def desc(column):
+def desc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
"""Produce a descending ``ORDER BY`` clause element.
e.g.::
@@ -965,7 +971,7 @@ def desc(column):
return UnaryExpression._create_desc(column)
-def distinct(expr):
+def distinct(expr: _ColumnExpression[_T]) -> UnaryExpression[_T]:
"""Produce an column-expression-level unary ``DISTINCT`` clause.
This applies the ``DISTINCT`` keyword to an individual column
@@ -1004,7 +1010,7 @@ def distinct(expr):
return UnaryExpression._create_distinct(expr)
-def extract(field: str, expr: ColumnElement) -> "Extract[sqltypes.Integer]":
+def extract(field: str, expr: _ColumnExpression[Any]) -> Extract:
"""Return a :class:`.Extract` construct.
This is typically available as :func:`.extract`
@@ -1045,7 +1051,7 @@ def extract(field: str, expr: ColumnElement) -> "Extract[sqltypes.Integer]":
return Extract(field, expr)
-def false():
+def false() -> False_:
"""Return a :class:`.False_` construct.
E.g.::
@@ -1083,7 +1089,9 @@ def false():
return False_._instance()
-def funcfilter(func, *criterion) -> "FunctionFilter":
+def funcfilter(
+ func: FunctionElement[_T], *criterion: _ColumnExpression[bool]
+) -> FunctionFilter[_T]:
"""Produce a :class:`.FunctionFilter` object against a function.
Used against aggregate and window functions,
@@ -1114,8 +1122,8 @@ def funcfilter(func, *criterion) -> "FunctionFilter":
def label(
name: str,
- element: ColumnElement[_T],
- type_: Optional[Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]] = None,
+ element: _ColumnExpression[_T],
+ type_: Optional[_TypeEngineArgument[_T]] = None,
) -> "Label[_T]":
"""Return a :class:`Label` object for the
given :class:`_expression.ColumnElement`.
@@ -1135,13 +1143,13 @@ def label(
return Label(name, element, type_)
-def null():
+def null() -> Null:
"""Return a constant :class:`.Null` construct."""
return Null._instance()
-def nulls_first(column):
+def nulls_first(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
"""Produce the ``NULLS FIRST`` modifier for an ``ORDER BY`` expression.
:func:`.nulls_first` is intended to modify the expression produced
@@ -1185,7 +1193,7 @@ def nulls_first(column):
return UnaryExpression._create_nulls_first(column)
-def nulls_last(column):
+def nulls_last(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
"""Produce the ``NULLS LAST`` modifier for an ``ORDER BY`` expression.
:func:`.nulls_last` is intended to modify the expression produced
@@ -1229,7 +1237,7 @@ def nulls_last(column):
return UnaryExpression._create_nulls_last(column)
-def or_(*clauses: SQLCoreOperations) -> BooleanClauseList:
+def or_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList:
"""Produce a conjunction of expressions joined by ``OR``.
E.g.::
@@ -1281,12 +1289,16 @@ def or_(*clauses: SQLCoreOperations) -> BooleanClauseList:
def over(
- element: "FunctionElement[_T]",
- partition_by=None,
- order_by=None,
- range_=None,
- rows=None,
-) -> "Over[_T]":
+ element: FunctionElement[_T],
+ partition_by: Optional[
+ Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ ] = None,
+ order_by: Optional[
+ Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ ] = None,
+ range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+ rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+) -> Over[_T]:
r"""Produce an :class:`.Over` object against a function.
Used against aggregate or so-called "window" functions,
@@ -1373,7 +1385,7 @@ def over(
@_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`")
-def text(text):
+def text(text: str) -> TextClause:
r"""Construct a new :class:`_expression.TextClause` clause,
representing
a textual SQL string directly.
@@ -1451,7 +1463,7 @@ def text(text):
return TextClause(text)
-def true():
+def true() -> True_:
"""Return a constant :class:`.True_` construct.
E.g.::
@@ -1489,7 +1501,10 @@ def true():
return True_._instance()
-def tuple_(*clauses: roles.ExpressionElementRole, types=None) -> "Tuple":
+def tuple_(
+ *clauses: _ColumnExpression[Any],
+ types: Optional[Sequence[_TypeEngineArgument[Any]]] = None,
+) -> Tuple:
"""Return a :class:`.Tuple`.
Main usage is to produce a composite IN construct using
@@ -1516,9 +1531,9 @@ def tuple_(*clauses: roles.ExpressionElementRole, types=None) -> "Tuple":
def type_coerce(
- expression: "ColumnElement",
- type_: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"],
-) -> "TypeCoerce[_T]":
+ expression: _ColumnExpression[Any],
+ type_: _TypeEngineArgument[_T],
+) -> TypeCoerce[_T]:
r"""Associate a SQL expression with a particular type, without rendering
``CAST``.
@@ -1597,8 +1612,8 @@ def type_coerce(
def within_group(
- element: "FunctionElement[_T]", *order_by: roles.OrderByRole
-) -> "WithinGroup[_T]":
+ element: FunctionElement[_T], *order_by: _ColumnExpression[Any]
+) -> WithinGroup[_T]:
r"""Produce a :class:`.WithinGroup` object against a function.
Used against so-called "ordered set aggregate" and "hypothetical
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index 69e4645fa..389f7e8d0 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -1,14 +1,58 @@
from __future__ import annotations
+from typing import Any
from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import roles
+from .. import util
from ..inspection import Inspectable
+if TYPE_CHECKING:
+ from .elements import quoted_name
+ from .schema import DefaultGenerator
+ from .schema import Sequence
+ from .selectable import FromClause
+ from .selectable import NamedFromClause
+ from .selectable import TableClause
+ from .sqltypes import TupleType
+ from .type_api import TypeEngine
+ from ..util.typing import TypeGuard
+
+_T = TypeVar("_T", bound=Any)
+
_ColumnsClauseElement = Union[
- roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement]
+ roles.ColumnsClauseRole,
+ Type,
+ Inspectable[roles.HasColumnElementClauseElement],
]
_FromClauseElement = Union[
roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement]
]
+
+_ColumnExpression = Union[
+ roles.ExpressionElementRole[_T],
+ Inspectable[roles.HasColumnElementClauseElement],
+]
+
+_PropagateAttrsType = util.immutabledict[str, Any]
+
+_TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
+
+
+def is_named_from_clause(t: FromClause) -> TypeGuard[NamedFromClause]:
+ return t.named_with_column
+
+
+def has_schema_attr(t: FromClause) -> TypeGuard[TableClause]:
+ return hasattr(t, "schema")
+
+
+def is_quoted_name(s: str) -> TypeGuard[quoted_name]:
+ return hasattr(s, "quote")
+
+
+def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]:
+ return t._is_tuple_type
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 7afc2de97..f37ae9a60 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -18,11 +18,11 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
+from typing import FrozenSet
from typing import Mapping
from typing import Optional
from typing import overload
from typing import Sequence
-from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -53,7 +53,9 @@ class SupportsAnnotations(ExternallyTraversible):
__slots__ = ()
_annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS
- proxy_set: Set[SupportsAnnotations]
+
+ proxy_set: util.generic_fn_descriptor[FrozenSet[Any]]
+
_is_immutable: bool
def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations:
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index a408a010a..29f9028c8 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -13,16 +13,20 @@
from __future__ import annotations
import collections.abc as collections_abc
+from enum import Enum
from functools import reduce
import itertools
from itertools import zip_longest
import operator
import re
import typing
+from typing import Any
+from typing import Iterable
+from typing import List
from typing import MutableMapping
from typing import Optional
from typing import Sequence
-from typing import Set
+from typing import Tuple
from typing import TypeVar
from . import roles
@@ -38,23 +42,33 @@ from .. import util
from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
+from ..util.typing import Self
if typing.TYPE_CHECKING:
+ from .elements import BindParameter
from .elements import ColumnElement
from ..engine import Connection
from ..engine import Result
+ from ..engine.base import _CompiledCacheType
from ..engine.interfaces import _CoreMultiExecuteParams
from ..engine.interfaces import _ExecuteOptions
from ..engine.interfaces import _ExecuteOptionsParameter
from ..engine.interfaces import _ImmutableExecuteOptions
+ from ..engine.interfaces import _SchemaTranslateMapType
from ..engine.interfaces import CacheStats
-
+ from ..engine.interfaces import Compiled
+ from ..engine.interfaces import Dialect
coercions = None
elements = None
type_api = None
-NO_ARG = util.symbol("NO_ARG")
+
+class _NoArg(Enum):
+ NO_ARG = 0
+
+
+NO_ARG = _NoArg.NO_ARG
# if I use sqlalchemy.util.typing, which has the exact same
# symbols, mypy reports: "error: _Fn? not callable"
@@ -74,10 +88,12 @@ class Immutable:
def params(self, *optionaldict, **kwargs):
raise NotImplementedError("Immutable objects do not support copying")
- def _clone(self, **kw):
+ def _clone(self: Self, **kw: Any) -> Self:
return self
- def _copy_internals(self, **kw):
+ def _copy_internals(
+ self, omit_attrs: Iterable[str] = (), **kw: Any
+ ) -> None:
pass
@@ -88,8 +104,6 @@ class SingletonConstant(Immutable):
_singleton: SingletonConstant
- proxy_set: Set[ColumnElement]
-
def __new__(cls, *arg, **kw):
return cls._singleton
@@ -877,12 +891,15 @@ class Executable(roles.StatementRole, Generative):
def _compile_w_cache(
self,
dialect: Dialect,
- compiled_cache: Optional[_CompiledCacheType] = None,
- column_keys: Optional[Sequence[str]] = None,
+ *,
+ compiled_cache: Optional[_CompiledCacheType],
+ column_keys: List[str],
for_executemany: bool = False,
schema_translate_map: Optional[_SchemaTranslateMapType] = None,
**kw: Any,
- ) -> Tuple[Compiled, _SingleExecuteParams, CacheStats]:
+ ) -> Tuple[
+ Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats
+ ]:
...
def _execute_on_connection(
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index fca58f98e..19a232c56 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -11,15 +11,14 @@ import enum
from itertools import zip_longest
import typing
from typing import Any
-from typing import cast
from typing import Dict
from typing import Iterator
from typing import List
+from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
-from typing import Type
from typing import Union
from .visitors import anon_map
@@ -91,7 +90,7 @@ class HasCacheKey:
__slots__ = ()
_cache_key_traversal: Union[
- _TraverseInternalsType, Literal[CacheConst.NO_CACHE]
+ _TraverseInternalsType, Literal[CacheConst.NO_CACHE], Literal[None]
] = NO_CACHE
_is_has_cache_key = True
@@ -147,11 +146,8 @@ class HasCacheKey:
_cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
if _cache_key_traversal is None:
try:
- # check for _traverse_internals, which is part of
- # HasTraverseInternals
- _cache_key_traversal = cast(
- "Type[HasTraverseInternals]", cls
- )._traverse_internals
+ assert issubclass(cls, HasTraverseInternals)
+ _cache_key_traversal = cls._traverse_internals
except AttributeError:
cls._generated_cache_key_traversal = NO_CACHE
return NO_CACHE
@@ -417,7 +413,7 @@ class CacheKey(NamedTuple):
def to_offline_string(
self,
- statement_cache: _CompiledCacheType,
+ statement_cache: MutableMapping[Any, str],
statement: ClauseElement,
parameters: _CoreSingleExecuteParams,
) -> str:
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index 834bfb75d..ea17b8e03 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -13,10 +13,12 @@ import re
import typing
from typing import Any
from typing import Any as TODO_Any
+from typing import Callable
from typing import Dict
from typing import List
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Type
from typing import TypeVar
@@ -46,9 +48,14 @@ if typing.TYPE_CHECKING:
from . import traversals
from .elements import ClauseElement
from .elements import ColumnClause
+ from .elements import ColumnElement
+ from .elements import SQLCoreOperations
+
_SR = TypeVar("_SR", bound=roles.SQLRole)
+_F = TypeVar("_F", bound=Callable[..., Any])
_StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole)
+_T = TypeVar("_T", bound=Any)
def _is_literal(element):
@@ -104,7 +111,9 @@ def _deep_is_literal(element):
)
-def _document_text_coercion(paramname, meth_rst, param_rst):
+def _document_text_coercion(
+ paramname: str, meth_rst: str, param_rst: str
+) -> Callable[[_F], _F]:
return util.add_parameter_text(
paramname,
(
@@ -132,15 +141,50 @@ def _expression_collection_was_a_list(attrname, fnname, args):
return args
-# TODO; would like to have overloads here, however mypy is being extremely
-# pedantic about them. not sure why pylance is OK with them.
+@overload
+def expect(
+ role: Type[roles.TruncatedLabelRole],
+ element: Any,
+ *,
+ apply_propagate_attrs: Optional[ClauseElement] = None,
+ argname: Optional[str] = None,
+ post_inspect: bool = False,
+ **kw: Any,
+) -> str:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.ExpressionElementRole[_T]],
+ element: Any,
+ *,
+ apply_propagate_attrs: Optional[ClauseElement] = None,
+ argname: Optional[str] = None,
+ post_inspect: bool = False,
+ **kw: Any,
+) -> ColumnElement[_T]:
+ ...
+@overload
def expect(
role: Type[_SR],
element: Any,
*,
- apply_propagate_attrs: Optional["ClauseElement"] = None,
+ apply_propagate_attrs: Optional[ClauseElement] = None,
+ argname: Optional[str] = None,
+ post_inspect: bool = False,
+ **kw: Any,
+) -> TODO_Any:
+ ...
+
+
+def expect(
+ role: Type[_SR],
+ element: Any,
+ *,
+ apply_propagate_attrs: Optional[ClauseElement] = None,
argname: Optional[str] = None,
post_inspect: bool = False,
**kw: Any,
@@ -220,12 +264,16 @@ def expect(
resolved = element
else:
resolved = element
- if (
- apply_propagate_attrs is not None
- and not apply_propagate_attrs._propagate_attrs
- and resolved._propagate_attrs
- ):
- apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs
+
+ if apply_propagate_attrs is not None:
+ if typing.TYPE_CHECKING:
+ assert isinstance(resolved, (SQLCoreOperations, ClauseElement))
+
+ if (
+ not apply_propagate_attrs._propagate_attrs
+ and resolved._propagate_attrs
+ ):
+ apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs
if impl._role_class in resolved.__class__.__mro__:
if impl._post_coercion:
@@ -620,8 +668,8 @@ class InElementImpl(RoleImpl):
element, str
):
non_literal_expressions: Dict[
- Optional[operators.ColumnOperators[Any]],
- operators.ColumnOperators[Any],
+ Optional[operators.ColumnOperators],
+ operators.ColumnOperators,
] = {}
element = list(element)
for o in element:
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 001710d7b..91bb0a5c5 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -42,13 +42,14 @@ _T = typing.TypeVar("_T", bound=Any)
if typing.TYPE_CHECKING:
from .elements import ColumnElement
+ from .operators import custom_op
from .sqltypes import TypeEngine
def _boolean_compare(
- expr: "ColumnElement",
+ expr: ColumnElement[Any],
op: OperatorType,
- obj: roles.BinaryElementRole,
+ obj: Any,
*,
negate_op: Optional[OperatorType] = None,
reverse: bool = False,
@@ -59,7 +60,6 @@ def _boolean_compare(
] = None,
**kwargs: Any,
) -> BinaryExpression[bool]:
-
if result_type is None:
result_type = type_api.BOOLEANTYPE
@@ -143,7 +143,14 @@ def _boolean_compare(
)
-def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw):
+def _custom_op_operate(
+ expr: ColumnElement[Any],
+ op: custom_op[Any],
+ obj: Any,
+ reverse: bool = False,
+ result_type: Optional[TypeEngine[Any]] = None,
+ **kw: Any,
+) -> ColumnElement[Any]:
if result_type is None:
if op.return_type:
result_type = op.return_type
@@ -156,11 +163,11 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw):
def _binary_operate(
- expr: "ColumnElement",
+ expr: ColumnElement[Any],
op: OperatorType,
obj: roles.BinaryElementRole,
*,
- reverse=False,
+ reverse: bool = False,
result_type: Optional[
Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
] = None,
@@ -184,7 +191,9 @@ def _binary_operate(
return BinaryExpression(left, right, op, type_=result_type, modifiers=kw)
-def _conjunction_operate(expr, op, other, **kw) -> "ColumnElement":
+def _conjunction_operate(
+ expr: ColumnElement[Any], op: OperatorType, other, **kw
+) -> ColumnElement[Any]:
if op is operators.and_:
return and_(expr, other)
elif op is operators.or_:
@@ -193,11 +202,19 @@ def _conjunction_operate(expr, op, other, **kw) -> "ColumnElement":
raise NotImplementedError()
-def _scalar(expr, op, fn, **kw) -> "ColumnElement":
+def _scalar(
+ expr: ColumnElement[Any], op: OperatorType, fn, **kw
+) -> ColumnElement[Any]:
return fn(expr)
-def _in_impl(expr, op, seq_or_selectable, negate_op, **kw) -> "ColumnElement":
+def _in_impl(
+ expr: ColumnElement[Any],
+ op: OperatorType,
+ seq_or_selectable,
+ negate_op: OperatorType,
+ **kw,
+) -> ColumnElement[Any]:
seq_or_selectable = coercions.expect(
roles.InElementRole, seq_or_selectable, expr=expr, operator=op
)
@@ -209,7 +226,9 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw) -> "ColumnElement":
)
-def _getitem_impl(expr, op, other, **kw) -> "ColumnElement":
+def _getitem_impl(
+ expr: ColumnElement[Any], op: OperatorType, other, **kw
+) -> ColumnElement[Any]:
if isinstance(expr.type, type_api.INDEXABLE):
other = coercions.expect(
roles.BinaryElementRole, other, expr=expr, operator=op
@@ -219,13 +238,17 @@ def _getitem_impl(expr, op, other, **kw) -> "ColumnElement":
_unsupported_impl(expr, op, other, **kw)
-def _unsupported_impl(expr, op, *arg, **kw) -> NoReturn:
+def _unsupported_impl(
+ expr: ColumnElement[Any], op: OperatorType, *arg, **kw
+) -> NoReturn:
raise NotImplementedError(
"Operator '%s' is not supported on " "this expression" % op.__name__
)
-def _inv_impl(expr, op, **kw) -> "ColumnElement":
+def _inv_impl(
+ expr: ColumnElement[Any], op: OperatorType, **kw
+) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.__inv__`."""
# undocumented element currently used by the ORM for
@@ -236,12 +259,16 @@ def _inv_impl(expr, op, **kw) -> "ColumnElement":
return expr._negate()
-def _neg_impl(expr, op, **kw) -> "ColumnElement":
+def _neg_impl(
+ expr: ColumnElement[Any], op: OperatorType, **kw
+) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.__neg__`."""
return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
-def _match_impl(expr, op, other, **kw) -> "ColumnElement":
+def _match_impl(
+ expr: ColumnElement[Any], op: OperatorType, other, **kw
+) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.match`."""
return _boolean_compare(
@@ -261,14 +288,18 @@ def _match_impl(expr, op, other, **kw) -> "ColumnElement":
)
-def _distinct_impl(expr, op, **kw) -> "ColumnElement":
+def _distinct_impl(
+ expr: ColumnElement[Any], op: OperatorType, **kw
+) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.distinct`."""
return UnaryExpression(
expr, operator=operators.distinct_op, type_=expr.type
)
-def _between_impl(expr, op, cleft, cright, **kw) -> "ColumnElement":
+def _between_impl(
+ expr: ColumnElement[Any], op: OperatorType, cleft, cright, **kw
+) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.between`."""
return BinaryExpression(
expr,
@@ -297,11 +328,15 @@ def _between_impl(expr, op, cleft, cright, **kw) -> "ColumnElement":
)
-def _collate_impl(expr, op, collation, **kw) -> "ColumnElement":
+def _collate_impl(
+ expr: ColumnElement[Any], op: OperatorType, collation, **kw
+) -> ColumnElement[Any]:
return CollationClause._create_collation_expression(expr, collation)
-def _regexp_match_impl(expr, op, pattern, flags, **kw) -> "ColumnElement":
+def _regexp_match_impl(
+ expr: ColumnElement[Any], op: OperatorType, pattern, flags, **kw
+) -> ColumnElement[Any]:
if flags is not None:
flags = coercions.expect(
roles.BinaryElementRole,
@@ -322,8 +357,13 @@ def _regexp_match_impl(expr, op, pattern, flags, **kw) -> "ColumnElement":
def _regexp_replace_impl(
- expr, op, pattern, replacement, flags, **kw
-) -> "ColumnElement":
+ expr: ColumnElement[Any],
+ op: OperatorType,
+ pattern,
+ replacement,
+ flags,
+ **kw,
+) -> ColumnElement[Any]:
replacement = coercions.expect(
roles.BinaryElementRole,
replacement,
@@ -345,7 +385,7 @@ def _regexp_replace_impl(
# a mapping of operators with the method they use, along with
# additional keyword arguments to be passed
operator_lookup: Dict[
- str, Tuple[Callable[..., "ColumnElement"], util.immutabledict]
+ str, Tuple[Callable[..., ColumnElement[Any]], util.immutabledict]
] = {
"and_": (_conjunction_operate, util.EMPTY_DICT),
"or_": (_conjunction_operate, util.EMPTY_DICT),
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 08d632afd..fdb3fc8bb 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -12,20 +12,28 @@
from __future__ import annotations
+from decimal import Decimal
+from enum import IntEnum
import itertools
import operator
import re
import typing
from typing import Any
from typing import Callable
+from typing import cast
from typing import Dict
+from typing import FrozenSet
from typing import Generic
+from typing import Iterable
from typing import List
+from typing import Mapping
from typing import Optional
from typing import overload
from typing import Sequence
-from typing import Text as typing_Text
+from typing import Set
+from typing import Tuple as typing_Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -34,10 +42,15 @@ from . import operators
from . import roles
from . import traversals
from . import type_api
+from ._typing import has_schema_attr
+from ._typing import is_named_from_clause
+from ._typing import is_quoted_name
+from ._typing import is_tuple_type
from .annotation import Annotated
from .annotation import SupportsWrappingAnnotations
from .base import _clone
from .base import _generative
+from .base import _NoArg
from .base import Executable
from .base import HasMemoized
from .base import Immutable
@@ -57,30 +70,47 @@ from .. import exc
from .. import inspection
from .. import util
from ..util.langhelpers import TypingOnly
+from ..util.typing import Literal
if typing.TYPE_CHECKING:
- from decimal import Decimal
-
+ from ._typing import _ColumnExpression
+ from ._typing import _PropagateAttrsType
+ from ._typing import _TypeEngineArgument
+ from .cache_key import CacheKey
from .compiler import Compiled
from .compiler import SQLCompiler
+ from .functions import FunctionElement
from .operators import OperatorType
+ from .schema import Column
+ from .schema import DefaultGenerator
+ from .schema import ForeignKey
from .selectable import FromClause
+ from .selectable import NamedFromClause
+ from .selectable import ReturnsRows
from .selectable import Select
- from .sqltypes import Boolean # noqa
+ from .selectable import TableClause
+ from .sqltypes import Boolean
+ from .sqltypes import TupleType
from .type_api import TypeEngine
+ from .visitors import _TraverseInternalsType
from ..engine import Connection
from ..engine import Dialect
from ..engine import Engine
from ..engine.base import _CompiledCacheType
- from ..engine.base import _SchemaTranslateMapType
-
+ from ..engine.interfaces import _CoreMultiExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _SchemaTranslateMapType
+ from ..engine.interfaces import CacheStats
+ from ..engine.result import Result
-_NUMERIC = Union[complex, "Decimal"]
+_NUMERIC = Union[complex, Decimal]
+_NUMBER = Union[complex, int, Decimal]
_T = TypeVar("_T", bound="Any")
_OPT = TypeVar("_OPT", bound="Any")
_NT = TypeVar("_NT", bound="_NUMERIC")
-_ST = TypeVar("_ST", bound="typing_Text")
+
+_NMT = TypeVar("_NMT", bound="_NUMBER")
def literal(value, type_=None):
@@ -210,28 +240,27 @@ class CompilerElement(Visitable):
"""
- if not dialect:
+ if dialect is None:
if bind:
dialect = bind.dialect
+ elif self.stringify_dialect == "default":
+ default = util.preloaded.engine_default
+ dialect = default.StrCompileDialect()
else:
- if self.stringify_dialect == "default":
- default = util.preloaded.engine_default
- dialect = default.StrCompileDialect()
- else:
- url = util.preloaded.engine_url
- dialect = url.URL.create(
- self.stringify_dialect
- ).get_dialect()()
+ url = util.preloaded.engine_url
+ dialect = url.URL.create(
+ self.stringify_dialect
+ ).get_dialect()()
return self._compiler(dialect, **kw)
- def _compiler(self, dialect, **kw):
+ def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled:
"""Return a compiler appropriate for this ClauseElement, given a
Dialect."""
return dialect.statement_compiler(dialect, self, **kw)
- def __str__(self):
+ def __str__(self) -> str:
return str(self.compile())
@@ -253,16 +282,17 @@ class ClauseElement(
__visit_name__ = "clause"
- _propagate_attrs = util.immutabledict()
+ _propagate_attrs: _PropagateAttrsType = util.immutabledict()
"""like annotations, however these propagate outwards liberally
as SQL constructs are built, and are set up at construction time.
"""
- _from_objects = []
- bind = None
- description = None
- _is_clone_of = None
+ @util.memoized_property
+ def description(self) -> Optional[str]:
+ return None
+
+ _is_clone_of: Optional[ClauseElement] = None
is_clause_element = True
is_selectable = False
@@ -281,10 +311,25 @@ class ClauseElement(
_is_singleton_constant = False
_is_immutable = False
- _order_by_label_element = None
+ @property
+ def _order_by_label_element(self) -> Optional[Label[Any]]:
+ return None
_cache_key_traversal = None
+ negation_clause: ClauseElement
+
+ if typing.TYPE_CHECKING:
+
+ def get_children(
+ self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any
+ ) -> Iterable[ClauseElement]:
+ ...
+
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
+ return []
+
def _set_propagate_attrs(self, values):
# usually, self._propagate_attrs is empty here. one case where it's
# not is a subquery against ORM select, that is then pulled as a
@@ -295,7 +340,7 @@ class ClauseElement(
self._propagate_attrs = util.immutabledict(values)
return self
- def _clone(self: SelfClauseElement, **kw) -> SelfClauseElement:
+ def _clone(self: SelfClauseElement, **kw: Any) -> SelfClauseElement:
"""Create a shallow copy of this ClauseElement.
This method may be used by a generative API. Its also used as
@@ -357,7 +402,7 @@ class ClauseElement(
"""
s = util.column_set()
- f = self
+ f: Optional[ClauseElement] = self
# note this creates a cycle, asserted in test_memusage. however,
# turning this into a plain @property adds tends of thousands of method
@@ -383,16 +428,26 @@ class ClauseElement(
return d
def _execute_on_connection(
- self, connection, distilled_params, execution_options, _force=False
- ):
+ self,
+ connection: Connection,
+ distilled_params: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptions,
+ _force: bool = False,
+ ) -> Result:
if _force or self.supports_execution:
+ if TYPE_CHECKING:
+ assert isinstance(self, Executable)
return connection._execute_clauseelement(
self, distilled_params, execution_options
)
else:
raise exc.ObjectNotExecutableError(self)
- def unique_params(self, *optionaldict, **kwargs):
+ def unique_params(
+ self: SelfClauseElement,
+ __optionaldict: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> SelfClauseElement:
"""Return a copy with :func:`_expression.bindparam` elements
replaced.
@@ -402,11 +457,13 @@ class ClauseElement(
used.
"""
- return self._replace_params(True, optionaldict, kwargs)
+ return self._replace_params(True, __optionaldict, kwargs)
def params(
- self, *optionaldict: Dict[str, Any], **kwargs: Any
- ) -> ClauseElement:
+ self: SelfClauseElement,
+ __optionaldict: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> SelfClauseElement:
"""Return a copy with :func:`_expression.bindparam` elements
replaced.
@@ -421,33 +478,32 @@ class ClauseElement(
{'foo':7}
"""
- return self._replace_params(False, optionaldict, kwargs)
+ return self._replace_params(False, __optionaldict, kwargs)
def _replace_params(
- self,
+ self: SelfClauseElement,
unique: bool,
optionaldict: Optional[Dict[str, Any]],
kwargs: Dict[str, Any],
- ) -> ClauseElement:
+ ) -> SelfClauseElement:
- if len(optionaldict) == 1:
- kwargs.update(optionaldict[0])
- elif len(optionaldict) > 1:
- raise exc.ArgumentError(
- "params() takes zero or one positional dictionary argument"
- )
+ if optionaldict:
+ kwargs.update(optionaldict)
- def visit_bindparam(bind):
+ def visit_bindparam(bind: BindParameter[Any]) -> None:
if bind.key in kwargs:
bind.value = kwargs[bind.key]
bind.required = False
if unique:
bind._convert_to_unique()
- return cloned_traverse(
- self,
- {"maintain_key": True, "detect_subquery_cols": True},
- {"bindparam": visit_bindparam},
+ return cast(
+ SelfClauseElement,
+ cloned_traverse(
+ self,
+ {"maintain_key": True, "detect_subquery_cols": True},
+ {"bindparam": visit_bindparam},
+ ),
)
def compare(self, other, **kw):
@@ -501,18 +557,26 @@ class ClauseElement(
def _compile_w_cache(
self,
dialect: Dialect,
- compiled_cache: Optional[_CompiledCacheType] = None,
- column_keys: Optional[List[str]] = None,
+ *,
+ compiled_cache: Optional[_CompiledCacheType],
+ column_keys: List[str],
for_executemany: bool = False,
schema_translate_map: Optional[_SchemaTranslateMapType] = None,
**kw: Any,
- ):
+ ) -> typing_Tuple[
+ Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats
+ ]:
+ elem_cache_key: Optional[CacheKey]
+
if compiled_cache is not None and dialect._supports_statement_cache:
elem_cache_key = self._generate_cache_key()
else:
elem_cache_key = None
- if elem_cache_key:
+ if elem_cache_key is not None:
+ if TYPE_CHECKING:
+ assert compiled_cache is not None
+
cache_key, extracted_params = elem_cache_key
key = (
dialect,
@@ -564,7 +628,7 @@ class ClauseElement(
else:
return self._negate()
- def _negate(self):
+ def _negate(self) -> ClauseElement:
return UnaryExpression(
self.self_group(against=operators.inv), operator=operators.inv
)
@@ -605,6 +669,9 @@ class DQLDMLClauseElement(ClauseElement):
) -> SQLCompiler:
...
+ def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler:
+ ...
+
class CompilerColumnElement(
roles.DMLColumnRole,
@@ -621,9 +688,7 @@ class CompilerColumnElement(
__slots__ = ()
-class SQLCoreOperations(
- Generic[_T], ColumnOperators["SQLCoreOperations"], TypingOnly
-):
+class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
__slots__ = ()
# annotations for comparison methods
@@ -631,173 +696,186 @@ class SQLCoreOperations(
# redefined with the specific types returned by ColumnElement hierarchies
if typing.TYPE_CHECKING:
+ _propagate_attrs: _PropagateAttrsType
+
def operate(
self, op: OperatorType, *other: Any, **kwargs: Any
- ) -> ColumnElement:
+ ) -> ColumnElement[Any]:
...
def reverse_operate(
self, op: OperatorType, other: Any, **kwargs: Any
- ) -> ColumnElement:
+ ) -> ColumnElement[Any]:
...
def op(
self,
- opstring: Any,
+ opstring: str,
precedence: int = 0,
is_comparison: bool = False,
- return_type: Optional[
- Union[Type["TypeEngine[_OPT]"], "TypeEngine[_OPT]"]
- ] = None,
- python_impl=None,
- ) -> Callable[[Any], "BinaryExpression[_OPT]"]:
+ return_type: Optional[_TypeEngineArgument[_OPT]] = None,
+ python_impl: Optional[Callable[..., Any]] = None,
+ ) -> Callable[[Any], BinaryExpression[_OPT]]:
...
def bool_op(
- self, opstring: Any, precedence: int = 0, python_impl=None
- ) -> Callable[[Any], "BinaryExpression[bool]"]:
+ self,
+ opstring: str,
+ precedence: int = 0,
+ python_impl: Optional[Callable[..., Any]] = None,
+ ) -> Callable[[Any], BinaryExpression[bool]]:
...
- def __and__(self, other: Any) -> "BooleanClauseList":
+ def __and__(self, other: Any) -> BooleanClauseList:
...
- def __or__(self, other: Any) -> "BooleanClauseList":
+ def __or__(self, other: Any) -> BooleanClauseList:
...
- def __invert__(self) -> "UnaryExpression[_T]":
+ def __invert__(self) -> ColumnElement[_T]:
...
- def __lt__(self, other: Any) -> "ColumnElement[bool]":
+ def __lt__(self, other: Any) -> ColumnElement[bool]:
...
- def __le__(self, other: Any) -> "ColumnElement[bool]":
+ def __le__(self, other: Any) -> ColumnElement[bool]:
...
- def __eq__(self, other: Any) -> "ColumnElement[bool]": # type: ignore[override] # noqa: E501
+ def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
...
- def __ne__(self, other: Any) -> "ColumnElement[bool]": # type: ignore[override] # noqa: E501
+ def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
...
- def is_distinct_from(self, other: Any) -> "ColumnElement[bool]":
+ def is_distinct_from(self, other: Any) -> ColumnElement[bool]:
...
- def is_not_distinct_from(self, other: Any) -> "ColumnElement[bool]":
+ def is_not_distinct_from(self, other: Any) -> ColumnElement[bool]:
...
- def __gt__(self, other: Any) -> "ColumnElement[bool]":
+ def __gt__(self, other: Any) -> ColumnElement[bool]:
...
- def __ge__(self, other: Any) -> "ColumnElement[bool]":
+ def __ge__(self, other: Any) -> ColumnElement[bool]:
...
- def __neg__(self) -> "UnaryExpression[_T]":
+ def __neg__(self) -> UnaryExpression[_T]:
...
- def __contains__(self, other: Any) -> "ColumnElement[bool]":
+ def __contains__(self, other: Any) -> ColumnElement[bool]:
...
- def __getitem__(self, index: Any) -> "ColumnElement":
+ def __getitem__(self, index: Any) -> ColumnElement[Any]:
...
@overload
- def concat(
- self: "SQLCoreOperations[_ST]", other: Any
- ) -> "ColumnElement[_ST]":
+ def concat(self: _SQO[str], other: Any) -> ColumnElement[str]:
...
@overload
- def concat(self, other: Any) -> "ColumnElement":
+ def concat(self, other: Any) -> ColumnElement[Any]:
...
- def concat(self, other: Any) -> "ColumnElement":
+ def concat(self, other: Any) -> ColumnElement[Any]:
...
- def like(self, other: Any, escape=None) -> "BinaryExpression[bool]":
+ def like(
+ self, other: Any, escape: Optional[str] = None
+ ) -> BinaryExpression[bool]:
...
- def ilike(self, other: Any, escape=None) -> "BinaryExpression[bool]":
+ def ilike(
+ self, other: Any, escape: Optional[str] = None
+ ) -> BinaryExpression[bool]:
...
def in_(
self,
- other: Union[Sequence[Any], "BindParameter", "Select"],
- ) -> "BinaryExpression[bool]":
+ other: Union[Sequence[Any], BindParameter[Any], Select],
+ ) -> BinaryExpression[bool]:
...
def not_in(
self,
- other: Union[Sequence[Any], "BindParameter", "Select"],
- ) -> "BinaryExpression[bool]":
+ other: Union[Sequence[Any], BindParameter[Any], Select],
+ ) -> BinaryExpression[bool]:
...
def not_like(
- self, other: Any, escape=None
- ) -> "BinaryExpression[bool]":
+ self, other: Any, escape: Optional[str] = None
+ ) -> BinaryExpression[bool]:
...
def not_ilike(
- self, other: Any, escape=None
- ) -> "BinaryExpression[bool]":
+ self, other: Any, escape: Optional[str] = None
+ ) -> BinaryExpression[bool]:
...
- def is_(self, other: Any) -> "BinaryExpression[bool]":
+ def is_(self, other: Any) -> BinaryExpression[bool]:
...
- def is_not(self, other: Any) -> "BinaryExpression[bool]":
+ def is_not(self, other: Any) -> BinaryExpression[bool]:
...
def startswith(
- self, other: Any, escape=None, autoescape=False
- ) -> "ColumnElement[bool]":
+ self,
+ other: Any,
+ escape: Optional[str] = None,
+ autoescape: bool = False,
+ ) -> ColumnElement[bool]:
...
def endswith(
- self, other: Any, escape=None, autoescape=False
- ) -> "ColumnElement[bool]":
+ self,
+ other: Any,
+ escape: Optional[str] = None,
+ autoescape: bool = False,
+ ) -> ColumnElement[bool]:
...
- def contains(self, other: Any, **kw: Any) -> "ColumnElement[bool]":
+ def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]:
...
- def match(self, other: Any, **kwargs) -> "ColumnElement[bool]":
+ def match(self, other: Any, **kwargs: Any) -> ColumnElement[bool]:
...
- def regexp_match(self, pattern, flags=None) -> "ColumnElement[bool]":
+ def regexp_match(
+ self, pattern: Any, flags: Optional[str] = None
+ ) -> ColumnElement[bool]:
...
def regexp_replace(
- self, pattern, replacement, flags=None
- ) -> "ColumnElement":
+ self, pattern: Any, replacement: Any, flags: Optional[str] = None
+ ) -> ColumnElement[str]:
...
- def desc(self) -> "UnaryExpression[_T]":
+ def desc(self) -> UnaryExpression[_T]:
...
- def asc(self) -> "UnaryExpression[_T]":
+ def asc(self) -> UnaryExpression[_T]:
...
- def nulls_first(self) -> "UnaryExpression[_T]":
+ def nulls_first(self) -> UnaryExpression[_T]:
...
- def nulls_last(self) -> "UnaryExpression[_T]":
+ def nulls_last(self) -> UnaryExpression[_T]:
...
- def collate(self, collation) -> "CollationClause":
+ def collate(self, collation: str) -> CollationClause:
...
def between(
- self, cleft, cright, symmetric=False
- ) -> "ColumnElement[bool]":
+ self, cleft: Any, cright: Any, symmetric: bool = False
+ ) -> BinaryExpression[bool]:
...
- def distinct(self: "SQLCoreOperations[_T]") -> "UnaryExpression[_T]":
+ def distinct(self: _SQO[_T]) -> UnaryExpression[_T]:
...
- def any_(self) -> "CollectionAggregate":
+ def any_(self) -> CollectionAggregate[Any]:
...
- def all_(self) -> "CollectionAggregate":
+ def all_(self) -> CollectionAggregate[Any]:
...
# numeric overloads. These need more tweaking
@@ -807,179 +885,173 @@ class SQLCoreOperations(
@overload
def __add__(
- self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]",
- other: "Union[_SQO[Optional[_NT]], _SQO[_NT], _NT]",
- ) -> "ColumnElement[_NT]":
+ self: _SQO[_NMT],
+ other: Any,
+ ) -> ColumnElement[_NMT]:
...
@overload
def __add__(
- self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]",
+ self: _SQO[str],
other: Any,
- ) -> "ColumnElement[_NUMERIC]":
+ ) -> ColumnElement[str]:
...
- @overload
- def __add__(
- self: "Union[_SQO[_ST], _SQO[Optional[_ST]]]",
- other: Any,
- ) -> "ColumnElement[_ST]":
+ def __add__(self, other: Any) -> ColumnElement[Any]:
...
- def __add__(self, other: Any) -> "ColumnElement":
+ @overload
+ def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]:
...
@overload
- def __radd__(self, other: Any) -> "ColumnElement[_NUMERIC]":
+ def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]:
...
@overload
- def __radd__(self, other: Any) -> "ColumnElement":
+ def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]:
...
- def __radd__(self, other: Any) -> "ColumnElement":
+ def __radd__(self, other: Any) -> ColumnElement[Any]:
...
@overload
def __sub__(
- self: "SQLCoreOperations[_NT]",
- other: "Union[SQLCoreOperations[_NT], _NT]",
- ) -> "ColumnElement[_NT]":
+ self: _SQO[_NMT],
+ other: Any,
+ ) -> ColumnElement[_NMT]:
...
@overload
- def __sub__(self, other: Any) -> "ColumnElement":
+ def __sub__(self, other: Any) -> ColumnElement[Any]:
...
- def __sub__(self, other: Any) -> "ColumnElement":
+ def __sub__(self, other: Any) -> ColumnElement[Any]:
...
@overload
def __rsub__(
- self: "SQLCoreOperations[_NT]", other: Any
- ) -> "ColumnElement[_NUMERIC]":
+ self: _SQO[_NMT],
+ other: Any,
+ ) -> ColumnElement[_NMT]:
...
@overload
- def __rsub__(self, other: Any) -> "ColumnElement":
+ def __rsub__(self, other: Any) -> ColumnElement[Any]:
...
- def __rsub__(self, other: Any) -> "ColumnElement":
+ def __rsub__(self, other: Any) -> ColumnElement[Any]:
...
@overload
def __mul__(
- self: "SQLCoreOperations[_NT]", other: Any
- ) -> "ColumnElement[_NUMERIC]":
+ self: _SQO[_NMT],
+ other: Any,
+ ) -> ColumnElement[_NMT]:
...
@overload
- def __mul__(self, other: Any) -> "ColumnElement":
+ def __mul__(self, other: Any) -> ColumnElement[Any]:
...
- def __mul__(self, other: Any) -> "ColumnElement":
+ def __mul__(self, other: Any) -> ColumnElement[Any]:
...
@overload
def __rmul__(
- self: "SQLCoreOperations[_NT]", other: Any
- ) -> "ColumnElement[_NUMERIC]":
+ self: _SQO[_NMT],
+ other: Any,
+ ) -> ColumnElement[_NMT]:
...
@overload
- def __rmul__(self, other: Any) -> "ColumnElement":
+ def __rmul__(self, other: Any) -> ColumnElement[Any]:
...
- def __rmul__(self, other: Any) -> "ColumnElement":
+ def __rmul__(self, other: Any) -> ColumnElement[Any]:
...
@overload
- def __mod__(
- self: "SQLCoreOperations[_NT]", other: Any
- ) -> "ColumnElement[_NUMERIC]":
+ def __mod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]:
...
@overload
- def __mod__(self, other: Any) -> "ColumnElement":
+ def __mod__(self, other: Any) -> ColumnElement[Any]:
...
- def __mod__(self, other: Any) -> "ColumnElement":
+ def __mod__(self, other: Any) -> ColumnElement[Any]:
...
@overload
- def __rmod__(
- self: "SQLCoreOperations[_NT]", other: Any
- ) -> "ColumnElement[_NUMERIC]":
+ def __rmod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]:
...
@overload
- def __rmod__(self, other: Any) -> "ColumnElement":
+ def __rmod__(self, other: Any) -> ColumnElement[Any]:
...
- def __rmod__(self, other: Any) -> "ColumnElement":
+ def __rmod__(self, other: Any) -> ColumnElement[Any]:
...
@overload
def __truediv__(
- self: "SQLCoreOperations[_NT]", other: Any
- ) -> "ColumnElement[_NUMERIC]":
+ self: _SQO[_NMT], other: Any
+ ) -> ColumnElement[_NUMERIC]:
...
@overload
- def __truediv__(self, other: Any) -> "ColumnElement":
+ def __truediv__(self, other: Any) -> ColumnElement[Any]:
...
- def __truediv__(self, other: Any) -> "ColumnElement":
+ def __truediv__(self, other: Any) -> ColumnElement[Any]:
...
@overload
def __rtruediv__(
- self: "SQLCoreOperations[_NT]", other: Any
- ) -> "ColumnElement[_NUMERIC]":
+ self: _SQO[_NMT], other: Any
+ ) -> ColumnElement[_NUMERIC]:
...
@overload
- def __rtruediv__(self, other: Any) -> "ColumnElement":
+ def __rtruediv__(self, other: Any) -> ColumnElement[Any]:
...
- def __rtruediv__(self, other: Any) -> "ColumnElement":
+ def __rtruediv__(self, other: Any) -> ColumnElement[Any]:
...
@overload
- def __floordiv__(
- self: "SQLCoreOperations[_NT]", other: Any
- ) -> "ColumnElement[_NUMERIC]":
+ def __floordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]:
...
@overload
- def __floordiv__(self, other: Any) -> "ColumnElement":
+ def __floordiv__(self, other: Any) -> ColumnElement[Any]:
...
- def __floordiv__(self, other: Any) -> "ColumnElement":
+ def __floordiv__(self, other: Any) -> ColumnElement[Any]:
...
@overload
- def __rfloordiv__(
- self: "SQLCoreOperations[_NT]", other: Any
- ) -> "ColumnElement[_NUMERIC]":
+ def __rfloordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]:
...
@overload
- def __rfloordiv__(self, other: Any) -> "ColumnElement":
+ def __rfloordiv__(self, other: Any) -> ColumnElement[Any]:
...
- def __rfloordiv__(self, other: Any) -> "ColumnElement":
+ def __rfloordiv__(self, other: Any) -> ColumnElement[Any]:
...
_SQO = SQLCoreOperations
+SelfColumnElement = TypeVar("SelfColumnElement", bound="ColumnElement[Any]")
+
class ColumnElement(
roles.ColumnArgumentOrKeyRole,
roles.StatementOptionRole,
roles.WhereHavingRole,
- roles.BinaryElementRole,
+ roles.BinaryElementRole[_T],
roles.OrderByRole,
roles.ColumnsClauseRole,
roles.LimitOffsetRole,
@@ -987,7 +1059,6 @@ class ColumnElement(
roles.DDLConstraintColumnRole,
roles.DDLExpressionRole,
SQLCoreOperations[_T],
- operators.ColumnOperators[SQLCoreOperations],
DQLDMLClauseElement,
):
"""Represent a column-oriented SQL expression suitable for usage in the
@@ -1069,28 +1140,37 @@ class ColumnElement(
__visit_name__ = "column_element"
- primary_key = False
- foreign_keys = []
- _proxies = ()
+ primary_key: bool = False
+ _is_clone_of: Optional[ColumnElement[_T]]
- _tq_label = None
- """The named label that can be used to target
- this column in a result set in a "table qualified" context.
+ @util.memoized_property
+ def foreign_keys(self) -> Iterable[ForeignKey]:
+ return []
- This label is almost always the label used when
- rendering <expr> AS <label> in a SELECT statement when using
- the LABEL_STYLE_TABLENAME_PLUS_COL label style, which is what the legacy
- ORM ``Query`` object uses as well.
+ @util.memoized_property
+ def _proxies(self) -> List[ColumnElement[Any]]:
+ return []
- For a regular Column bound to a Table, this is typically the label
- <tablename>_<columnname>. For other constructs, different rules
- may apply, such as anonymized labels and others.
+ @util.non_memoized_property
+ def _tq_label(self) -> Optional[str]:
+ """The named label that can be used to target
+ this column in a result set in a "table qualified" context.
- .. versionchanged:: 1.4.21 renamed from ``._label``
+ This label is almost always the label used when
+ rendering <expr> AS <label> in a SELECT statement when using
+ the LABEL_STYLE_TABLENAME_PLUS_COL label style, which is what the
+ legacy ORM ``Query`` object uses as well.
- """
+ For a regular Column bound to a Table, this is typically the label
+ <tablename>_<columnname>. For other constructs, different rules
+ may apply, such as anonymized labels and others.
+
+ .. versionchanged:: 1.4.21 renamed from ``._label``
+
+ """
+ return None
- key = None
+ key: Optional[str] = None
"""The 'key' that in some circumstances refers to this object in a
Python namespace.
@@ -1101,7 +1181,7 @@ class ColumnElement(
"""
@HasMemoized.memoized_attribute
- def _tq_key_label(self):
+ def _tq_key_label(self) -> Optional[str]:
"""A label-based version of 'key' that in some circumstances refers
to this object in a Python namespace.
@@ -1119,17 +1199,17 @@ class ColumnElement(
return self._proxy_key
@property
- def _key_label(self):
+ def _key_label(self) -> Optional[str]:
"""legacy; renamed to _tq_key_label"""
return self._tq_key_label
@property
- def _label(self):
+ def _label(self) -> Optional[str]:
"""legacy; renamed to _tq_label"""
return self._tq_label
@property
- def _non_anon_label(self):
+ def _non_anon_label(self) -> Optional[str]:
"""the 'name' that naturally applies this element when rendered in
SQL.
@@ -1184,9 +1264,23 @@ class ColumnElement(
_is_implicitly_boolean = False
- _alt_names = ()
+ _alt_names: Sequence[str] = ()
- def self_group(self, against=None):
+ @overload
+ def self_group(
+ self: ColumnElement[bool], against: Optional[OperatorType] = None
+ ) -> ColumnElement[bool]:
+ ...
+
+ @overload
+ def self_group(
+ self: ColumnElement[_T], against: Optional[OperatorType] = None
+ ) -> ColumnElement[_T]:
+ ...
+
+ def self_group(
+ self, against: Optional[OperatorType] = None
+ ) -> ColumnElement[Any]:
if (
against in (operators.and_, operators.or_, operators._asbool)
and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity
@@ -1197,18 +1291,32 @@ class ColumnElement(
else:
return self
- def _negate(self):
+ @overload
+ def _negate(self: ColumnElement[bool]) -> ColumnElement[bool]:
+ ...
+
+ @overload
+ def _negate(self: ColumnElement[_T]) -> ColumnElement[_T]:
+ ...
+
+ def _negate(self) -> ColumnElement[Any]:
if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
return AsBoolean(self, operators.is_false, operators.is_true)
else:
- return super(ColumnElement, self)._negate()
+ return cast("UnaryExpression[_T]", super()._negate())
- @util.memoized_property
- def type(self) -> "TypeEngine[_T]":
- return type_api.NULLTYPE
+ type: TypeEngine[_T]
+
+ if not TYPE_CHECKING:
+
+ @util.memoized_property
+ def type(self) -> TypeEngine[_T]: # noqa: A001
+ # used for delayed setup of
+ # type_api
+ return type_api.NULLTYPE
@HasMemoized.memoized_attribute
- def comparator(self) -> "TypeEngine.Comparator[_T]":
+ def comparator(self) -> TypeEngine.Comparator[_T]:
try:
comparator_factory = self.type.comparator_factory
except AttributeError as err:
@@ -1219,7 +1327,7 @@ class ColumnElement(
else:
return comparator_factory(self)
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
try:
return getattr(self.comparator, key)
except AttributeError as err:
@@ -1236,16 +1344,22 @@ class ColumnElement(
self,
op: operators.OperatorType,
*other: Any,
- **kwargs,
- ) -> "ColumnElement":
- return op(self.comparator, *other, **kwargs)
+ **kwargs: Any,
+ ) -> ColumnElement[Any]:
+ return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501
def reverse_operate(
- self, op: operators.OperatorType, other: Any, **kwargs
- ) -> "ColumnElement":
- return op(other, self.comparator, **kwargs)
+ self, op: operators.OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501
- def _bind_param(self, operator, obj, type_=None, expanding=False):
+ def _bind_param(
+ self,
+ operator: operators.OperatorType,
+ obj: Any,
+ type_: Optional[TypeEngine[_T]] = None,
+ expanding: bool = False,
+ ) -> BindParameter[_T]:
return BindParameter(
None,
obj,
@@ -1257,7 +1371,7 @@ class ColumnElement(
)
@property
- def expression(self):
+ def expression(self) -> ColumnElement[Any]:
"""Return a column expression.
Part of the inspection interface; returns self.
@@ -1266,39 +1380,39 @@ class ColumnElement(
return self
@property
- def _select_iterable(self):
+ def _select_iterable(self) -> Iterable[ColumnElement[Any]]:
return (self,)
@util.memoized_property
- def base_columns(self):
- return util.column_set(c for c in self.proxy_set if not c._proxies)
+ def base_columns(self) -> FrozenSet[ColumnElement[Any]]:
+ return frozenset(c for c in self.proxy_set if not c._proxies)
@util.memoized_property
- def proxy_set(self):
- s = util.column_set([self])
- for c in self._proxies:
- s.update(c.proxy_set)
- return s
+ def proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
+ return frozenset([self]).union(
+ itertools.chain.from_iterable(c.proxy_set for c in self._proxies)
+ )
- def _uncached_proxy_set(self):
+ def _uncached_proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
"""An 'uncached' version of proxy set.
This is so that we can read annotations from the list of columns
without breaking the caching of the above proxy_set.
"""
- s = util.column_set([self])
- for c in self._proxies:
- s.update(c._uncached_proxy_set())
- return s
+ return frozenset([self]).union(
+ itertools.chain.from_iterable(
+ c._uncached_proxy_set() for c in self._proxies
+ )
+ )
- def shares_lineage(self, othercolumn):
+ def shares_lineage(self, othercolumn: ColumnElement[Any]) -> bool:
"""Return True if the given :class:`_expression.ColumnElement`
has a common ancestor to this :class:`_expression.ColumnElement`."""
return bool(self.proxy_set.intersection(othercolumn.proxy_set))
- def _compare_name_for_result(self, other):
+ def _compare_name_for_result(self, other: ColumnElement[Any]) -> bool:
"""Return True if the given column element compares to this one
when targeting within a result row."""
@@ -1309,9 +1423,9 @@ class ColumnElement(
)
@HasMemoized.memoized_attribute
- def _proxy_key(self):
+ def _proxy_key(self) -> Optional[str]:
if self._annotations and "proxy_key" in self._annotations:
- return self._annotations["proxy_key"]
+ return cast(str, self._annotations["proxy_key"])
name = self.key
if not name:
@@ -1327,7 +1441,7 @@ class ColumnElement(
return name
@HasMemoized.memoized_attribute
- def _expression_label(self):
+ def _expression_label(self) -> Optional[str]:
"""a suggested label to use in the case that the column has no name,
which should be used if possible as the explicit 'AS <label>'
where this expression would normally have an anon label.
@@ -1340,18 +1454,18 @@ class ColumnElement(
if getattr(self, "name", None) is not None:
return None
elif self._annotations and "proxy_key" in self._annotations:
- return self._annotations["proxy_key"]
+ return cast(str, self._annotations["proxy_key"])
else:
return None
def _make_proxy(
self,
- selectable,
+ selectable: FromClause,
name: Optional[str] = None,
- key=None,
- name_is_truncatable=False,
- **kw,
- ):
+ key: Optional[str] = None,
+ name_is_truncatable: bool = False,
+ **kw: Any,
+ ) -> typing_Tuple[str, ColumnClause[_T]]:
"""Create a new :class:`_expression.ColumnElement` representing this
:class:`_expression.ColumnElement` as it appears in the select list of
a descending selectable.
@@ -1364,7 +1478,7 @@ class ColumnElement(
else:
key = name
- co = ColumnClause(
+ co: ColumnClause[_T] = ColumnClause(
coercions.expect(roles.TruncatedLabelRole, name)
if name_is_truncatable
else name,
@@ -1376,9 +1490,10 @@ class ColumnElement(
co._proxies = [self]
if selectable._is_clone_of is not None:
co._is_clone_of = selectable._is_clone_of.columns.get(key)
+ assert key is not None
return key, co
- def cast(self, type_):
+ def cast(self, type_: TypeEngine[_T]) -> Cast[_T]:
"""Produce a type cast, i.e. ``CAST(<expression> AS <type>)``.
This is a shortcut to the :func:`_expression.cast` function.
@@ -1406,7 +1521,9 @@ class ColumnElement(
"""
return Label(name, self, self.type)
- def _anon_label(self, seed, add_hash=None) -> "_anonymous_label":
+ def _anon_label(
+ self, seed: Optional[str], add_hash: Optional[int] = None
+ ) -> _anonymous_label:
while self._is_clone_of is not None:
self = self._is_clone_of
@@ -1441,7 +1558,7 @@ class ColumnElement(
return _anonymous_label.safe_construct(hash_value, seed or "anon")
@util.memoized_property
- def _anon_name_label(self) -> "_anonymous_label":
+ def _anon_name_label(self) -> str:
"""Provides a constant 'anonymous label' for this ColumnElement.
This is a label() expression which will be named at compile time.
@@ -1462,7 +1579,7 @@ class ColumnElement(
return self._anon_label(name)
@util.memoized_property
- def _anon_key_label(self):
+ def _anon_key_label(self) -> _anonymous_label:
"""Provides a constant 'anonymous key label' for this ColumnElement.
Compare to ``anon_label``, except that the "key" of the column,
@@ -1478,25 +1595,23 @@ class ColumnElement(
"""
return self._anon_label(self._proxy_key)
- @property
- @util.deprecated(
+ @util.deprecated_property(
"1.4",
"The :attr:`_expression.ColumnElement.anon_label` attribute is now "
"private, and the public accessor is deprecated.",
)
- def anon_label(self):
+ def anon_label(self) -> str:
return self._anon_name_label
- @property
- @util.deprecated(
+ @util.deprecated_property(
"1.4",
"The :attr:`_expression.ColumnElement.anon_key_label` attribute is "
"now private, and the public accessor is deprecated.",
)
- def anon_key_label(self):
+ def anon_key_label(self) -> str:
return self._anon_key_label
- def _dedupe_anon_label_idx(self, idx):
+ def _dedupe_anon_label_idx(self, idx: int) -> str:
"""label to apply to a column that is anon labeled, but repeated
in the SELECT, so that we have to make an "extra anon" label that
disambiguates it from the previous appearance.
@@ -1520,20 +1635,20 @@ class ColumnElement(
return self._anon_label(label, add_hash=idx)
@util.memoized_property
- def _anon_tq_label(self):
+ def _anon_tq_label(self) -> _anonymous_label:
return self._anon_label(getattr(self, "_tq_label", None))
@util.memoized_property
- def _anon_tq_key_label(self):
+ def _anon_tq_key_label(self) -> _anonymous_label:
return self._anon_label(getattr(self, "_tq_key_label", None))
- def _dedupe_anon_tq_label_idx(self, idx):
+ def _dedupe_anon_tq_label_idx(self, idx: int) -> _anonymous_label:
label = getattr(self, "_tq_label", None) or "anon"
return self._anon_label(label, add_hash=idx)
-class WrapsColumnExpression:
+class WrapsColumnExpression(ColumnElement[_T]):
"""Mixin that defines a :class:`_expression.ColumnElement`
as a wrapper with special
labeling behavior for an expression that already has a name.
@@ -1548,25 +1663,27 @@ class WrapsColumnExpression:
"""
@property
- def wrapped_column_expression(self):
+ def wrapped_column_expression(self) -> ColumnElement[_T]:
raise NotImplementedError()
- @property
- def _tq_label(self):
+ @util.non_memoized_property
+ def _tq_label(self) -> Optional[str]:
wce = self.wrapped_column_expression
if hasattr(wce, "_tq_label"):
return wce._tq_label
else:
return None
- _label = _tq_label
+ @property
+ def _label(self) -> Optional[str]:
+ return self._tq_label
@property
- def _non_anon_label(self):
+ def _non_anon_label(self) -> Optional[str]:
return None
- @property
- def _anon_name_label(self):
+ @util.non_memoized_property
+ def _anon_name_label(self) -> str:
wce = self.wrapped_column_expression
# this logic tries to get the WrappedColumnExpression to render
@@ -1578,9 +1695,9 @@ class WrapsColumnExpression:
return nal
elif hasattr(wce, "_anon_name_label"):
return wce._anon_name_label
- return super(WrapsColumnExpression, self)._anon_name_label
+ return super()._anon_name_label
- def _dedupe_anon_label_idx(self, idx):
+ def _dedupe_anon_label_idx(self, idx: int) -> str:
wce = self.wrapped_column_expression
nal = wce._non_anon_label
if nal:
@@ -1589,7 +1706,7 @@ class WrapsColumnExpression:
return self._dedupe_anon_tq_label_idx(idx)
-SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter")
+SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter[Any]")
class BindParameter(roles.InElementRole, ColumnElement[_T]):
@@ -1614,7 +1731,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
__visit_name__ = "bindparam"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("key", InternalTraversal.dp_anon_name),
("type", InternalTraversal.dp_type),
("callable", InternalTraversal.dp_plain_dict),
@@ -1622,7 +1739,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
]
key: str
- type: TypeEngine
+ type: TypeEngine[_T]
_is_crud = False
_is_bind_parameter = True
@@ -1634,23 +1751,23 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
def __init__(
self,
- key,
- value=NO_ARG,
- type_=None,
- unique=False,
- required=NO_ARG,
- quote=None,
- callable_=None,
- expanding=False,
- isoutparam=False,
- literal_execute=False,
- _compared_to_operator=None,
- _compared_to_type=None,
- _is_crud=False,
+ key: Optional[str],
+ value: Any = _NoArg.NO_ARG,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ unique: bool = False,
+ required: Union[bool, Literal[_NoArg.NO_ARG]] = _NoArg.NO_ARG,
+ quote: Optional[bool] = None,
+ callable_: Optional[Callable[[], Any]] = None,
+ expanding: bool = False,
+ isoutparam: bool = False,
+ literal_execute: bool = False,
+ _compared_to_operator: Optional[OperatorType] = None,
+ _compared_to_type: Optional[TypeEngine[Any]] = None,
+ _is_crud: bool = False,
):
- if required is NO_ARG:
- required = value is NO_ARG and callable_ is None
- if value is NO_ARG:
+ if required is _NoArg.NO_ARG:
+ required = value is _NoArg.NO_ARG and callable_ is None
+ if value is _NoArg.NO_ARG:
value = None
if quote is not None:
@@ -1713,12 +1830,19 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
self.type = type_api._resolve_value_to_type(check_value)
elif isinstance(type_, type):
self.type = type_()
- elif type_._is_tuple_type and value:
- if expanding:
- check_value = value[0]
+ elif is_tuple_type(type_):
+ if value:
+ if expanding:
+ check_value = value[0]
+ else:
+ check_value = value
+ cast(
+ "BindParameter[typing_Tuple[Any, ...]]", self
+ ).type = type_._resolve_values_to_types(check_value)
else:
- check_value = value
- self.type = type_._resolve_values_to_types(check_value)
+ cast(
+ "BindParameter[typing_Tuple[Any, ...]]", self
+ ).type = type_
else:
self.type = type_
@@ -1791,7 +1915,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
return c
def _clone(
- self: SelfBindParameter, maintain_key=False, **kw
+ self: SelfBindParameter, maintain_key: bool = False, **kw: Any
) -> SelfBindParameter:
c = ClauseElement._clone(self, **kw)
if not maintain_key and self.unique:
@@ -1865,7 +1989,9 @@ class TypeClause(DQLDMLClauseElement):
__visit_name__ = "typeclause"
- _traverse_internals = [("type", InternalTraversal.dp_type)]
+ _traverse_internals: _TraverseInternalsType = [
+ ("type", InternalTraversal.dp_type)
+ ]
def __init__(self, type_):
self.type = type_
@@ -1882,7 +2008,7 @@ class TextClause(
roles.OrderByRole,
roles.FromClauseRole,
roles.SelectStatementRole,
- roles.BinaryElementRole,
+ roles.BinaryElementRole[Any],
roles.InElementRole,
Executable,
DQLDMLClauseElement,
@@ -1909,7 +2035,7 @@ class TextClause(
__visit_name__ = "textclause"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("_bindparams", InternalTraversal.dp_string_clauseelement_dict),
("text", InternalTraversal.dp_string),
]
@@ -1923,7 +2049,9 @@ class TextClause(
_render_label_in_columns_clause = False
- _hide_froms = ()
+ @property
+ def _hide_froms(self) -> Iterable[FromClause]:
+ return ()
def __and__(self, other):
# support use in select.where(), query.filter()
@@ -1935,12 +2063,13 @@ class TextClause(
# help in those cases where text() is
# interpreted in a column expression situation
- key = _label = None
+ key: Optional[str] = None
+ _label: Optional[str] = None
_allow_label_resolve = False
- def __init__(self, text):
- self._bindparams = {}
+ def __init__(self, text: str):
+ self._bindparams: Dict[str, BindParameter[Any]] = {}
def repl(m):
self._bindparams[m.group(1)] = BindParameter(m.group(1))
@@ -1952,7 +2081,9 @@ class TextClause(
@_generative
def bindparams(
- self: SelfTextClause, *binds, **names_to_values
+ self: SelfTextClause,
+ *binds: BindParameter[Any],
+ **names_to_values: Any,
) -> SelfTextClause:
"""Establish the values and/or types of bound parameters within
this :class:`_expression.TextClause` construct.
@@ -2205,7 +2336,7 @@ class TextClause(
else col
for col in cols
]
- keyed_input_cols = [
+ keyed_input_cols: List[ColumnClause[Any]] = [
ColumnClause(key, type_) for key, type_ in types.items()
]
@@ -2230,7 +2361,7 @@ class TextClause(
return self
-class Null(SingletonConstant, roles.ConstExprRole, ColumnElement):
+class Null(SingletonConstant, roles.ConstExprRole[None], ColumnElement[None]):
"""Represent the NULL keyword in a SQL statement.
:class:`.Null` is accessed as a constant via the
@@ -2240,23 +2371,26 @@ class Null(SingletonConstant, roles.ConstExprRole, ColumnElement):
__visit_name__ = "null"
- _traverse_internals = []
+ _traverse_internals: _TraverseInternalsType = []
+ _singleton: Null
@util.memoized_property
def type(self):
return type_api.NULLTYPE
@classmethod
- def _instance(cls):
+ def _instance(cls) -> Null:
"""Return a constant :class:`.Null` construct."""
- return Null()
+ return Null._singleton
Null._create_singleton()
-class False_(SingletonConstant, roles.ConstExprRole, ColumnElement):
+class False_(
+ SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]
+):
"""Represent the ``false`` keyword, or equivalent, in a SQL statement.
:class:`.False_` is accessed as a constant via the
@@ -2265,24 +2399,25 @@ class False_(SingletonConstant, roles.ConstExprRole, ColumnElement):
"""
__visit_name__ = "false"
- _traverse_internals = []
+ _traverse_internals: _TraverseInternalsType = []
+ _singleton: False_
@util.memoized_property
def type(self):
return type_api.BOOLEANTYPE
- def _negate(self):
- return True_()
+ def _negate(self) -> True_:
+ return True_._singleton
@classmethod
- def _instance(cls):
- return False_()
+ def _instance(cls) -> False_:
+ return False_._singleton
False_._create_singleton()
-class True_(SingletonConstant, roles.ConstExprRole, ColumnElement):
+class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]):
"""Represent the ``true`` keyword, or equivalent, in a SQL statement.
:class:`.True_` is accessed as a constant via the
@@ -2292,14 +2427,15 @@ class True_(SingletonConstant, roles.ConstExprRole, ColumnElement):
__visit_name__ = "true"
- _traverse_internals = []
+ _traverse_internals: _TraverseInternalsType = []
+ _singleton: True_
@util.memoized_property
def type(self):
return type_api.BOOLEANTYPE
- def _negate(self):
- return False_()
+ def _negate(self) -> False_:
+ return False_._singleton
@classmethod
def _ifnone(cls, other):
@@ -2309,8 +2445,8 @@ class True_(SingletonConstant, roles.ConstExprRole, ColumnElement):
return other
@classmethod
- def _instance(cls):
- return True_()
+ def _instance(cls) -> True_:
+ return True_._singleton
True_._create_singleton()
@@ -2333,18 +2469,18 @@ class ClauseList(
_is_clause_list = True
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("clauses", InternalTraversal.dp_clauseelement_list),
("operator", InternalTraversal.dp_operator),
]
def __init__(
self,
- *clauses,
- operator=operators.comma_op,
- group=True,
- group_contents=True,
- _flatten_sub_clauses=False,
+ *clauses: _ColumnExpression[Any],
+ operator: OperatorType = operators.comma_op,
+ group: bool = True,
+ group_contents: bool = True,
+ _flatten_sub_clauses: bool = False,
_literal_as_text_role: Type[roles.SQLRole] = roles.WhereHavingRole,
):
self.operator = operator
@@ -2405,8 +2541,8 @@ class ClauseList(
coercions.expect(self._text_converter_role, clause)
)
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return list(itertools.chain(*[c._from_objects for c in self.clauses]))
def self_group(self, against=None):
@@ -2465,7 +2601,14 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
return lcc, [c.self_group(against=against) for c in convert_clauses]
@classmethod
- def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
+ def _construct(
+ cls,
+ operator: OperatorType,
+ continue_on: Any,
+ skip_on: Any,
+ *clauses: _ColumnExpression[Any],
+ **kw: Any,
+ ) -> BooleanClauseList:
lcc, convert_clauses = cls._process_clauses_for_boolean(
operator,
continue_on,
@@ -2479,11 +2622,11 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
if lcc > 1:
# multiple elements. Return regular BooleanClauseList
# which will link elements against the operator.
- return cls._construct_raw(operator, convert_clauses)
+ return cls._construct_raw(operator, convert_clauses) # type: ignore[no-any-return] # noqa E501
elif lcc == 1:
# just one element. return it as a single boolean element,
# not a list and discard the operator.
- return convert_clauses[0]
+ return convert_clauses[0] # type: ignore[no-any-return] # noqa E501
else:
# no elements period. deprecated use case. return an empty
# ClauseList construct that generates nothing unless it has
@@ -2500,7 +2643,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
},
version="1.4",
)
- return cls._construct_raw(operator)
+ return cls._construct_raw(operator) # type: ignore[no-any-return] # noqa E501
@classmethod
def _construct_for_whereclause(cls, clauses):
@@ -2540,7 +2683,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
return self
@classmethod
- def and_(cls, *clauses):
+ def and_(cls, *clauses: _ColumnExpression[bool]) -> BooleanClauseList:
r"""Produce a conjunction of expressions joined by ``AND``.
See :func:`_sql.and_` for full documentation.
@@ -2550,7 +2693,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
)
@classmethod
- def or_(cls, *clauses):
+ def or_(cls, *clauses: _ColumnExpression[bool]) -> BooleanClauseList:
"""Produce a conjunction of expressions joined by ``OR``.
See :func:`_sql.or_` for full documentation.
@@ -2577,19 +2720,27 @@ and_ = BooleanClauseList.and_
or_ = BooleanClauseList.or_
-class Tuple(ClauseList, ColumnElement):
+class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]):
"""Represent a SQL tuple."""
__visit_name__ = "tuple"
- _traverse_internals = ClauseList._traverse_internals + []
+ _traverse_internals: _TraverseInternalsType = (
+ ClauseList._traverse_internals + []
+ )
+
+ type: TupleType
@util.preload_module("sqlalchemy.sql.sqltypes")
- def __init__(self, *clauses, types=None):
+ def __init__(
+ self,
+ *clauses: _ColumnExpression[Any],
+ types: Optional[Sequence[_TypeEngineArgument[Any]]] = None,
+ ):
sqltypes = util.preloaded.sql_sqltypes
if types is None:
- clauses = [
+ init_clauses = [
coercions.expect(roles.ExpressionElementRole, c)
for c in clauses
]
@@ -2599,7 +2750,7 @@ class Tuple(ClauseList, ColumnElement):
"Wrong number of elements for %d-tuple: %r "
% (len(types), clauses)
)
- clauses = [
+ init_clauses = [
coercions.expect(
roles.ExpressionElementRole,
c,
@@ -2608,8 +2759,8 @@ class Tuple(ClauseList, ColumnElement):
for typ, c in zip(types, clauses)
]
- self.type = sqltypes.TupleType(*[arg.type for arg in clauses])
- super(Tuple, self).__init__(*clauses)
+ self.type = sqltypes.TupleType(*[arg.type for arg in init_clauses])
+ super(Tuple, self).__init__(*init_clauses)
@property
def _select_iterable(self):
@@ -2672,7 +2823,7 @@ class Case(ColumnElement[_T]):
__visit_name__ = "case"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("value", InternalTraversal.dp_clauseelement),
("whens", InternalTraversal.dp_clauseelement_tuples),
("else_", InternalTraversal.dp_clauseelement),
@@ -2681,13 +2832,24 @@ class Case(ColumnElement[_T]):
# for case(), the type is derived from the whens. so for the moment
# users would have to cast() the case to get a specific type
- def __init__(self, *whens, value=None, else_=None):
+ whens: List[typing_Tuple[ColumnElement[bool], ColumnElement[_T]]]
+ else_: Optional[ColumnElement[_T]]
+ value: Optional[ColumnElement[Any]]
- whens = coercions._expression_collection_was_a_list(
+ def __init__(
+ self,
+ *whens: Union[
+ typing_Tuple[_ColumnExpression[bool], Any], Mapping[Any, Any]
+ ],
+ value: Optional[Any] = None,
+ else_: Optional[Any] = None,
+ ):
+
+ new_whens: Iterable[Any] = coercions._expression_collection_was_a_list(
"whens", "case", whens
)
try:
- whens = util.dictlike_iteritems(whens)
+ new_whens = util.dictlike_iteritems(new_whens)
except TypeError:
pass
@@ -2700,7 +2862,7 @@ class Case(ColumnElement[_T]):
).self_group(),
coercions.expect(roles.ExpressionElementRole, r),
)
- for (c, r) in whens
+ for (c, r) in new_whens
]
if whenlist:
@@ -2713,7 +2875,7 @@ class Case(ColumnElement[_T]):
else:
self.value = coercions.expect(roles.ExpressionElementRole, value)
- self.type = type_
+ self.type = cast(_T, type_)
self.whens = whenlist
if else_ is not None:
@@ -2721,14 +2883,14 @@ class Case(ColumnElement[_T]):
else:
self.else_ = None
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return list(
itertools.chain(*[x._from_objects for x in self.get_children()])
)
-class Cast(WrapsColumnExpression, ColumnElement[_T]):
+class Cast(WrapsColumnExpression[_T]):
"""Represent a ``CAST`` expression.
:class:`.Cast` is produced using the :func:`.cast` factory function,
@@ -2754,12 +2916,20 @@ class Cast(WrapsColumnExpression, ColumnElement[_T]):
__visit_name__ = "cast"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("clause", InternalTraversal.dp_clauseelement),
("typeclause", InternalTraversal.dp_clauseelement),
]
- def __init__(self, expression, type_):
+ clause: ColumnElement[Any]
+ type: TypeEngine[_T]
+ typeclause: TypeClause
+
+ def __init__(
+ self,
+ expression: _ColumnExpression[Any],
+ type_: _TypeEngineArgument[_T],
+ ):
self.type = type_api.to_instance(type_)
self.clause = coercions.expect(
roles.ExpressionElementRole,
@@ -2769,8 +2939,8 @@ class Cast(WrapsColumnExpression, ColumnElement[_T]):
)
self.typeclause = TypeClause(self.type)
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return self.clause._from_objects
@property
@@ -2778,7 +2948,7 @@ class Cast(WrapsColumnExpression, ColumnElement[_T]):
return self.clause
-class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]):
+class TypeCoerce(WrapsColumnExpression[_T]):
"""Represent a Python-side type-coercion wrapper.
:class:`.TypeCoerce` supplies the :func:`_expression.type_coerce`
@@ -2798,12 +2968,19 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]):
__visit_name__ = "type_coerce"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("clause", InternalTraversal.dp_clauseelement),
("type", InternalTraversal.dp_type),
]
- def __init__(self, expression, type_):
+ clause: ColumnElement[Any]
+ type: TypeEngine[_T]
+
+ def __init__(
+ self,
+ expression: _ColumnExpression[Any],
+ type_: _TypeEngineArgument[_T],
+ ):
self.type = type_api.to_instance(type_)
self.clause = coercions.expect(
roles.ExpressionElementRole,
@@ -2812,8 +2989,8 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]):
apply_propagate_attrs=self,
)
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return self.clause._from_objects
@HasMemoized.memoized_attribute
@@ -2837,27 +3014,30 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]):
return self
-class Extract(ColumnElement[_T]):
+class Extract(ColumnElement[int]):
"""Represent a SQL EXTRACT clause, ``extract(field FROM expr)``."""
__visit_name__ = "extract"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("expr", InternalTraversal.dp_clauseelement),
("field", InternalTraversal.dp_string),
]
- def __init__(self, field, expr):
+ expr: ColumnElement[Any]
+ field: str
+
+ def __init__(self, field: str, expr: _ColumnExpression[Any]):
self.type = type_api.INTEGERTYPE
self.field = field
self.expr = coercions.expect(roles.ExpressionElementRole, expr)
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return self.expr._from_objects
-class _label_reference(ColumnElement):
+class _label_reference(ColumnElement[_T]):
"""Wrap a column expression as it appears in a 'reference' context.
This expression is any that includes an _order_by_label_element,
@@ -2872,26 +3052,30 @@ class _label_reference(ColumnElement):
__visit_name__ = "label_reference"
- _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+ _traverse_internals: _TraverseInternalsType = [
+ ("element", InternalTraversal.dp_clauseelement)
+ ]
- def __init__(self, element):
+ def __init__(self, element: ColumnElement[_T]):
self.element = element
- @property
- def _from_objects(self):
- return ()
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
+ return []
-class _textual_label_reference(ColumnElement):
+class _textual_label_reference(ColumnElement[Any]):
__visit_name__ = "textual_label_reference"
- _traverse_internals = [("element", InternalTraversal.dp_string)]
+ _traverse_internals: _TraverseInternalsType = [
+ ("element", InternalTraversal.dp_string)
+ ]
- def __init__(self, element):
+ def __init__(self, element: str):
self.element = element
@util.memoized_property
- def _text_clause(self):
+ def _text_clause(self) -> TextClause:
return TextClause(self.element)
@@ -2911,7 +3095,7 @@ class UnaryExpression(ColumnElement[_T]):
__visit_name__ = "unary"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement),
("operator", InternalTraversal.dp_operator),
("modifier", InternalTraversal.dp_operator),
@@ -2919,11 +3103,11 @@ class UnaryExpression(ColumnElement[_T]):
def __init__(
self,
- element,
- operator=None,
- modifier=None,
- type_: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] = None,
- wraps_column_expression=False,
+ element: ColumnElement[Any],
+ operator: Optional[OperatorType] = None,
+ modifier: Optional[OperatorType] = None,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ wraps_column_expression: bool = False,
):
self.operator = operator
self.modifier = modifier
@@ -2935,7 +3119,10 @@ class UnaryExpression(ColumnElement[_T]):
self.wraps_column_expression = wraps_column_expression
@classmethod
- def _create_nulls_first(cls, column):
+ def _create_nulls_first(
+ cls,
+ column: _ColumnExpression[_T],
+ ) -> UnaryExpression[_T]:
return UnaryExpression(
coercions.expect(roles.ByOfRole, column),
modifier=operators.nulls_first_op,
@@ -2943,7 +3130,10 @@ class UnaryExpression(ColumnElement[_T]):
)
@classmethod
- def _create_nulls_last(cls, column):
+ def _create_nulls_last(
+ cls,
+ column: _ColumnExpression[_T],
+ ) -> UnaryExpression[_T]:
return UnaryExpression(
coercions.expect(roles.ByOfRole, column),
modifier=operators.nulls_last_op,
@@ -2951,7 +3141,9 @@ class UnaryExpression(ColumnElement[_T]):
)
@classmethod
- def _create_desc(cls, column):
+ def _create_desc(
+ cls, column: _ColumnExpression[_T]
+ ) -> UnaryExpression[_T]:
return UnaryExpression(
coercions.expect(roles.ByOfRole, column),
modifier=operators.desc_op,
@@ -2959,7 +3151,10 @@ class UnaryExpression(ColumnElement[_T]):
)
@classmethod
- def _create_asc(cls, column):
+ def _create_asc(
+ cls,
+ column: _ColumnExpression[_T],
+ ) -> UnaryExpression[_T]:
return UnaryExpression(
coercions.expect(roles.ByOfRole, column),
modifier=operators.asc_op,
@@ -2967,24 +3162,27 @@ class UnaryExpression(ColumnElement[_T]):
)
@classmethod
- def _create_distinct(cls, expr):
- expr = coercions.expect(roles.ExpressionElementRole, expr)
+ def _create_distinct(
+ cls,
+ expr: _ColumnExpression[_T],
+ ) -> UnaryExpression[_T]:
+ col_expr = coercions.expect(roles.ExpressionElementRole, expr)
return UnaryExpression(
- expr,
+ col_expr,
operator=operators.distinct_op,
- type_=expr.type,
+ type_=col_expr.type,
wraps_column_expression=False,
)
@property
- def _order_by_label_element(self):
+ def _order_by_label_element(self) -> Optional[Label[Any]]:
if self.modifier in (operators.desc_op, operators.asc_op):
return self.element._order_by_label_element
else:
return None
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return self.element._from_objects
def _negate(self):
@@ -3005,7 +3203,7 @@ class UnaryExpression(ColumnElement[_T]):
return self
-class CollectionAggregate(UnaryExpression):
+class CollectionAggregate(UnaryExpression[_T]):
"""Forms the basis for right-hand collection operator modifiers
ANY and ALL.
@@ -3018,7 +3216,9 @@ class CollectionAggregate(UnaryExpression):
inherit_cache = True
@classmethod
- def _create_any(cls, expr):
+ def _create_any(
+ cls, expr: _ColumnExpression[_T]
+ ) -> CollectionAggregate[_T]:
expr = coercions.expect(roles.ExpressionElementRole, expr)
expr = expr.self_group()
@@ -3030,7 +3230,9 @@ class CollectionAggregate(UnaryExpression):
)
@classmethod
- def _create_all(cls, expr):
+ def _create_all(
+ cls, expr: _ColumnExpression[_T]
+ ) -> CollectionAggregate[_T]:
expr = coercions.expect(roles.ExpressionElementRole, expr)
expr = expr.self_group()
return CollectionAggregate(
@@ -3059,7 +3261,7 @@ class CollectionAggregate(UnaryExpression):
)
-class AsBoolean(WrapsColumnExpression, UnaryExpression):
+class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]):
inherit_cache = True
def __init__(self, element, operator, negate):
@@ -3101,7 +3303,7 @@ class BinaryExpression(ColumnElement[_T]):
__visit_name__ = "binary"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("left", InternalTraversal.dp_clauseelement),
("right", InternalTraversal.dp_clauseelement),
("operator", InternalTraversal.dp_operator),
@@ -3119,16 +3321,16 @@ class BinaryExpression(ColumnElement[_T]):
"""
+ modifiers: Optional[Mapping[str, Any]]
+
def __init__(
self,
- left: ColumnElement,
- right: Union[ColumnElement, ClauseList],
- operator,
- type_: Optional[
- Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
- ] = None,
- negate=None,
- modifiers=None,
+ left: ColumnElement[Any],
+ right: Union[ColumnElement[Any], ClauseList],
+ operator: OperatorType,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ negate: Optional[OperatorType] = None,
+ modifiers: Optional[Mapping[str, Any]] = None,
):
# allow compatibility with libraries that
# refer to BinaryExpression directly and pass strings
@@ -3149,8 +3351,40 @@ class BinaryExpression(ColumnElement[_T]):
self.modifiers = modifiers
def __bool__(self):
- if self.operator in (operator.eq, operator.ne):
- return self.operator(*self._orig)
+ """Implement Python-side "bool" for BinaryExpression as a
+ simple "identity" check for the left and right attributes,
+ if the operator is "eq" or "ne". Otherwise the expression
+ continues to not support "bool" like all other column expressions.
+
+ The rationale here is so that ColumnElement objects can be hashable.
+ What? Well, suppose you do this::
+
+ c1, c2 = column('x'), column('y')
+ s1 = set([c1, c2])
+
+ We do that **a lot**, columns inside of sets is an extremely basic
+ thing all over the ORM for example.
+
+ So what happens if we do this? ::
+
+ c1 in s1
+
+ Hashing means it will normally use ``__hash__()`` of the object,
+ but in case of hash collision, it's going to also do ``c1 == c1``
+ and/or ``c1 == c2`` inside. Those operations need to return a
+ True/False value. But because we override ``==`` and ``!=``, they're
+ going to get a BinaryExpression. Hence we implement ``__bool__`` here
+ so that these comparisons behave in this particular context mostly
+ like regular object comparisons. Thankfully Python is OK with
+ that! Otherwise we'd have to use special set classes for columns
+ (which we used to do, decades ago).
+
+ """
+ if self.operator in (operators.eq, operators.ne):
+ # this is using the eq/ne operator given int hash values,
+ # rather than Operator, so that "bool" can be based on
+ # identity
+ return self.operator(*self._orig) # type: ignore
else:
raise TypeError("Boolean value of this clause is not defined")
@@ -3167,8 +3401,8 @@ class BinaryExpression(ColumnElement[_T]):
def is_comparison(self):
return operators.is_comparison(self.operator)
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return self.left._from_objects + self.right._from_objects
def self_group(self, against=None):
@@ -3192,7 +3426,7 @@ class BinaryExpression(ColumnElement[_T]):
return super(BinaryExpression, self)._negate()
-class Slice(ColumnElement):
+class Slice(ColumnElement[Any]):
"""Represent SQL for a Python array-slice object.
This is not a specific SQL construct at this level, but
@@ -3202,7 +3436,7 @@ class Slice(ColumnElement):
__visit_name__ = "slice"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("start", InternalTraversal.dp_clauseelement),
("stop", InternalTraversal.dp_clauseelement),
("step", InternalTraversal.dp_clauseelement),
@@ -3234,7 +3468,7 @@ class Slice(ColumnElement):
return self
-class IndexExpression(BinaryExpression):
+class IndexExpression(BinaryExpression[Any]):
"""Represent the class of expressions that are like an "index"
operation."""
@@ -3246,6 +3480,8 @@ class GroupedElement(DQLDMLClauseElement):
__visit_name__ = "grouping"
+ element: ClauseElement
+
def self_group(self, against=None):
return self
@@ -3253,15 +3489,19 @@ class GroupedElement(DQLDMLClauseElement):
return self.element._ungroup()
-class Grouping(GroupedElement, ColumnElement):
+class Grouping(GroupedElement, ColumnElement[_T]):
"""Represent a grouping within a column expression"""
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement),
("type", InternalTraversal.dp_type),
]
- def __init__(self, element):
+ element: Union[TextClause, ClauseList, ColumnElement[_T]]
+
+ def __init__(
+ self, element: Union[TextClause, ClauseList, ColumnElement[_T]]
+ ):
self.element = element
self.type = getattr(element, "type", type_api.NULLTYPE)
@@ -3272,21 +3512,21 @@ class Grouping(GroupedElement, ColumnElement):
def _is_implicitly_boolean(self):
return self.element._is_implicitly_boolean
- @property
- def _tq_label(self):
+ @util.non_memoized_property
+ def _tq_label(self) -> Optional[str]:
return (
getattr(self.element, "_tq_label", None) or self._anon_name_label
)
- @property
- def _proxies(self):
+ @util.non_memoized_property
+ def _proxies(self) -> List[ColumnElement[Any]]:
if isinstance(self.element, ColumnElement):
return [self.element]
else:
return []
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return self.element._from_objects
def __getattr__(self, attr):
@@ -3300,8 +3540,13 @@ class Grouping(GroupedElement, ColumnElement):
self.type = state["type"]
-RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED")
-RANGE_CURRENT = util.symbol("RANGE_CURRENT")
+class _OverRange(IntEnum):
+ RANGE_UNBOUNDED = 0
+ RANGE_CURRENT = 1
+
+
+RANGE_UNBOUNDED = _OverRange.RANGE_UNBOUNDED
+RANGE_CURRENT = _OverRange.RANGE_CURRENT
class Over(ColumnElement[_T]):
@@ -3316,7 +3561,7 @@ class Over(ColumnElement[_T]):
__visit_name__ = "over"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement),
("order_by", InternalTraversal.dp_clauseelement),
("partition_by", InternalTraversal.dp_clauseelement),
@@ -3324,15 +3569,26 @@ class Over(ColumnElement[_T]):
("rows", InternalTraversal.dp_plain_obj),
]
- order_by = None
- partition_by = None
+ order_by: Optional[ClauseList] = None
+ partition_by: Optional[ClauseList] = None
- element = None
+ element: ColumnElement[_T]
"""The underlying expression object to which this :class:`.Over`
object refers towards."""
+ range_: Optional[typing_Tuple[int, int]]
+
def __init__(
- self, element, partition_by=None, order_by=None, range_=None, rows=None
+ self,
+ element: ColumnElement[_T],
+ partition_by: Optional[
+ Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ ] = None,
+ order_by: Optional[
+ Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ ] = None,
+ range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+ rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
):
self.element = element
if order_by is not None:
@@ -3368,10 +3624,15 @@ class Over(ColumnElement[_T]):
self.rows,
)
- def _interpret_range(self, range_):
+ def _interpret_range(
+ self, range_: typing_Tuple[Optional[int], Optional[int]]
+ ) -> typing_Tuple[int, int]:
if not isinstance(range_, tuple) or len(range_) != 2:
raise exc.ArgumentError("2-tuple expected for range/rows")
+ lower: int
+ upper: int
+
if range_[0] is None:
lower = RANGE_UNBOUNDED
else:
@@ -3404,8 +3665,8 @@ class Over(ColumnElement[_T]):
def type(self):
return self.element.type
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return list(
itertools.chain(
*[
@@ -3436,14 +3697,16 @@ class WithinGroup(ColumnElement[_T]):
__visit_name__ = "withingroup"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement),
("order_by", InternalTraversal.dp_clauseelement),
]
- order_by = None
+ order_by: Optional[ClauseList] = None
- def __init__(self, element, *order_by):
+ def __init__(
+ self, element: FunctionElement[_T], *order_by: _ColumnExpression[Any]
+ ):
self.element = element
if order_by is not None:
self.order_by = ClauseList(
@@ -3451,7 +3714,9 @@ class WithinGroup(ColumnElement[_T]):
)
def __reduce__(self):
- return self.__class__, (self.element,) + tuple(self.order_by)
+ return self.__class__, (self.element,) + (
+ tuple(self.order_by) if self.order_by is not None else ()
+ )
def over(self, partition_by=None, order_by=None, range_=None, rows=None):
"""Produce an OVER clause against this :class:`.WithinGroup`
@@ -3477,8 +3742,8 @@ class WithinGroup(ColumnElement[_T]):
else:
return self.element.type
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return list(
itertools.chain(
*[
@@ -3490,7 +3755,7 @@ class WithinGroup(ColumnElement[_T]):
)
-class FunctionFilter(ColumnElement):
+class FunctionFilter(ColumnElement[_T]):
"""Represent a function FILTER clause.
This is a special operator against aggregate and window functions,
@@ -3512,14 +3777,16 @@ class FunctionFilter(ColumnElement):
__visit_name__ = "funcfilter"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("func", InternalTraversal.dp_clauseelement),
("criterion", InternalTraversal.dp_clauseelement),
]
- criterion = None
+ criterion: Optional[ColumnElement[bool]] = None
- def __init__(self, func, *criterion):
+ def __init__(
+ self, func: FunctionElement[_T], *criterion: _ColumnExpression[bool]
+ ):
self.func = func
self.filter(*criterion)
@@ -3535,17 +3802,27 @@ class FunctionFilter(ColumnElement):
"""
- for criterion in list(criterion):
- criterion = coercions.expect(roles.WhereHavingRole, criterion)
+ for crit in list(criterion):
+ crit = coercions.expect(roles.WhereHavingRole, crit)
if self.criterion is not None:
- self.criterion = self.criterion & criterion
+ self.criterion = self.criterion & crit
else:
- self.criterion = criterion
+ self.criterion = crit
return self
- def over(self, partition_by=None, order_by=None, range_=None, rows=None):
+ def over(
+ self,
+ partition_by: Optional[
+ Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ ] = None,
+ order_by: Optional[
+ Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ ] = None,
+ range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+ rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+ ) -> Over[_T]:
"""Produce an OVER clause against this filtered function.
Used against aggregate or so-called "window" functions,
@@ -3581,8 +3858,8 @@ class FunctionFilter(ColumnElement):
def type(self):
return self.func.type
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return list(
itertools.chain(
*[
@@ -3594,7 +3871,7 @@ class FunctionFilter(ColumnElement):
)
-class Label(roles.LabeledColumnExprRole, ColumnElement[_T]):
+class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]):
"""Represents a column label (AS).
Represent a label, as typically applied to any column-level
@@ -3604,13 +3881,21 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]):
__visit_name__ = "label"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("name", InternalTraversal.dp_anon_name),
- ("_type", InternalTraversal.dp_type),
+ ("type", InternalTraversal.dp_type),
("_element", InternalTraversal.dp_clauseelement),
]
- def __init__(self, name, element, type_=None):
+ _element: ColumnElement[_T]
+ name: str
+
+ def __init__(
+ self,
+ name: Optional[str],
+ element: _ColumnExpression[_T],
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ ):
orig_element = element
element = coercions.expect(
roles.ExpressionElementRole,
@@ -3635,11 +3920,14 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]):
self.key = self._tq_label = self._tq_key_label = self.name
self._element = element
- self._type = type_
+ # self._type = type_
+ self.type = type_api.to_instance(
+ type_ or getattr(self._element, "type", None)
+ )
self._proxies = [element]
def __reduce__(self):
- return self.__class__, (self.name, self._element, self._type)
+ return self.__class__, (self.name, self._element, self.type)
@util.memoized_property
def _is_implicitly_boolean(self):
@@ -3653,14 +3941,8 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]):
def _order_by_label_element(self):
return self
- @util.memoized_property
- def type(self):
- return type_api.to_instance(
- self._type or getattr(self._element, "type", None)
- )
-
@HasMemoized.memoized_attribute
- def element(self):
+ def element(self) -> ColumnElement[_T]:
return self._element.self_group(against=operators.as_)
def self_group(self, against=None):
@@ -3672,7 +3954,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]):
def _apply_to_inner(self, fn, *arg, **kw):
sub_element = fn(*arg, **kw)
if sub_element is not self._element:
- return Label(self.name, sub_element, type_=self._type)
+ return Label(self.name, sub_element, type_=self.type)
else:
return self
@@ -3693,8 +3975,8 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]):
)
self.key = self._tq_label = self._tq_key_label = self.name
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return self.element._from_objects
def _make_proxy(self, selectable, name=None, **kw):
@@ -3724,15 +4006,16 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]):
e._propagate_attrs = selectable._propagate_attrs
e._proxies.append(self)
- if self._type is not None:
- e.type = self._type
+ if self.type is not None:
+ e.type = self.type
return self.key, e
class NamedColumn(ColumnElement[_T]):
is_literal = False
- table = None
+ table: Optional[FromClause] = None
+ name: str
def _compare_name_for_result(self, other):
return (hasattr(other, "name") and self.name == other.name) or (
@@ -3740,7 +4023,7 @@ class NamedColumn(ColumnElement[_T]):
)
@util.memoized_property
- def description(self):
+ def description(self) -> str:
return self.name
@HasMemoized.memoized_attribute
@@ -3759,7 +4042,7 @@ class NamedColumn(ColumnElement[_T]):
return self._tq_label
@HasMemoized.memoized_attribute
- def _tq_label(self):
+ def _tq_label(self) -> Optional[str]:
"""table qualified label based on column name.
for table-bound columns this is <tablename>_<columnname>; all other
@@ -3776,7 +4059,9 @@ class NamedColumn(ColumnElement[_T]):
def _non_anon_label(self):
return self.name
- def _gen_tq_label(self, name, dedupe_on_key=True):
+ def _gen_tq_label(
+ self, name: str, dedupe_on_key: bool = True
+ ) -> Optional[str]:
return name
def _bind_param(self, operator, obj, type_=None, expanding=False):
@@ -3817,7 +4102,7 @@ class NamedColumn(ColumnElement[_T]):
class ColumnClause(
roles.DDLReferredColumnRole,
- roles.LabeledColumnExprRole,
+ roles.LabeledColumnExprRole[_T],
roles.StrAsPlainColumnRole,
Immutable,
NamedColumn[_T],
@@ -3859,30 +4144,31 @@ class ColumnClause(
"""
- table = None
- is_literal = False
+ table: Optional[FromClause]
+ is_literal: bool
__visit_name__ = "column"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("name", InternalTraversal.dp_anon_name),
("type", InternalTraversal.dp_type),
("table", InternalTraversal.dp_clauseelement),
("is_literal", InternalTraversal.dp_boolean),
]
- onupdate = default = server_default = server_onupdate = None
+ onupdate: Optional[DefaultGenerator] = None
+ default: Optional[DefaultGenerator] = None
+ server_default: Optional[DefaultGenerator] = None
+ server_onupdate: Optional[DefaultGenerator] = None
_is_multiparam_column = False
def __init__(
self,
text: str,
- type_: Optional[
- Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
- ] = None,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
is_literal: bool = False,
- _selectable: Optional["FromClause"] = None,
+ _selectable: Optional[FromClause] = None,
):
self.key = self.name = text
self.table = _selectable
@@ -3916,7 +4202,7 @@ class ColumnClause(
return super(ColumnClause, self)._clone(**kw)
@HasMemoized.memoized_attribute
- def _from_objects(self):
+ def _from_objects(self) -> List[FromClause]:
t = self.table
if t is not None:
return [t]
@@ -3953,7 +4239,9 @@ class ColumnClause(
else:
return other.proxy_set.intersection(self.proxy_set)
- def _gen_tq_label(self, name, dedupe_on_key=True):
+ def _gen_tq_label(
+ self, name: str, dedupe_on_key: bool = True
+ ) -> Optional[str]:
"""generate table-qualified label
for a table-bound column this is <tablename>_<columnname>.
@@ -3962,22 +4250,24 @@ class ColumnClause(
as well as the .columns collection on a Join object.
"""
+ label: str
t = self.table
if self.is_literal:
return None
- elif t is not None and t.named_with_column:
- if getattr(t, "schema", None):
+ elif t is not None and is_named_from_clause(t):
+ if has_schema_attr(t) and t.schema:
label = t.schema.replace(".", "_") + "_" + t.name + "_" + name
else:
+ assert not TYPE_CHECKING or isinstance(t, NamedFromClause)
label = t.name + "_" + name
# propagate name quoting rules for labels.
- if getattr(name, "quote", None) is not None:
- if isinstance(label, quoted_name):
+ if is_quoted_name(name) and name.quote is not None:
+ if is_quoted_name(label):
label.quote = name.quote
else:
label = quoted_name(label, name.quote)
- elif getattr(t.name, "quote", None) is not None:
+ elif is_quoted_name(t.name) and t.name.quote is not None:
# can't get this situation to occur, so let's
# assert false on it for now
assert not isinstance(label, quoted_name)
@@ -4046,16 +4336,16 @@ class ColumnClause(
return c.key, c
-class TableValuedColumn(NamedColumn):
+class TableValuedColumn(NamedColumn[_T]):
__visit_name__ = "table_valued_column"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("name", InternalTraversal.dp_anon_name),
("type", InternalTraversal.dp_type),
("scalar_alias", InternalTraversal.dp_clauseelement),
]
- def __init__(self, scalar_alias, type_):
+ def __init__(self, scalar_alias: NamedFromClause, type_: TypeEngine[_T]):
self.scalar_alias = scalar_alias
self.key = self.name = scalar_alias.name
self.type = type_
@@ -4064,24 +4354,28 @@ class TableValuedColumn(NamedColumn):
self.scalar_alias = clone(self.scalar_alias, **kw)
self.key = self.name = self.scalar_alias.name
- @property
- def _from_objects(self):
+ @util.non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return [self.scalar_alias]
-class CollationClause(ColumnElement):
+class CollationClause(ColumnElement[str]):
__visit_name__ = "collation"
- _traverse_internals = [("collation", InternalTraversal.dp_string)]
+ _traverse_internals: _TraverseInternalsType = [
+ ("collation", InternalTraversal.dp_string)
+ ]
@classmethod
- def _create_collation_expression(cls, expression, collation):
+ def _create_collation_expression(
+ cls, expression: _ColumnExpression[str], collation: str
+ ) -> BinaryExpression[str]:
expr = coercions.expect(roles.ExpressionElementRole, expression)
return BinaryExpression(
expr,
CollationClause(collation),
operators.collate,
- type_=expression.type,
+ type_=expr.type,
)
def __init__(self, collation):
@@ -4163,6 +4457,8 @@ class quoted_name(util.MemoizedSlots, str):
__slots__ = "quote", "lower", "upper"
+ quote: Optional[bool]
+
def __new__(cls, value, quote):
if value is None:
return None
@@ -4196,10 +4492,10 @@ class quoted_name(util.MemoizedSlots, str):
return str(self).upper()
-def _find_columns(clause):
+def _find_columns(clause: ClauseElement) -> Set[ColumnClause[Any]]:
"""locate Column objects within the given expression."""
- cols = util.column_set()
+ cols: Set[ColumnClause[Any]] = set()
traverse(clause, {}, {"column": cols.add})
return cols
@@ -4226,6 +4522,8 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False):
class AnnotatedColumnElement(Annotated):
+ _Annotated__element: ColumnElement[Any]
+
def __init__(self, element, values):
Annotated.__init__(self, element, values)
for attr in (
@@ -4265,7 +4563,7 @@ class AnnotatedColumnElement(Annotated):
return self._Annotated__element.info
@util.memoized_property
- def _anon_name_label(self):
+ def _anon_name_label(self) -> str:
return self._Annotated__element._anon_name_label
@@ -4353,8 +4651,12 @@ class _anonymous_label(_truncated_label):
@classmethod
def safe_construct(
- cls, seed, body, enclosing_label=None, sanitize_key=False
- ) -> "_anonymous_label":
+ cls,
+ seed: int,
+ body: str,
+ enclosing_label: Optional[str] = None,
+ sanitize_key: bool = False,
+ ) -> _anonymous_label:
if sanitize_key:
body = re.sub(r"[%\(\) \$]+", "_", body).strip("_")
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
index f08e71bcd..7db1631c8 100644
--- a/lib/sqlalchemy/sql/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
@@ -12,6 +12,7 @@
from __future__ import annotations
+from enum import IntEnum
from operator import add as _uncast_add
from operator import and_ as _uncast_and_
from operator import contains as _uncast_contains
@@ -36,15 +37,17 @@ import typing
from typing import Any
from typing import Callable
from typing import cast
+from typing import Dict
from typing import Generic
from typing import Optional
-from typing import overload
+from typing import Set
from typing import Type
from typing import TypeVar
from typing import Union
from .. import exc
from .. import util
+from ..util.typing import Literal
from ..util.typing import Protocol
if typing.TYPE_CHECKING:
@@ -52,8 +55,8 @@ if typing.TYPE_CHECKING:
from .elements import ColumnElement
from .type_api import TypeEngine
-_OP_RETURN = TypeVar("_OP_RETURN", bound=Any, covariant=True)
_T = TypeVar("_T", bound=Any)
+_FN = TypeVar("_FN", bound=Callable[..., Any])
class OperatorType(Protocol):
@@ -64,8 +67,12 @@ class OperatorType(Protocol):
__name__: str
def __call__(
- self, left: "Operators[_OP_RETURN]", *other: Any, **kwargs: Any
- ) -> "_OP_RETURN":
+ self,
+ left: "Operators",
+ right: Optional[Any] = None,
+ *other: Any,
+ **kwargs: Any,
+ ) -> "Operators":
...
@@ -91,7 +98,7 @@ sub = cast(OperatorType, _uncast_sub)
truediv = cast(OperatorType, _uncast_truediv)
-class Operators(Generic[_OP_RETURN]):
+class Operators:
"""Base of comparison and logical operators.
Implements base methods
@@ -108,7 +115,7 @@ class Operators(Generic[_OP_RETURN]):
__slots__ = ()
- def __and__(self, other: Any) -> "Operators":
+ def __and__(self, other: Any) -> Operators:
"""Implement the ``&`` operator.
When used with SQL expressions, results in an
@@ -132,7 +139,7 @@ class Operators(Generic[_OP_RETURN]):
"""
return self.operate(and_, other)
- def __or__(self, other: Any) -> "Operators":
+ def __or__(self, other: Any) -> Operators:
"""Implement the ``|`` operator.
When used with SQL expressions, results in an
@@ -156,7 +163,7 @@ class Operators(Generic[_OP_RETURN]):
"""
return self.operate(or_, other)
- def __invert__(self) -> "Operators":
+ def __invert__(self) -> Operators:
"""Implement the ``~`` operator.
When used with SQL expressions, results in a
@@ -175,14 +182,14 @@ class Operators(Generic[_OP_RETURN]):
def op(
self,
- opstring: Any,
+ opstring: str,
precedence: int = 0,
is_comparison: bool = False,
return_type: Optional[
Union[Type["TypeEngine[Any]"], "TypeEngine[Any]"]
] = None,
- python_impl=None,
- ) -> Callable[[Any], Any]:
+ python_impl: Optional[Callable[..., Any]] = None,
+ ) -> Callable[[Any], Operators]:
"""Produce a generic operator function.
e.g.::
@@ -200,7 +207,7 @@ class Operators(Generic[_OP_RETURN]):
is a bitwise AND of the value in ``somecolumn``.
- :param operator: a string which will be output as the infix operator
+ :param opstring: a string which will be output as the infix operator
between this element and the expression passed to the
generated function.
@@ -263,14 +270,17 @@ class Operators(Generic[_OP_RETURN]):
python_impl=python_impl,
)
- def against(other: Any) -> _OP_RETURN:
- return operator(self, other)
+ def against(other: Any) -> Operators:
+ return operator(self, other) # type: ignore
return against
def bool_op(
- self, opstring: Any, precedence: int = 0, python_impl=None
- ) -> Callable[[Any], Any]:
+ self,
+ opstring: str,
+ precedence: int = 0,
+ python_impl: Optional[Callable[..., Any]] = None,
+ ) -> Callable[[Any], Operators]:
"""Return a custom boolean operator.
This method is shorthand for calling
@@ -292,7 +302,7 @@ class Operators(Generic[_OP_RETURN]):
def operate(
self, op: OperatorType, *other: Any, **kwargs: Any
- ) -> _OP_RETURN:
+ ) -> Operators:
r"""Operate on an argument.
This is the lowest level of operation, raises
@@ -322,7 +332,7 @@ class Operators(Generic[_OP_RETURN]):
def reverse_operate(
self, op: OperatorType, other: Any, **kwargs: Any
- ) -> _OP_RETURN:
+ ) -> Operators:
"""Reverse operate on an argument.
Usage is the same as :meth:`operate`.
@@ -379,7 +389,7 @@ class custom_op(OperatorType, Generic[_T]):
] = None,
natural_self_precedent: bool = False,
eager_grouping: bool = False,
- python_impl=None,
+ python_impl: Optional[Callable[..., Any]] = None,
):
self.opstring = opstring
self.precedence = precedence
@@ -397,25 +407,17 @@ class custom_op(OperatorType, Generic[_T]):
def __hash__(self) -> int:
return id(self)
- @overload
def __call__(
- self, left: "ColumnElement", right: Any, **kw
- ) -> "BinaryExpression[_T]":
- ...
-
- @overload
- def __call__(
- self, left: "Operators[_OP_RETURN]", right: Any, **kw
- ) -> _OP_RETURN:
- ...
-
- def __call__(
- self, left: "Operators[_OP_RETURN]", right: Any, **kw
- ) -> _OP_RETURN:
+ self,
+ left: Operators,
+ right: Optional[Any] = None,
+ *other: Any,
+ **kwargs: Any,
+ ) -> Operators:
if hasattr(left, "__sa_operate__"):
- return left.operate(self, right, **kw)
+ return left.operate(self, right, *other, **kwargs)
elif self.python_impl:
- return self.python_impl(left, right, **kw)
+ return self.python_impl(left, right, *other, **kwargs) # type: ignore # noqa E501
else:
raise exc.InvalidRequestError(
f"Custom operator {self.opstring!r} can't be used with "
@@ -424,7 +426,7 @@ class custom_op(OperatorType, Generic[_T]):
)
-class ColumnOperators(Operators[_OP_RETURN]):
+class ColumnOperators(Operators):
"""Defines boolean, comparison, and other operators for
:class:`_expression.ColumnElement` expressions.
@@ -464,22 +466,22 @@ class ColumnOperators(Operators[_OP_RETURN]):
__slots__ = ()
- timetuple = None
+ timetuple: Literal[None] = None
"""Hack, allows datetime objects to be compared on the LHS."""
if typing.TYPE_CHECKING:
def operate(
self, op: OperatorType, *other: Any, **kwargs: Any
- ) -> "ColumnOperators":
+ ) -> ColumnOperators:
...
def reverse_operate(
self, op: OperatorType, other: Any, **kwargs: Any
- ) -> "ColumnOperators":
+ ) -> ColumnOperators:
...
- def __lt__(self, other: Any) -> "ColumnOperators":
+ def __lt__(self, other: Any) -> ColumnOperators:
"""Implement the ``<`` operator.
In a column context, produces the clause ``a < b``.
@@ -487,7 +489,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(lt, other)
- def __le__(self, other: Any) -> "ColumnOperators":
+ def __le__(self, other: Any) -> ColumnOperators:
"""Implement the ``<=`` operator.
In a column context, produces the clause ``a <= b``.
@@ -498,7 +500,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
# TODO: not sure why we have this
__hash__ = Operators.__hash__ # type: ignore
- def __eq__(self, other: Any) -> "ColumnOperators":
+ def __eq__(self, other: Any) -> ColumnOperators: # type: ignore[override]
"""Implement the ``==`` operator.
In a column context, produces the clause ``a = b``.
@@ -507,7 +509,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(eq, other)
- def __ne__(self, other: Any) -> "ColumnOperators":
+ def __ne__(self, other: Any) -> ColumnOperators: # type: ignore[override]
"""Implement the ``!=`` operator.
In a column context, produces the clause ``a != b``.
@@ -516,7 +518,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(ne, other)
- def is_distinct_from(self, other: Any) -> "ColumnOperators":
+ def is_distinct_from(self, other: Any) -> ColumnOperators:
"""Implement the ``IS DISTINCT FROM`` operator.
Renders "a IS DISTINCT FROM b" on most platforms;
@@ -527,7 +529,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(is_distinct_from, other)
- def is_not_distinct_from(self, other: Any) -> "ColumnOperators":
+ def is_not_distinct_from(self, other: Any) -> ColumnOperators:
"""Implement the ``IS NOT DISTINCT FROM`` operator.
Renders "a IS NOT DISTINCT FROM b" on most platforms;
@@ -545,7 +547,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
# deprecated 1.4; see #5435
isnot_distinct_from = is_not_distinct_from
- def __gt__(self, other: Any) -> "ColumnOperators":
+ def __gt__(self, other: Any) -> ColumnOperators:
"""Implement the ``>`` operator.
In a column context, produces the clause ``a > b``.
@@ -553,7 +555,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(gt, other)
- def __ge__(self, other: Any) -> "ColumnOperators":
+ def __ge__(self, other: Any) -> ColumnOperators:
"""Implement the ``>=`` operator.
In a column context, produces the clause ``a >= b``.
@@ -561,7 +563,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(ge, other)
- def __neg__(self) -> "ColumnOperators":
+ def __neg__(self) -> ColumnOperators:
"""Implement the ``-`` operator.
In a column context, produces the clause ``-a``.
@@ -569,10 +571,10 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(neg)
- def __contains__(self, other: Any) -> "ColumnOperators":
+ def __contains__(self, other: Any) -> ColumnOperators:
return self.operate(contains, other)
- def __getitem__(self, index: Any) -> "ColumnOperators":
+ def __getitem__(self, index: Any) -> ColumnOperators:
"""Implement the [] operator.
This can be used by some database-specific types
@@ -581,7 +583,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(getitem, index)
- def __lshift__(self, other: Any) -> "ColumnOperators":
+ def __lshift__(self, other: Any) -> ColumnOperators:
"""implement the << operator.
Not used by SQLAlchemy core, this is provided
@@ -590,7 +592,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(lshift, other)
- def __rshift__(self, other: Any) -> "ColumnOperators":
+ def __rshift__(self, other: Any) -> ColumnOperators:
"""implement the >> operator.
Not used by SQLAlchemy core, this is provided
@@ -599,7 +601,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(rshift, other)
- def concat(self, other: Any) -> "ColumnOperators":
+ def concat(self, other: Any) -> ColumnOperators:
"""Implement the 'concat' operator.
In a column context, produces the clause ``a || b``,
@@ -608,7 +610,9 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(concat_op, other)
- def like(self, other: Any, escape=None) -> "ColumnOperators":
+ def like(
+ self, other: Any, escape: Optional[str] = None
+ ) -> ColumnOperators:
r"""Implement the ``like`` operator.
In a column context, produces the expression::
@@ -633,7 +637,9 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(like_op, other, escape=escape)
- def ilike(self, other: Any, escape=None) -> "ColumnOperators":
+ def ilike(
+ self, other: Any, escape: Optional[str] = None
+ ) -> ColumnOperators:
r"""Implement the ``ilike`` operator, e.g. case insensitive LIKE.
In a column context, produces an expression either of the form::
@@ -662,7 +668,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(ilike_op, other, escape=escape)
- def in_(self, other: Any) -> "ColumnOperators":
+ def in_(self, other: Any) -> ColumnOperators:
"""Implement the ``in`` operator.
In a column context, produces the clause ``column IN <other>``.
@@ -751,7 +757,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(in_op, other)
- def not_in(self, other: Any) -> "ColumnOperators":
+ def not_in(self, other: Any) -> ColumnOperators:
"""implement the ``NOT IN`` operator.
This is equivalent to using negation with
@@ -782,7 +788,9 @@ class ColumnOperators(Operators[_OP_RETURN]):
# deprecated 1.4; see #5429
notin_ = not_in
- def not_like(self, other: Any, escape=None) -> "ColumnOperators":
+ def not_like(
+ self, other: Any, escape: Optional[str] = None
+ ) -> ColumnOperators:
"""implement the ``NOT LIKE`` operator.
This is equivalent to using negation with
@@ -802,7 +810,9 @@ class ColumnOperators(Operators[_OP_RETURN]):
# deprecated 1.4; see #5435
notlike = not_like
- def not_ilike(self, other: Any, escape=None) -> "ColumnOperators":
+ def not_ilike(
+ self, other: Any, escape: Optional[str] = None
+ ) -> ColumnOperators:
"""implement the ``NOT ILIKE`` operator.
This is equivalent to using negation with
@@ -822,7 +832,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
# deprecated 1.4; see #5435
notilike = not_ilike
- def is_(self, other: Any) -> "ColumnOperators":
+ def is_(self, other: Any) -> ColumnOperators:
"""Implement the ``IS`` operator.
Normally, ``IS`` is generated automatically when comparing to a
@@ -835,7 +845,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(is_, other)
- def is_not(self, other: Any) -> "ColumnOperators":
+ def is_not(self, other: Any) -> ColumnOperators:
"""Implement the ``IS NOT`` operator.
Normally, ``IS NOT`` is generated automatically when comparing to a
@@ -856,8 +866,11 @@ class ColumnOperators(Operators[_OP_RETURN]):
isnot = is_not
def startswith(
- self, other: Any, escape=None, autoescape=False
- ) -> "ColumnOperators":
+ self,
+ other: Any,
+ escape: Optional[str] = None,
+ autoescape: bool = False,
+ ) -> ColumnOperators:
r"""Implement the ``startswith`` operator.
Produces a LIKE expression that tests against a match for the start
@@ -939,8 +952,11 @@ class ColumnOperators(Operators[_OP_RETURN]):
)
def endswith(
- self, other: Any, escape=None, autoescape=False
- ) -> "ColumnOperators":
+ self,
+ other: Any,
+ escape: Optional[str] = None,
+ autoescape: bool = False,
+ ) -> ColumnOperators:
r"""Implement the 'endswith' operator.
Produces a LIKE expression that tests against a match for the end
@@ -1021,7 +1037,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
endswith_op, other, escape=escape, autoescape=autoescape
)
- def contains(self, other: Any, **kw: Any) -> "ColumnOperators":
+ def contains(self, other: Any, **kw: Any) -> ColumnOperators:
r"""Implement the 'contains' operator.
Produces a LIKE expression that tests against a match for the middle
@@ -1101,7 +1117,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(contains_op, other, **kw)
- def match(self, other: Any, **kwargs) -> "ColumnOperators":
+ def match(self, other: Any, **kwargs: Any) -> ColumnOperators:
"""Implements a database-specific 'match' operator.
:meth:`_sql.ColumnOperators.match` attempts to resolve to
@@ -1125,7 +1141,9 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(match_op, other, **kwargs)
- def regexp_match(self, pattern, flags=None) -> "ColumnOperators":
+ def regexp_match(
+ self, pattern: Any, flags: Optional[str] = None
+ ) -> ColumnOperators:
"""Implements a database-specific 'regexp match' operator.
E.g.::
@@ -1174,8 +1192,8 @@ class ColumnOperators(Operators[_OP_RETURN]):
return self.operate(regexp_match_op, pattern, flags=flags)
def regexp_replace(
- self, pattern, replacement, flags=None
- ) -> "ColumnOperators":
+ self, pattern: Any, replacement: Any, flags: Optional[str] = None
+ ) -> ColumnOperators:
"""Implements a database-specific 'regexp replace' operator.
E.g.::
@@ -1220,17 +1238,17 @@ class ColumnOperators(Operators[_OP_RETURN]):
flags=flags,
)
- def desc(self) -> "ColumnOperators":
+ def desc(self) -> ColumnOperators:
"""Produce a :func:`_expression.desc` clause against the
parent object."""
return self.operate(desc_op)
- def asc(self) -> "ColumnOperators":
+ def asc(self) -> ColumnOperators:
"""Produce a :func:`_expression.asc` clause against the
parent object."""
return self.operate(asc_op)
- def nulls_first(self) -> "ColumnOperators":
+ def nulls_first(self) -> ColumnOperators:
"""Produce a :func:`_expression.nulls_first` clause against the
parent object.
@@ -1243,7 +1261,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
# deprecated 1.4; see #5435
nullsfirst = nulls_first
- def nulls_last(self) -> "ColumnOperators":
+ def nulls_last(self) -> ColumnOperators:
"""Produce a :func:`_expression.nulls_last` clause against the
parent object.
@@ -1256,7 +1274,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
# deprecated 1.4; see #5429
nullslast = nulls_last
- def collate(self, collation) -> "ColumnOperators":
+ def collate(self, collation: str) -> ColumnOperators:
"""Produce a :func:`_expression.collate` clause against
the parent object, given the collation string.
@@ -1267,7 +1285,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(collate, collation)
- def __radd__(self, other: Any) -> "ColumnOperators":
+ def __radd__(self, other: Any) -> ColumnOperators:
"""Implement the ``+`` operator in reverse.
See :meth:`.ColumnOperators.__add__`.
@@ -1275,7 +1293,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.reverse_operate(add, other)
- def __rsub__(self, other: Any) -> "ColumnOperators":
+ def __rsub__(self, other: Any) -> ColumnOperators:
"""Implement the ``-`` operator in reverse.
See :meth:`.ColumnOperators.__sub__`.
@@ -1283,7 +1301,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.reverse_operate(sub, other)
- def __rmul__(self, other: Any) -> "ColumnOperators":
+ def __rmul__(self, other: Any) -> ColumnOperators:
"""Implement the ``*`` operator in reverse.
See :meth:`.ColumnOperators.__mul__`.
@@ -1291,7 +1309,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.reverse_operate(mul, other)
- def __rmod__(self, other: Any) -> "ColumnOperators":
+ def __rmod__(self, other: Any) -> ColumnOperators:
"""Implement the ``%`` operator in reverse.
See :meth:`.ColumnOperators.__mod__`.
@@ -1299,21 +1317,23 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.reverse_operate(mod, other)
- def between(self, cleft, cright, symmetric=False) -> "ColumnOperators":
+ def between(
+ self, cleft: Any, cright: Any, symmetric: bool = False
+ ) -> ColumnOperators:
"""Produce a :func:`_expression.between` clause against
the parent object, given the lower and upper range.
"""
return self.operate(between_op, cleft, cright, symmetric=symmetric)
- def distinct(self) -> "ColumnOperators":
+ def distinct(self) -> ColumnOperators:
"""Produce a :func:`_expression.distinct` clause against the
parent object.
"""
return self.operate(distinct_op)
- def any_(self) -> "ColumnOperators":
+ def any_(self) -> ColumnOperators:
"""Produce an :func:`_expression.any_` clause against the
parent object.
@@ -1330,7 +1350,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(any_op)
- def all_(self) -> "ColumnOperators":
+ def all_(self) -> ColumnOperators:
"""Produce an :func:`_expression.all_` clause against the
parent object.
@@ -1348,7 +1368,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(all_op)
- def __add__(self, other: Any) -> "ColumnOperators":
+ def __add__(self, other: Any) -> ColumnOperators:
"""Implement the ``+`` operator.
In a column context, produces the clause ``a + b``
@@ -1360,7 +1380,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(add, other)
- def __sub__(self, other: Any) -> "ColumnOperators":
+ def __sub__(self, other: Any) -> ColumnOperators:
"""Implement the ``-`` operator.
In a column context, produces the clause ``a - b``.
@@ -1368,7 +1388,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(sub, other)
- def __mul__(self, other: Any) -> "ColumnOperators":
+ def __mul__(self, other: Any) -> ColumnOperators:
"""Implement the ``*`` operator.
In a column context, produces the clause ``a * b``.
@@ -1376,7 +1396,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(mul, other)
- def __mod__(self, other: Any) -> "ColumnOperators":
+ def __mod__(self, other: Any) -> ColumnOperators:
"""Implement the ``%`` operator.
In a column context, produces the clause ``a % b``.
@@ -1384,7 +1404,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(mod, other)
- def __truediv__(self, other: Any) -> "ColumnOperators":
+ def __truediv__(self, other: Any) -> ColumnOperators:
"""Implement the ``/`` operator.
In a column context, produces the clause ``a / b``, and
@@ -1397,7 +1417,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(truediv, other)
- def __rtruediv__(self, other: Any) -> "ColumnOperators":
+ def __rtruediv__(self, other: Any) -> ColumnOperators:
"""Implement the ``/`` operator in reverse.
See :meth:`.ColumnOperators.__truediv__`.
@@ -1405,7 +1425,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.reverse_operate(truediv, other)
- def __floordiv__(self, other: Any) -> "ColumnOperators":
+ def __floordiv__(self, other: Any) -> ColumnOperators:
"""Implement the ``//`` operator.
In a column context, produces the clause ``a / b``,
@@ -1417,7 +1437,7 @@ class ColumnOperators(Operators[_OP_RETURN]):
"""
return self.operate(floordiv, other)
- def __rfloordiv__(self, other: Any) -> "ColumnOperators":
+ def __rfloordiv__(self, other: Any) -> ColumnOperators:
"""Implement the ``//`` operator in reverse.
See :meth:`.ColumnOperators.__floordiv__`.
@@ -1426,43 +1446,47 @@ class ColumnOperators(Operators[_OP_RETURN]):
return self.reverse_operate(floordiv, other)
-_commutative = {eq, ne, add, mul}
-_comparison = {eq, ne, lt, gt, ge, le}
+_commutative: Set[Any] = {eq, ne, add, mul}
+_comparison: Set[Any] = {eq, ne, lt, gt, ge, le}
-def _operator_fn(fn):
+def _operator_fn(fn: Callable[..., Any]) -> OperatorType:
return cast(OperatorType, fn)
-def commutative_op(fn):
+def commutative_op(fn: _FN) -> _FN:
_commutative.add(fn)
return fn
-def comparison_op(fn):
+def comparison_op(fn: _FN) -> _FN:
_comparison.add(fn)
return fn
-def from_():
+@_operator_fn
+def from_() -> Any:
raise NotImplementedError()
+@_operator_fn
@comparison_op
-def function_as_comparison_op():
+def function_as_comparison_op() -> Any:
raise NotImplementedError()
-def as_():
+@_operator_fn
+def as_() -> Any:
raise NotImplementedError()
-def exists():
+@_operator_fn
+def exists() -> Any:
raise NotImplementedError()
@_operator_fn
-def is_true(a):
+def is_true(a: Any) -> Any:
raise NotImplementedError()
@@ -1471,7 +1495,7 @@ istrue = is_true
@_operator_fn
-def is_false(a):
+def is_false(a: Any) -> Any:
raise NotImplementedError()
@@ -1481,13 +1505,13 @@ isfalse = is_false
@comparison_op
@_operator_fn
-def is_distinct_from(a, b):
+def is_distinct_from(a: Any, b: Any) -> Any:
return a.is_distinct_from(b)
@comparison_op
@_operator_fn
-def is_not_distinct_from(a, b):
+def is_not_distinct_from(a: Any, b: Any) -> Any:
return a.is_not_distinct_from(b)
@@ -1497,13 +1521,13 @@ isnot_distinct_from = is_not_distinct_from
@comparison_op
@_operator_fn
-def is_(a, b):
+def is_(a: Any, b: Any) -> Any:
return a.is_(b)
@comparison_op
@_operator_fn
-def is_not(a, b):
+def is_not(a: Any, b: Any) -> Any:
return a.is_not(b)
@@ -1512,24 +1536,24 @@ isnot = is_not
@_operator_fn
-def collate(a, b):
+def collate(a: Any, b: Any) -> Any:
return a.collate(b)
@_operator_fn
-def op(a, opstring, b):
+def op(a: Any, opstring: str, b: Any) -> Any:
return a.op(opstring)(b)
@comparison_op
@_operator_fn
-def like_op(a, b, escape=None):
+def like_op(a: Any, b: Any, escape: Optional[str] = None) -> Any:
return a.like(b, escape=escape)
@comparison_op
@_operator_fn
-def not_like_op(a, b, escape=None):
+def not_like_op(a: Any, b: Any, escape: Optional[str] = None) -> Any:
return a.notlike(b, escape=escape)
@@ -1539,13 +1563,13 @@ notlike_op = not_like_op
@comparison_op
@_operator_fn
-def ilike_op(a, b, escape=None):
+def ilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any:
return a.ilike(b, escape=escape)
@comparison_op
@_operator_fn
-def not_ilike_op(a, b, escape=None):
+def not_ilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any:
return a.not_ilike(b, escape=escape)
@@ -1555,13 +1579,13 @@ notilike_op = not_ilike_op
@comparison_op
@_operator_fn
-def between_op(a, b, c, symmetric=False):
+def between_op(a: Any, b: Any, c: Any, symmetric: bool = False) -> Any:
return a.between(b, c, symmetric=symmetric)
@comparison_op
@_operator_fn
-def not_between_op(a, b, c, symmetric=False):
+def not_between_op(a: Any, b: Any, c: Any, symmetric: bool = False) -> Any:
return ~a.between(b, c, symmetric=symmetric)
@@ -1571,13 +1595,13 @@ notbetween_op = not_between_op
@comparison_op
@_operator_fn
-def in_op(a, b):
+def in_op(a: Any, b: Any) -> Any:
return a.in_(b)
@comparison_op
@_operator_fn
-def not_in_op(a, b):
+def not_in_op(a: Any, b: Any) -> Any:
return a.not_in(b)
@@ -1586,21 +1610,23 @@ notin_op = not_in_op
@_operator_fn
-def distinct_op(a):
+def distinct_op(a: Any) -> Any:
return a.distinct()
@_operator_fn
-def any_op(a):
+def any_op(a: Any) -> Any:
return a.any_()
@_operator_fn
-def all_op(a):
+def all_op(a: Any) -> Any:
return a.all_()
-def _escaped_like_impl(fn, other: Any, escape, autoescape):
+def _escaped_like_impl(
+ fn: Callable[..., Any], other: Any, escape: Optional[str], autoescape: bool
+) -> Any:
if autoescape:
if autoescape is not True:
util.warn(
@@ -1622,13 +1648,17 @@ def _escaped_like_impl(fn, other: Any, escape, autoescape):
@comparison_op
@_operator_fn
-def startswith_op(a, b, escape=None, autoescape=False):
+def startswith_op(
+ a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False
+) -> Any:
return _escaped_like_impl(a.startswith, b, escape, autoescape)
@comparison_op
@_operator_fn
-def not_startswith_op(a, b, escape=None, autoescape=False):
+def not_startswith_op(
+ a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False
+) -> Any:
return ~_escaped_like_impl(a.startswith, b, escape, autoescape)
@@ -1638,13 +1668,17 @@ notstartswith_op = not_startswith_op
@comparison_op
@_operator_fn
-def endswith_op(a, b, escape=None, autoescape=False):
+def endswith_op(
+ a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False
+) -> Any:
return _escaped_like_impl(a.endswith, b, escape, autoescape)
@comparison_op
@_operator_fn
-def not_endswith_op(a, b, escape=None, autoescape=False):
+def not_endswith_op(
+ a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False
+) -> Any:
return ~_escaped_like_impl(a.endswith, b, escape, autoescape)
@@ -1654,13 +1688,17 @@ notendswith_op = not_endswith_op
@comparison_op
@_operator_fn
-def contains_op(a, b, escape=None, autoescape=False):
+def contains_op(
+ a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False
+) -> Any:
return _escaped_like_impl(a.contains, b, escape, autoescape)
@comparison_op
@_operator_fn
-def not_contains_op(a, b, escape=None, autoescape=False):
+def not_contains_op(
+ a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False
+) -> Any:
return ~_escaped_like_impl(a.contains, b, escape, autoescape)
@@ -1670,30 +1708,32 @@ notcontains_op = not_contains_op
@comparison_op
@_operator_fn
-def match_op(a, b, **kw):
+def match_op(a: Any, b: Any, **kw: Any) -> Any:
return a.match(b, **kw)
@comparison_op
@_operator_fn
-def regexp_match_op(a, b, flags=None):
+def regexp_match_op(a: Any, b: Any, flags: Optional[str] = None) -> Any:
return a.regexp_match(b, flags=flags)
@comparison_op
@_operator_fn
-def not_regexp_match_op(a, b, flags=None):
+def not_regexp_match_op(a: Any, b: Any, flags: Optional[str] = None) -> Any:
return ~a.regexp_match(b, flags=flags)
@_operator_fn
-def regexp_replace_op(a, b, replacement, flags=None):
+def regexp_replace_op(
+ a: Any, b: Any, replacement: Any, flags: Optional[str] = None
+) -> Any:
return a.regexp_replace(b, replacement=replacement, flags=flags)
@comparison_op
@_operator_fn
-def not_match_op(a, b, **kw):
+def not_match_op(a: Any, b: Any, **kw: Any) -> Any:
return ~a.match(b, **kw)
@@ -1702,32 +1742,32 @@ notmatch_op = not_match_op
@_operator_fn
-def comma_op(a, b):
+def comma_op(a: Any, b: Any) -> Any:
raise NotImplementedError()
@_operator_fn
-def filter_op(a, b):
+def filter_op(a: Any, b: Any) -> Any:
raise NotImplementedError()
@_operator_fn
-def concat_op(a, b):
+def concat_op(a: Any, b: Any) -> Any:
return a.concat(b)
@_operator_fn
-def desc_op(a):
+def desc_op(a: Any) -> Any:
return a.desc()
@_operator_fn
-def asc_op(a):
+def asc_op(a: Any) -> Any:
return a.asc()
@_operator_fn
-def nulls_first_op(a):
+def nulls_first_op(a: Any) -> Any:
return a.nulls_first()
@@ -1736,7 +1776,7 @@ nullsfirst_op = nulls_first_op
@_operator_fn
-def nulls_last_op(a):
+def nulls_last_op(a: Any) -> Any:
return a.nulls_last()
@@ -1745,28 +1785,28 @@ nullslast_op = nulls_last_op
@_operator_fn
-def json_getitem_op(a, b):
+def json_getitem_op(a: Any, b: Any) -> Any:
raise NotImplementedError()
@_operator_fn
-def json_path_getitem_op(a, b):
+def json_path_getitem_op(a: Any, b: Any) -> Any:
raise NotImplementedError()
-def is_comparison(op):
+def is_comparison(op: OperatorType) -> bool:
return op in _comparison or isinstance(op, custom_op) and op.is_comparison
-def is_commutative(op):
+def is_commutative(op: OperatorType) -> bool:
return op in _commutative
-def is_ordering_modifier(op):
+def is_ordering_modifier(op: OperatorType) -> bool:
return op in (asc_op, desc_op, nulls_first_op, nulls_last_op)
-def is_natural_self_precedent(op):
+def is_natural_self_precedent(op: OperatorType) -> bool:
return (
op in _natural_self_precedent
or isinstance(op, custom_op)
@@ -1777,14 +1817,14 @@ def is_natural_self_precedent(op):
_booleans = (inv, is_true, is_false, and_, or_)
-def is_boolean(op):
+def is_boolean(op: OperatorType) -> bool:
return is_comparison(op) or op in _booleans
_mirror = {gt: lt, ge: le, lt: gt, le: ge}
-def mirror(op):
+def mirror(op: OperatorType) -> OperatorType:
"""rotate a comparison operator 180 degrees.
Note this is not the same as negation.
@@ -1796,7 +1836,7 @@ def mirror(op):
_associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne])
-def is_associative(op):
+def is_associative(op: OperatorType) -> bool:
return op in _associative
@@ -1809,11 +1849,17 @@ parenthesize (a op b).
"""
-_asbool = util.symbol("_asbool", canonical=-10)
-_smallest = util.symbol("_smallest", canonical=-100)
-_largest = util.symbol("_largest", canonical=100)
+@_operator_fn
+def _asbool(a: Any) -> Any:
+ raise NotImplementedError()
-_PRECEDENCE = {
+
+class _OpLimit(IntEnum):
+ _smallest = -100
+ _largest = 100
+
+
+_PRECEDENCE: Dict[OperatorType, int] = {
from_: 15,
function_as_comparison_op: 15,
any_op: 15,
@@ -1866,15 +1912,18 @@ _PRECEDENCE = {
as_: -1,
exists: 0,
_asbool: -10,
- _smallest: _smallest,
- _largest: _largest,
}
-def is_precedent(operator, against):
+def is_precedent(operator: OperatorType, against: OperatorType) -> bool:
if operator is against and is_natural_self_precedent(operator):
return False
else:
- return _PRECEDENCE.get(
- operator, getattr(operator, "precedence", _smallest)
- ) <= _PRECEDENCE.get(against, getattr(against, "precedence", _largest))
+ return bool(
+ _PRECEDENCE.get(
+ operator, getattr(operator, "precedence", _OpLimit._smallest)
+ )
+ <= _PRECEDENCE.get(
+ against, getattr(against, "precedence", _OpLimit._largest)
+ )
+ )
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 1a7a5f4d4..4c4f49aa8 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -8,23 +8,28 @@ from __future__ import annotations
import typing
from typing import Any
+from typing import Generic
from typing import Iterable
-from typing import Mapping
from typing import Optional
-from typing import Sequence
+from typing import TypeVar
from .. import util
from ..util import TypingOnly
from ..util.typing import Literal
if typing.TYPE_CHECKING:
+ from ._typing import _PropagateAttrsType
from .base import ColumnCollection
from .elements import ClauseElement
+ from .elements import ColumnElement
from .elements import Label
from .selectable import FromClause
from .selectable import Subquery
+_T = TypeVar("_T", bound=Any)
+
+
class SQLRole:
"""Define a "role" within a SQL statement structure.
@@ -104,7 +109,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
_role_name = "Column expression or FROM clause"
@property
- def _select_iterable(self) -> Sequence[ColumnsClauseRole]:
+ def _select_iterable(self) -> Iterable[ColumnsClauseRole]:
raise NotImplementedError()
@@ -154,24 +159,24 @@ class WhereHavingRole(OnClauseRole):
_role_name = "SQL expression for WHERE/HAVING role"
-class ExpressionElementRole(SQLRole):
+class ExpressionElementRole(Generic[_T], SQLRole):
__slots__ = ()
_role_name = "SQL expression element"
- def label(self, name: Optional[str]) -> Label[Any]:
+ def label(self, name: Optional[str]) -> Label[_T]:
raise NotImplementedError()
-class ConstExprRole(ExpressionElementRole):
+class ConstExprRole(ExpressionElementRole[_T]):
__slots__ = ()
_role_name = "Constant True/False/None expression"
-class LabeledColumnExprRole(ExpressionElementRole):
+class LabeledColumnExprRole(ExpressionElementRole[_T]):
__slots__ = ()
-class BinaryElementRole(ExpressionElementRole):
+class BinaryElementRole(ExpressionElementRole[_T]):
__slots__ = ()
_role_name = "SQL expression element or literal value"
@@ -235,7 +240,7 @@ class StatementRole(SQLRole):
__slots__ = ()
_role_name = "Executable SQL or text() construct"
- _propagate_attrs: Mapping[str, Any] = util.immutabledict()
+ _propagate_attrs: _PropagateAttrsType = util.immutabledict()
class SelectStatementRole(StatementRole, ReturnsRowsRole):
@@ -317,7 +322,18 @@ class HasClauseElement(TypingOnly):
if typing.TYPE_CHECKING:
- def __clause_element__(self) -> "ClauseElement":
+ def __clause_element__(self) -> ClauseElement:
+ ...
+
+
+class HasColumnElementClauseElement(TypingOnly):
+ """indicates a class that has a __clause_element__() method"""
+
+ __slots__ = ()
+
+ if typing.TYPE_CHECKING:
+
+ def __clause_element__(self) -> ColumnElement[Any]:
...
@@ -328,5 +344,5 @@ class HasFromClauseElement(HasClauseElement, TypingOnly):
if typing.TYPE_CHECKING:
- def __clause_element__(self) -> "FromClause":
+ def __clause_element__(self) -> FromClause:
...
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 33e300bf6..78d524127 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -2671,6 +2671,7 @@ class DefaultGenerator(Executable, SchemaItem):
is_sequence = False
is_server_default = False
+ is_scalar = False
column = None
def __init__(self, for_update=False):
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 09befb078..0692483e9 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -20,9 +20,11 @@ from operator import attrgetter
import typing
from typing import Any as TODO_Any
from typing import Any
+from typing import Iterable
from typing import NamedTuple
from typing import Optional
from typing import Tuple
+from typing import TYPE_CHECKING
from typing import TypeVar
from . import cache_key
@@ -454,7 +456,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
__visit_name__ = "fromclause"
named_with_column = False
- _hide_froms = []
+
+ @property
+ def _hide_froms(self) -> Iterable[FromClause]:
+ return ()
+
+ _is_clone_of: Optional[FromClause]
schema = None
"""Define the 'schema' attribute for this :class:`_expression.FromClause`.
@@ -667,7 +674,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return self._cloned_set.intersection(other._cloned_set)
@property
- def description(self):
+ def description(self) -> str:
"""A brief description of this :class:`_expression.FromClause`.
Used primarily for error message formatting.
@@ -703,7 +710,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return self.columns
@util.memoized_property
- def columns(self):
+ def columns(self) -> ColumnCollection:
"""A named-based collection of :class:`_expression.ColumnElement`
objects maintained by this :class:`_expression.FromClause`.
@@ -787,19 +794,24 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
for key in ["_columns", "columns", "primary_key", "foreign_keys"]:
self.__dict__.pop(key, None)
- c = property(
- attrgetter("columns"),
- doc="""
- A named-based collection of :class:`_expression.ColumnElement`
- objects maintained by this :class:`_expression.FromClause`.
+ # this is awkward. maybe there's a better way
+ if TYPE_CHECKING:
+ c: ColumnCollection
+ else:
+ c = property(
+ attrgetter("columns"),
+ doc="""
+ A named-based collection of :class:`_expression.ColumnElement`
+ objects maintained by this :class:`_expression.FromClause`.
- The :attr:`_sql.FromClause.c` attribute is an alias for the
- :attr:`_sql.FromClause.columns` attribute.
+ The :attr:`_sql.FromClause.c` attribute is an alias for the
+ :attr:`_sql.FromClause.columns` attribute.
- :return: a :class:`.ColumnCollection`
+ :return: a :class:`.ColumnCollection`
+
+ """,
+ )
- """,
- )
_select_iterable = property(attrgetter("columns"))
def _init_collections(self):
@@ -1015,7 +1027,7 @@ class Join(roles.DMLTableRole, FromClause):
self.full = full
@property
- def description(self):
+ def description(self) -> str:
return "Join object on %s(%d) and %s(%d)" % (
self.left.description,
id(self.left),
@@ -1289,7 +1301,7 @@ class Join(roles.DMLTableRole, FromClause):
)
@property
- def _hide_froms(self):
+ def _hide_froms(self) -> Iterable[FromClause]:
return itertools.chain(
*[_from_objects(x.left, x.right) for x in self._cloned_set]
)
@@ -1370,7 +1382,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
self.element._refresh_for_new_column(column)
@property
- def description(self):
+ def description(self) -> str:
name = self.name
if isinstance(name, _anonymous_label):
name = "anon_1"
@@ -2301,6 +2313,8 @@ class FromGrouping(GroupedElement, FromClause):
_traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+ element: FromClause
+
def __init__(self, element):
self.element = coercions.expect(roles.FromClauseRole, element)
@@ -2329,7 +2343,7 @@ class FromGrouping(GroupedElement, FromClause):
return FromGrouping(self.element._anonymous_fromclause(**kw))
@property
- def _hide_froms(self):
+ def _hide_froms(self) -> Iterable[FromClause]:
return self.element._hide_froms
@property
@@ -2425,7 +2439,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
pass
@util.memoized_property
- def description(self):
+ def description(self) -> str:
return self.name
def append_column(self, c, **kw):
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index b2b1d9bc2..e64ec0843 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -18,7 +18,6 @@ import json
import pickle
from typing import Any
from typing import Sequence
-from typing import Text as typing_Text
from typing import Tuple
from typing import TypeVar
from typing import Union
@@ -132,7 +131,7 @@ class Indexable:
comparator_factory = Comparator
-class String(Concatenable, TypeEngine[typing_Text]):
+class String(Concatenable, TypeEngine[str]):
"""The base for all string and character types.
@@ -2793,11 +2792,13 @@ class ARRAY(
self.item_type._set_parent_with_dispatch(parent)
-class TupleType(TypeEngine[Tuple[Any]]):
+class TupleType(TypeEngine[Tuple[Any, ...]]):
"""represent the composite type of a Tuple."""
_is_tuple_type = True
+ types: List[TypeEngine[Any]]
+
def __init__(self, *types):
self._fully_typed = NULLTYPE not in types
self.types = [
@@ -2805,7 +2806,7 @@ class TupleType(TypeEngine[Tuple[Any]]):
for item_type in types
]
- def _resolve_values_to_types(self, value):
+ def _resolve_values_to_types(self, value: Any) -> TupleType:
if self._fully_typed:
return self
else:
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index cf9487f93..1f3d50876 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -17,6 +17,7 @@ from typing import Any
from typing import Callable
from typing import Deque
from typing import Dict
+from typing import Iterable
from typing import Set
from typing import Tuple
from typing import Type
@@ -226,7 +227,9 @@ class HasCopyInternals(HasTraverseInternals):
def _clone(self, **kw):
raise NotImplementedError()
- def _copy_internals(self, omit_attrs=(), **kw):
+ def _copy_internals(
+ self, omit_attrs: Iterable[str] = (), **kw: Any
+ ) -> None:
"""Reassign internal elements to be clones of themselves.
Called during a copy-and-traverse operation on newly
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index f76b4e462..55997556a 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -29,26 +29,19 @@ from .. import exc
from .. import util
# these are back-assigned by sqltypes.
-if not typing.TYPE_CHECKING:
- BOOLEANTYPE = None
- INTEGERTYPE = None
- NULLTYPE = None
- STRINGTYPE = None
- MATCHTYPE = None
- INDEXABLE = None
- TABLEVALUE = None
- _resolve_value_to_type = None
-
if typing.TYPE_CHECKING:
from .elements import ColumnElement
from .operators import OperatorType
- from .sqltypes import _resolve_value_to_type
- from .sqltypes import Boolean as BOOLEANTYPE # noqa
- from .sqltypes import Indexable as INDEXABLE # noqa
- from .sqltypes import MatchType as MATCHTYPE # noqa
- from .sqltypes import NULLTYPE
+ from .sqltypes import _resolve_value_to_type as _resolve_value_to_type
+ from .sqltypes import BOOLEANTYPE as BOOLEANTYPE
+ from .sqltypes import Indexable as INDEXABLE
+ from .sqltypes import INTEGERTYPE as INTEGERTYPE
+ from .sqltypes import MATCHTYPE as MATCHTYPE
+ from .sqltypes import NULLTYPE as NULLTYPE
+
_T = TypeVar("_T", bound=Any)
+_TE = TypeVar("_TE", bound="TypeEngine[Any]")
_CT = TypeVar("_CT", bound=Any)
# replace with pep-673 when applicable
@@ -95,7 +88,7 @@ class TypeEngine(Visitable, Generic[_T]):
"""
class Comparator(
- ColumnOperators["ColumnElement"],
+ ColumnOperators,
Generic[_CT],
):
"""Base class for custom comparison operations defined at the
@@ -539,7 +532,7 @@ class TypeEngine(Visitable, Generic[_T]):
return util.method_is_overridden(self, TypeEngine.bind_expression)
@staticmethod
- def _to_instance(cls_or_self):
+ def _to_instance(cls_or_self: Union[Type[_TE], _TE]) -> _TE:
return to_instance(cls_or_self)
def compare_values(self, x, y):
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 0c41e440e..903aae648 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -866,7 +866,7 @@ def traverse(
def cloned_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
- visitors: Mapping[str, _TraverseTransformCallableType],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
"""Clone the given expression structure, allowing modifications by
visitors.