diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/_elements_constructors.py | 165 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/_typing.py | 46 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 37 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/cache_key.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 72 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 82 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 1284 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/operators.py | 363 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/roles.py | 38 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 50 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 27 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 2 |
16 files changed, 1371 insertions, 830 deletions
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 9d15cdcc3..770fbe40c 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -9,17 +9,19 @@ from __future__ import annotations import typing from typing import Any -from typing import cast as _typing_cast +from typing import Callable +from typing import Iterable +from typing import Mapping from typing import Optional from typing import overload -from typing import Type +from typing import Sequence +from typing import Tuple as typing_Tuple from typing import TypeVar from typing import Union from . import coercions -from . import operators from . import roles -from .base import NO_ARG +from .base import _NoArg from .coercions import _document_text_coercion from .elements import BindParameter from .elements import BooleanClauseList @@ -35,18 +37,20 @@ from .elements import FunctionFilter from .elements import Label from .elements import Null from .elements import Over -from .elements import SQLCoreOperations from .elements import TextClause from .elements import True_ from .elements import Tuple from .elements import TypeCoerce from .elements import UnaryExpression from .elements import WithinGroup +from .functions import FunctionElement +from ..util.typing import Literal if typing.TYPE_CHECKING: - from elements import BinaryExpression - from . import sqltypes + from ._typing import _ColumnExpression + from ._typing import _TypeEngineArgument + from .elements import BinaryExpression from .functions import FunctionElement from .selectable import FromClause from .type_api import TypeEngine @@ -54,7 +58,7 @@ if typing.TYPE_CHECKING: _T = TypeVar("_T") -def all_(expr): +def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]: """Produce an ALL expression. For dialects such as that of PostgreSQL, this operator applies @@ -108,7 +112,7 @@ def all_(expr): return CollectionAggregate._create_all(expr) -def and_(*clauses): +def and_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList: r"""Produce a conjunction of expressions joined by ``AND``. E.g.:: @@ -169,7 +173,7 @@ def and_(*clauses): return BooleanClauseList.and_(*clauses) -def any_(expr): +def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]: """Produce an ANY expression. For dialects such as that of PostgreSQL, this operator applies @@ -223,7 +227,7 @@ def any_(expr): return CollectionAggregate._create_any(expr) -def asc(column): +def asc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: """Produce an ascending ``ORDER BY`` clause element. e.g.:: @@ -261,7 +265,9 @@ def asc(column): return UnaryExpression._create_asc(column) -def collate(expression, collation): +def collate( + expression: _ColumnExpression[str], collation: str +) -> BinaryExpression[str]: """Return the clause ``expression COLLATE collation``. e.g.:: @@ -282,7 +288,12 @@ def collate(expression, collation): return CollationClause._create_collation_expression(expression, collation) -def between(expr, lower_bound, upper_bound, symmetric=False): +def between( + expr: _ColumnExpression[_T], + lower_bound: Any, + upper_bound: Any, + symmetric: bool = False, +) -> BinaryExpression[bool]: """Produce a ``BETWEEN`` predicate clause. E.g.:: @@ -338,7 +349,9 @@ def between(expr, lower_bound, upper_bound, symmetric=False): return expr.between(lower_bound, upper_bound, symmetric=symmetric) -def outparam(key, type_=None): +def outparam( + key: str, type_: Optional[TypeEngine[_T]] = None +) -> BindParameter[_T]: """Create an 'OUT' parameter for usage in functions (stored procedures), for databases which support them. @@ -352,16 +365,16 @@ def outparam(key, type_=None): @overload -def not_(clause: "BinaryExpression[_T]") -> "BinaryExpression[_T]": +def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: ... @overload -def not_(clause: "ColumnElement[_T]") -> "UnaryExpression[_T]": +def not_(clause: _ColumnExpression[_T]) -> ColumnElement[_T]: ... -def not_(clause: "ColumnElement[_T]") -> "ColumnElement[_T]": +def not_(clause: _ColumnExpression[_T]) -> ColumnElement[_T]: """Return a negation of the given clause, i.e. ``NOT(clause)``. The ``~`` operator is also overloaded on all @@ -370,29 +383,21 @@ def not_(clause: "ColumnElement[_T]") -> "ColumnElement[_T]": """ - return operators.inv( - _typing_cast( - "ColumnElement[_T]", - coercions.expect(roles.ExpressionElementRole, clause), - ) - ) + return coercions.expect(roles.ExpressionElementRole, clause).__invert__() def bindparam( - key, - value=NO_ARG, - type_: Optional[Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]] = None, - unique=False, - required=NO_ARG, - quote=None, - callable_=None, - expanding=False, - isoutparam=False, - literal_execute=False, - _compared_to_operator=None, - _compared_to_type=None, - _is_crud=False, -) -> "BindParameter[_T]": + key: str, + value: Any = _NoArg.NO_ARG, + type_: Optional[TypeEngine[_T]] = None, + unique: bool = False, + required: Union[bool, Literal[_NoArg.NO_ARG]] = _NoArg.NO_ARG, + quote: Optional[bool] = None, + callable_: Optional[Callable[[], Any]] = None, + expanding: bool = False, + isoutparam: bool = False, + literal_execute: bool = False, +) -> BindParameter[_T]: r"""Produce a "bound expression". The return value is an instance of :class:`.BindParameter`; this @@ -636,13 +641,16 @@ def bindparam( expanding, isoutparam, literal_execute, - _compared_to_operator, - _compared_to_type, - _is_crud, ) -def case(*whens, value=None, else_=None) -> "Case[Any]": +def case( + *whens: Union[ + typing_Tuple[_ColumnExpression[bool], Any], Mapping[Any, Any] + ], + value: Optional[Any] = None, + else_: Optional[Any] = None, +) -> Case[Any]: r"""Produce a ``CASE`` expression. The ``CASE`` construct in SQL is a conditional object that @@ -767,9 +775,9 @@ def case(*whens, value=None, else_=None) -> "Case[Any]": def cast( - expression: ColumnElement, - type_: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], -) -> "Cast[_T]": + expression: _ColumnExpression[Any], + type_: _TypeEngineArgument[_T], +) -> Cast[_T]: r"""Produce a ``CAST`` expression. :func:`.cast` returns an instance of :class:`.Cast`. @@ -826,10 +834,10 @@ def cast( def column( text: str, - type_: Optional[Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]] = None, + type_: Optional[_TypeEngineArgument[_T]] = None, is_literal: bool = False, - _selectable: Optional["FromClause"] = None, -) -> "ColumnClause[_T]": + _selectable: Optional[FromClause] = None, +) -> ColumnClause[_T]: """Produce a :class:`.ColumnClause` object. The :class:`.ColumnClause` is a lightweight analogue to the @@ -921,12 +929,10 @@ def column( :ref:`sqlexpression_literal_column` """ - self = ColumnClause.__new__(ColumnClause) - self.__init__(text, type_, is_literal, _selectable) - return self + return ColumnClause(text, type_, is_literal, _selectable) -def desc(column): +def desc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: """Produce a descending ``ORDER BY`` clause element. e.g.:: @@ -965,7 +971,7 @@ def desc(column): return UnaryExpression._create_desc(column) -def distinct(expr): +def distinct(expr: _ColumnExpression[_T]) -> UnaryExpression[_T]: """Produce an column-expression-level unary ``DISTINCT`` clause. This applies the ``DISTINCT`` keyword to an individual column @@ -1004,7 +1010,7 @@ def distinct(expr): return UnaryExpression._create_distinct(expr) -def extract(field: str, expr: ColumnElement) -> "Extract[sqltypes.Integer]": +def extract(field: str, expr: _ColumnExpression[Any]) -> Extract: """Return a :class:`.Extract` construct. This is typically available as :func:`.extract` @@ -1045,7 +1051,7 @@ def extract(field: str, expr: ColumnElement) -> "Extract[sqltypes.Integer]": return Extract(field, expr) -def false(): +def false() -> False_: """Return a :class:`.False_` construct. E.g.:: @@ -1083,7 +1089,9 @@ def false(): return False_._instance() -def funcfilter(func, *criterion) -> "FunctionFilter": +def funcfilter( + func: FunctionElement[_T], *criterion: _ColumnExpression[bool] +) -> FunctionFilter[_T]: """Produce a :class:`.FunctionFilter` object against a function. Used against aggregate and window functions, @@ -1114,8 +1122,8 @@ def funcfilter(func, *criterion) -> "FunctionFilter": def label( name: str, - element: ColumnElement[_T], - type_: Optional[Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]] = None, + element: _ColumnExpression[_T], + type_: Optional[_TypeEngineArgument[_T]] = None, ) -> "Label[_T]": """Return a :class:`Label` object for the given :class:`_expression.ColumnElement`. @@ -1135,13 +1143,13 @@ def label( return Label(name, element, type_) -def null(): +def null() -> Null: """Return a constant :class:`.Null` construct.""" return Null._instance() -def nulls_first(column): +def nulls_first(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: """Produce the ``NULLS FIRST`` modifier for an ``ORDER BY`` expression. :func:`.nulls_first` is intended to modify the expression produced @@ -1185,7 +1193,7 @@ def nulls_first(column): return UnaryExpression._create_nulls_first(column) -def nulls_last(column): +def nulls_last(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: """Produce the ``NULLS LAST`` modifier for an ``ORDER BY`` expression. :func:`.nulls_last` is intended to modify the expression produced @@ -1229,7 +1237,7 @@ def nulls_last(column): return UnaryExpression._create_nulls_last(column) -def or_(*clauses: SQLCoreOperations) -> BooleanClauseList: +def or_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList: """Produce a conjunction of expressions joined by ``OR``. E.g.:: @@ -1281,12 +1289,16 @@ def or_(*clauses: SQLCoreOperations) -> BooleanClauseList: def over( - element: "FunctionElement[_T]", - partition_by=None, - order_by=None, - range_=None, - rows=None, -) -> "Over[_T]": + element: FunctionElement[_T], + partition_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + order_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, +) -> Over[_T]: r"""Produce an :class:`.Over` object against a function. Used against aggregate or so-called "window" functions, @@ -1373,7 +1385,7 @@ def over( @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`") -def text(text): +def text(text: str) -> TextClause: r"""Construct a new :class:`_expression.TextClause` clause, representing a textual SQL string directly. @@ -1451,7 +1463,7 @@ def text(text): return TextClause(text) -def true(): +def true() -> True_: """Return a constant :class:`.True_` construct. E.g.:: @@ -1489,7 +1501,10 @@ def true(): return True_._instance() -def tuple_(*clauses: roles.ExpressionElementRole, types=None) -> "Tuple": +def tuple_( + *clauses: _ColumnExpression[Any], + types: Optional[Sequence[_TypeEngineArgument[Any]]] = None, +) -> Tuple: """Return a :class:`.Tuple`. Main usage is to produce a composite IN construct using @@ -1516,9 +1531,9 @@ def tuple_(*clauses: roles.ExpressionElementRole, types=None) -> "Tuple": def type_coerce( - expression: "ColumnElement", - type_: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], -) -> "TypeCoerce[_T]": + expression: _ColumnExpression[Any], + type_: _TypeEngineArgument[_T], +) -> TypeCoerce[_T]: r"""Associate a SQL expression with a particular type, without rendering ``CAST``. @@ -1597,8 +1612,8 @@ def type_coerce( def within_group( - element: "FunctionElement[_T]", *order_by: roles.OrderByRole -) -> "WithinGroup[_T]": + element: FunctionElement[_T], *order_by: _ColumnExpression[Any] +) -> WithinGroup[_T]: r"""Produce a :class:`.WithinGroup` object against a function. Used against so-called "ordered set aggregate" and "hypothetical diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 69e4645fa..389f7e8d0 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,14 +1,58 @@ from __future__ import annotations +from typing import Any from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import roles +from .. import util from ..inspection import Inspectable +if TYPE_CHECKING: + from .elements import quoted_name + from .schema import DefaultGenerator + from .schema import Sequence + from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import TableClause + from .sqltypes import TupleType + from .type_api import TypeEngine + from ..util.typing import TypeGuard + +_T = TypeVar("_T", bound=Any) + _ColumnsClauseElement = Union[ - roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement] + roles.ColumnsClauseRole, + Type, + Inspectable[roles.HasColumnElementClauseElement], ] _FromClauseElement = Union[ roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement] ] + +_ColumnExpression = Union[ + roles.ExpressionElementRole[_T], + Inspectable[roles.HasColumnElementClauseElement], +] + +_PropagateAttrsType = util.immutabledict[str, Any] + +_TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] + + +def is_named_from_clause(t: FromClause) -> TypeGuard[NamedFromClause]: + return t.named_with_column + + +def has_schema_attr(t: FromClause) -> TypeGuard[TableClause]: + return hasattr(t, "schema") + + +def is_quoted_name(s: str) -> TypeGuard[quoted_name]: + return hasattr(s, "quote") + + +def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: + return t._is_tuple_type diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 7afc2de97..f37ae9a60 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -18,11 +18,11 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import FrozenSet from typing import Mapping from typing import Optional from typing import overload from typing import Sequence -from typing import Set from typing import Tuple from typing import Type from typing import TypeVar @@ -53,7 +53,9 @@ class SupportsAnnotations(ExternallyTraversible): __slots__ = () _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS - proxy_set: Set[SupportsAnnotations] + + proxy_set: util.generic_fn_descriptor[FrozenSet[Any]] + _is_immutable: bool def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations: diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a408a010a..29f9028c8 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -13,16 +13,20 @@ from __future__ import annotations import collections.abc as collections_abc +from enum import Enum from functools import reduce import itertools from itertools import zip_longest import operator import re import typing +from typing import Any +from typing import Iterable +from typing import List from typing import MutableMapping from typing import Optional from typing import Sequence -from typing import Set +from typing import Tuple from typing import TypeVar from . import roles @@ -38,23 +42,33 @@ from .. import util from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util import typing as compat_typing +from ..util.typing import Self if typing.TYPE_CHECKING: + from .elements import BindParameter from .elements import ColumnElement from ..engine import Connection from ..engine import Result + from ..engine.base import _CompiledCacheType from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import _ExecuteOptions from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.interfaces import _SchemaTranslateMapType from ..engine.interfaces import CacheStats - + from ..engine.interfaces import Compiled + from ..engine.interfaces import Dialect coercions = None elements = None type_api = None -NO_ARG = util.symbol("NO_ARG") + +class _NoArg(Enum): + NO_ARG = 0 + + +NO_ARG = _NoArg.NO_ARG # if I use sqlalchemy.util.typing, which has the exact same # symbols, mypy reports: "error: _Fn? not callable" @@ -74,10 +88,12 @@ class Immutable: def params(self, *optionaldict, **kwargs): raise NotImplementedError("Immutable objects do not support copying") - def _clone(self, **kw): + def _clone(self: Self, **kw: Any) -> Self: return self - def _copy_internals(self, **kw): + def _copy_internals( + self, omit_attrs: Iterable[str] = (), **kw: Any + ) -> None: pass @@ -88,8 +104,6 @@ class SingletonConstant(Immutable): _singleton: SingletonConstant - proxy_set: Set[ColumnElement] - def __new__(cls, *arg, **kw): return cls._singleton @@ -877,12 +891,15 @@ class Executable(roles.StatementRole, Generative): def _compile_w_cache( self, dialect: Dialect, - compiled_cache: Optional[_CompiledCacheType] = None, - column_keys: Optional[Sequence[str]] = None, + *, + compiled_cache: Optional[_CompiledCacheType], + column_keys: List[str], for_executemany: bool = False, schema_translate_map: Optional[_SchemaTranslateMapType] = None, **kw: Any, - ) -> Tuple[Compiled, _SingleExecuteParams, CacheStats]: + ) -> Tuple[ + Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ]: ... def _execute_on_connection( diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index fca58f98e..19a232c56 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -11,15 +11,14 @@ import enum from itertools import zip_longest import typing from typing import Any -from typing import cast from typing import Dict from typing import Iterator from typing import List +from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import Sequence from typing import Tuple -from typing import Type from typing import Union from .visitors import anon_map @@ -91,7 +90,7 @@ class HasCacheKey: __slots__ = () _cache_key_traversal: Union[ - _TraverseInternalsType, Literal[CacheConst.NO_CACHE] + _TraverseInternalsType, Literal[CacheConst.NO_CACHE], Literal[None] ] = NO_CACHE _is_has_cache_key = True @@ -147,11 +146,8 @@ class HasCacheKey: _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) if _cache_key_traversal is None: try: - # check for _traverse_internals, which is part of - # HasTraverseInternals - _cache_key_traversal = cast( - "Type[HasTraverseInternals]", cls - )._traverse_internals + assert issubclass(cls, HasTraverseInternals) + _cache_key_traversal = cls._traverse_internals except AttributeError: cls._generated_cache_key_traversal = NO_CACHE return NO_CACHE @@ -417,7 +413,7 @@ class CacheKey(NamedTuple): def to_offline_string( self, - statement_cache: _CompiledCacheType, + statement_cache: MutableMapping[Any, str], statement: ClauseElement, parameters: _CoreSingleExecuteParams, ) -> str: diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 834bfb75d..ea17b8e03 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -13,10 +13,12 @@ import re import typing from typing import Any from typing import Any as TODO_Any +from typing import Callable from typing import Dict from typing import List from typing import NoReturn from typing import Optional +from typing import overload from typing import Type from typing import TypeVar @@ -46,9 +48,14 @@ if typing.TYPE_CHECKING: from . import traversals from .elements import ClauseElement from .elements import ColumnClause + from .elements import ColumnElement + from .elements import SQLCoreOperations + _SR = TypeVar("_SR", bound=roles.SQLRole) +_F = TypeVar("_F", bound=Callable[..., Any]) _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole) +_T = TypeVar("_T", bound=Any) def _is_literal(element): @@ -104,7 +111,9 @@ def _deep_is_literal(element): ) -def _document_text_coercion(paramname, meth_rst, param_rst): +def _document_text_coercion( + paramname: str, meth_rst: str, param_rst: str +) -> Callable[[_F], _F]: return util.add_parameter_text( paramname, ( @@ -132,15 +141,50 @@ def _expression_collection_was_a_list(attrname, fnname, args): return args -# TODO; would like to have overloads here, however mypy is being extremely -# pedantic about them. not sure why pylance is OK with them. +@overload +def expect( + role: Type[roles.TruncatedLabelRole], + element: Any, + *, + apply_propagate_attrs: Optional[ClauseElement] = None, + argname: Optional[str] = None, + post_inspect: bool = False, + **kw: Any, +) -> str: + ... + + +@overload +def expect( + role: Type[roles.ExpressionElementRole[_T]], + element: Any, + *, + apply_propagate_attrs: Optional[ClauseElement] = None, + argname: Optional[str] = None, + post_inspect: bool = False, + **kw: Any, +) -> ColumnElement[_T]: + ... +@overload def expect( role: Type[_SR], element: Any, *, - apply_propagate_attrs: Optional["ClauseElement"] = None, + apply_propagate_attrs: Optional[ClauseElement] = None, + argname: Optional[str] = None, + post_inspect: bool = False, + **kw: Any, +) -> TODO_Any: + ... + + +def expect( + role: Type[_SR], + element: Any, + *, + apply_propagate_attrs: Optional[ClauseElement] = None, argname: Optional[str] = None, post_inspect: bool = False, **kw: Any, @@ -220,12 +264,16 @@ def expect( resolved = element else: resolved = element - if ( - apply_propagate_attrs is not None - and not apply_propagate_attrs._propagate_attrs - and resolved._propagate_attrs - ): - apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs + + if apply_propagate_attrs is not None: + if typing.TYPE_CHECKING: + assert isinstance(resolved, (SQLCoreOperations, ClauseElement)) + + if ( + not apply_propagate_attrs._propagate_attrs + and resolved._propagate_attrs + ): + apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs if impl._role_class in resolved.__class__.__mro__: if impl._post_coercion: @@ -620,8 +668,8 @@ class InElementImpl(RoleImpl): element, str ): non_literal_expressions: Dict[ - Optional[operators.ColumnOperators[Any]], - operators.ColumnOperators[Any], + Optional[operators.ColumnOperators], + operators.ColumnOperators, ] = {} element = list(element) for o in element: diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 001710d7b..91bb0a5c5 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -42,13 +42,14 @@ _T = typing.TypeVar("_T", bound=Any) if typing.TYPE_CHECKING: from .elements import ColumnElement + from .operators import custom_op from .sqltypes import TypeEngine def _boolean_compare( - expr: "ColumnElement", + expr: ColumnElement[Any], op: OperatorType, - obj: roles.BinaryElementRole, + obj: Any, *, negate_op: Optional[OperatorType] = None, reverse: bool = False, @@ -59,7 +60,6 @@ def _boolean_compare( ] = None, **kwargs: Any, ) -> BinaryExpression[bool]: - if result_type is None: result_type = type_api.BOOLEANTYPE @@ -143,7 +143,14 @@ def _boolean_compare( ) -def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): +def _custom_op_operate( + expr: ColumnElement[Any], + op: custom_op[Any], + obj: Any, + reverse: bool = False, + result_type: Optional[TypeEngine[Any]] = None, + **kw: Any, +) -> ColumnElement[Any]: if result_type is None: if op.return_type: result_type = op.return_type @@ -156,11 +163,11 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): def _binary_operate( - expr: "ColumnElement", + expr: ColumnElement[Any], op: OperatorType, obj: roles.BinaryElementRole, *, - reverse=False, + reverse: bool = False, result_type: Optional[ Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] ] = None, @@ -184,7 +191,9 @@ def _binary_operate( return BinaryExpression(left, right, op, type_=result_type, modifiers=kw) -def _conjunction_operate(expr, op, other, **kw) -> "ColumnElement": +def _conjunction_operate( + expr: ColumnElement[Any], op: OperatorType, other, **kw +) -> ColumnElement[Any]: if op is operators.and_: return and_(expr, other) elif op is operators.or_: @@ -193,11 +202,19 @@ def _conjunction_operate(expr, op, other, **kw) -> "ColumnElement": raise NotImplementedError() -def _scalar(expr, op, fn, **kw) -> "ColumnElement": +def _scalar( + expr: ColumnElement[Any], op: OperatorType, fn, **kw +) -> ColumnElement[Any]: return fn(expr) -def _in_impl(expr, op, seq_or_selectable, negate_op, **kw) -> "ColumnElement": +def _in_impl( + expr: ColumnElement[Any], + op: OperatorType, + seq_or_selectable, + negate_op: OperatorType, + **kw, +) -> ColumnElement[Any]: seq_or_selectable = coercions.expect( roles.InElementRole, seq_or_selectable, expr=expr, operator=op ) @@ -209,7 +226,9 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw) -> "ColumnElement": ) -def _getitem_impl(expr, op, other, **kw) -> "ColumnElement": +def _getitem_impl( + expr: ColumnElement[Any], op: OperatorType, other, **kw +) -> ColumnElement[Any]: if isinstance(expr.type, type_api.INDEXABLE): other = coercions.expect( roles.BinaryElementRole, other, expr=expr, operator=op @@ -219,13 +238,17 @@ def _getitem_impl(expr, op, other, **kw) -> "ColumnElement": _unsupported_impl(expr, op, other, **kw) -def _unsupported_impl(expr, op, *arg, **kw) -> NoReturn: +def _unsupported_impl( + expr: ColumnElement[Any], op: OperatorType, *arg, **kw +) -> NoReturn: raise NotImplementedError( "Operator '%s' is not supported on " "this expression" % op.__name__ ) -def _inv_impl(expr, op, **kw) -> "ColumnElement": +def _inv_impl( + expr: ColumnElement[Any], op: OperatorType, **kw +) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.__inv__`.""" # undocumented element currently used by the ORM for @@ -236,12 +259,16 @@ def _inv_impl(expr, op, **kw) -> "ColumnElement": return expr._negate() -def _neg_impl(expr, op, **kw) -> "ColumnElement": +def _neg_impl( + expr: ColumnElement[Any], op: OperatorType, **kw +) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.__neg__`.""" return UnaryExpression(expr, operator=operators.neg, type_=expr.type) -def _match_impl(expr, op, other, **kw) -> "ColumnElement": +def _match_impl( + expr: ColumnElement[Any], op: OperatorType, other, **kw +) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.match`.""" return _boolean_compare( @@ -261,14 +288,18 @@ def _match_impl(expr, op, other, **kw) -> "ColumnElement": ) -def _distinct_impl(expr, op, **kw) -> "ColumnElement": +def _distinct_impl( + expr: ColumnElement[Any], op: OperatorType, **kw +) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.distinct`.""" return UnaryExpression( expr, operator=operators.distinct_op, type_=expr.type ) -def _between_impl(expr, op, cleft, cright, **kw) -> "ColumnElement": +def _between_impl( + expr: ColumnElement[Any], op: OperatorType, cleft, cright, **kw +) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.between`.""" return BinaryExpression( expr, @@ -297,11 +328,15 @@ def _between_impl(expr, op, cleft, cright, **kw) -> "ColumnElement": ) -def _collate_impl(expr, op, collation, **kw) -> "ColumnElement": +def _collate_impl( + expr: ColumnElement[Any], op: OperatorType, collation, **kw +) -> ColumnElement[Any]: return CollationClause._create_collation_expression(expr, collation) -def _regexp_match_impl(expr, op, pattern, flags, **kw) -> "ColumnElement": +def _regexp_match_impl( + expr: ColumnElement[Any], op: OperatorType, pattern, flags, **kw +) -> ColumnElement[Any]: if flags is not None: flags = coercions.expect( roles.BinaryElementRole, @@ -322,8 +357,13 @@ def _regexp_match_impl(expr, op, pattern, flags, **kw) -> "ColumnElement": def _regexp_replace_impl( - expr, op, pattern, replacement, flags, **kw -) -> "ColumnElement": + expr: ColumnElement[Any], + op: OperatorType, + pattern, + replacement, + flags, + **kw, +) -> ColumnElement[Any]: replacement = coercions.expect( roles.BinaryElementRole, replacement, @@ -345,7 +385,7 @@ def _regexp_replace_impl( # a mapping of operators with the method they use, along with # additional keyword arguments to be passed operator_lookup: Dict[ - str, Tuple[Callable[..., "ColumnElement"], util.immutabledict] + str, Tuple[Callable[..., ColumnElement[Any]], util.immutabledict] ] = { "and_": (_conjunction_operate, util.EMPTY_DICT), "or_": (_conjunction_operate, util.EMPTY_DICT), diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 08d632afd..fdb3fc8bb 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -12,20 +12,28 @@ from __future__ import annotations +from decimal import Decimal +from enum import IntEnum import itertools import operator import re import typing from typing import Any from typing import Callable +from typing import cast from typing import Dict +from typing import FrozenSet from typing import Generic +from typing import Iterable from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Sequence -from typing import Text as typing_Text +from typing import Set +from typing import Tuple as typing_Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -34,10 +42,15 @@ from . import operators from . import roles from . import traversals from . import type_api +from ._typing import has_schema_attr +from ._typing import is_named_from_clause +from ._typing import is_quoted_name +from ._typing import is_tuple_type from .annotation import Annotated from .annotation import SupportsWrappingAnnotations from .base import _clone from .base import _generative +from .base import _NoArg from .base import Executable from .base import HasMemoized from .base import Immutable @@ -57,30 +70,47 @@ from .. import exc from .. import inspection from .. import util from ..util.langhelpers import TypingOnly +from ..util.typing import Literal if typing.TYPE_CHECKING: - from decimal import Decimal - + from ._typing import _ColumnExpression + from ._typing import _PropagateAttrsType + from ._typing import _TypeEngineArgument + from .cache_key import CacheKey from .compiler import Compiled from .compiler import SQLCompiler + from .functions import FunctionElement from .operators import OperatorType + from .schema import Column + from .schema import DefaultGenerator + from .schema import ForeignKey from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import ReturnsRows from .selectable import Select - from .sqltypes import Boolean # noqa + from .selectable import TableClause + from .sqltypes import Boolean + from .sqltypes import TupleType from .type_api import TypeEngine + from .visitors import _TraverseInternalsType from ..engine import Connection from ..engine import Dialect from ..engine import Engine from ..engine.base import _CompiledCacheType - from ..engine.base import _SchemaTranslateMapType - + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _SchemaTranslateMapType + from ..engine.interfaces import CacheStats + from ..engine.result import Result -_NUMERIC = Union[complex, "Decimal"] +_NUMERIC = Union[complex, Decimal] +_NUMBER = Union[complex, int, Decimal] _T = TypeVar("_T", bound="Any") _OPT = TypeVar("_OPT", bound="Any") _NT = TypeVar("_NT", bound="_NUMERIC") -_ST = TypeVar("_ST", bound="typing_Text") + +_NMT = TypeVar("_NMT", bound="_NUMBER") def literal(value, type_=None): @@ -210,28 +240,27 @@ class CompilerElement(Visitable): """ - if not dialect: + if dialect is None: if bind: dialect = bind.dialect + elif self.stringify_dialect == "default": + default = util.preloaded.engine_default + dialect = default.StrCompileDialect() else: - if self.stringify_dialect == "default": - default = util.preloaded.engine_default - dialect = default.StrCompileDialect() - else: - url = util.preloaded.engine_url - dialect = url.URL.create( - self.stringify_dialect - ).get_dialect()() + url = util.preloaded.engine_url + dialect = url.URL.create( + self.stringify_dialect + ).get_dialect()() return self._compiler(dialect, **kw) - def _compiler(self, dialect, **kw): + def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled: """Return a compiler appropriate for this ClauseElement, given a Dialect.""" return dialect.statement_compiler(dialect, self, **kw) - def __str__(self): + def __str__(self) -> str: return str(self.compile()) @@ -253,16 +282,17 @@ class ClauseElement( __visit_name__ = "clause" - _propagate_attrs = util.immutabledict() + _propagate_attrs: _PropagateAttrsType = util.immutabledict() """like annotations, however these propagate outwards liberally as SQL constructs are built, and are set up at construction time. """ - _from_objects = [] - bind = None - description = None - _is_clone_of = None + @util.memoized_property + def description(self) -> Optional[str]: + return None + + _is_clone_of: Optional[ClauseElement] = None is_clause_element = True is_selectable = False @@ -281,10 +311,25 @@ class ClauseElement( _is_singleton_constant = False _is_immutable = False - _order_by_label_element = None + @property + def _order_by_label_element(self) -> Optional[Label[Any]]: + return None _cache_key_traversal = None + negation_clause: ClauseElement + + if typing.TYPE_CHECKING: + + def get_children( + self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any + ) -> Iterable[ClauseElement]: + ... + + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: + return [] + def _set_propagate_attrs(self, values): # usually, self._propagate_attrs is empty here. one case where it's # not is a subquery against ORM select, that is then pulled as a @@ -295,7 +340,7 @@ class ClauseElement( self._propagate_attrs = util.immutabledict(values) return self - def _clone(self: SelfClauseElement, **kw) -> SelfClauseElement: + def _clone(self: SelfClauseElement, **kw: Any) -> SelfClauseElement: """Create a shallow copy of this ClauseElement. This method may be used by a generative API. Its also used as @@ -357,7 +402,7 @@ class ClauseElement( """ s = util.column_set() - f = self + f: Optional[ClauseElement] = self # note this creates a cycle, asserted in test_memusage. however, # turning this into a plain @property adds tends of thousands of method @@ -383,16 +428,26 @@ class ClauseElement( return d def _execute_on_connection( - self, connection, distilled_params, execution_options, _force=False - ): + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + _force: bool = False, + ) -> Result: if _force or self.supports_execution: + if TYPE_CHECKING: + assert isinstance(self, Executable) return connection._execute_clauseelement( self, distilled_params, execution_options ) else: raise exc.ObjectNotExecutableError(self) - def unique_params(self, *optionaldict, **kwargs): + def unique_params( + self: SelfClauseElement, + __optionaldict: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> SelfClauseElement: """Return a copy with :func:`_expression.bindparam` elements replaced. @@ -402,11 +457,13 @@ class ClauseElement( used. """ - return self._replace_params(True, optionaldict, kwargs) + return self._replace_params(True, __optionaldict, kwargs) def params( - self, *optionaldict: Dict[str, Any], **kwargs: Any - ) -> ClauseElement: + self: SelfClauseElement, + __optionaldict: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> SelfClauseElement: """Return a copy with :func:`_expression.bindparam` elements replaced. @@ -421,33 +478,32 @@ class ClauseElement( {'foo':7} """ - return self._replace_params(False, optionaldict, kwargs) + return self._replace_params(False, __optionaldict, kwargs) def _replace_params( - self, + self: SelfClauseElement, unique: bool, optionaldict: Optional[Dict[str, Any]], kwargs: Dict[str, Any], - ) -> ClauseElement: + ) -> SelfClauseElement: - if len(optionaldict) == 1: - kwargs.update(optionaldict[0]) - elif len(optionaldict) > 1: - raise exc.ArgumentError( - "params() takes zero or one positional dictionary argument" - ) + if optionaldict: + kwargs.update(optionaldict) - def visit_bindparam(bind): + def visit_bindparam(bind: BindParameter[Any]) -> None: if bind.key in kwargs: bind.value = kwargs[bind.key] bind.required = False if unique: bind._convert_to_unique() - return cloned_traverse( - self, - {"maintain_key": True, "detect_subquery_cols": True}, - {"bindparam": visit_bindparam}, + return cast( + SelfClauseElement, + cloned_traverse( + self, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ), ) def compare(self, other, **kw): @@ -501,18 +557,26 @@ class ClauseElement( def _compile_w_cache( self, dialect: Dialect, - compiled_cache: Optional[_CompiledCacheType] = None, - column_keys: Optional[List[str]] = None, + *, + compiled_cache: Optional[_CompiledCacheType], + column_keys: List[str], for_executemany: bool = False, schema_translate_map: Optional[_SchemaTranslateMapType] = None, **kw: Any, - ): + ) -> typing_Tuple[ + Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ]: + elem_cache_key: Optional[CacheKey] + if compiled_cache is not None and dialect._supports_statement_cache: elem_cache_key = self._generate_cache_key() else: elem_cache_key = None - if elem_cache_key: + if elem_cache_key is not None: + if TYPE_CHECKING: + assert compiled_cache is not None + cache_key, extracted_params = elem_cache_key key = ( dialect, @@ -564,7 +628,7 @@ class ClauseElement( else: return self._negate() - def _negate(self): + def _negate(self) -> ClauseElement: return UnaryExpression( self.self_group(against=operators.inv), operator=operators.inv ) @@ -605,6 +669,9 @@ class DQLDMLClauseElement(ClauseElement): ) -> SQLCompiler: ... + def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler: + ... + class CompilerColumnElement( roles.DMLColumnRole, @@ -621,9 +688,7 @@ class CompilerColumnElement( __slots__ = () -class SQLCoreOperations( - Generic[_T], ColumnOperators["SQLCoreOperations"], TypingOnly -): +class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): __slots__ = () # annotations for comparison methods @@ -631,173 +696,186 @@ class SQLCoreOperations( # redefined with the specific types returned by ColumnElement hierarchies if typing.TYPE_CHECKING: + _propagate_attrs: _PropagateAttrsType + def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnElement: + ) -> ColumnElement[Any]: ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> ColumnElement: + ) -> ColumnElement[Any]: ... def op( self, - opstring: Any, + opstring: str, precedence: int = 0, is_comparison: bool = False, - return_type: Optional[ - Union[Type["TypeEngine[_OPT]"], "TypeEngine[_OPT]"] - ] = None, - python_impl=None, - ) -> Callable[[Any], "BinaryExpression[_OPT]"]: + return_type: Optional[_TypeEngineArgument[_OPT]] = None, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[_OPT]]: ... def bool_op( - self, opstring: Any, precedence: int = 0, python_impl=None - ) -> Callable[[Any], "BinaryExpression[bool]"]: + self, + opstring: str, + precedence: int = 0, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[bool]]: ... - def __and__(self, other: Any) -> "BooleanClauseList": + def __and__(self, other: Any) -> BooleanClauseList: ... - def __or__(self, other: Any) -> "BooleanClauseList": + def __or__(self, other: Any) -> BooleanClauseList: ... - def __invert__(self) -> "UnaryExpression[_T]": + def __invert__(self) -> ColumnElement[_T]: ... - def __lt__(self, other: Any) -> "ColumnElement[bool]": + def __lt__(self, other: Any) -> ColumnElement[bool]: ... - def __le__(self, other: Any) -> "ColumnElement[bool]": + def __le__(self, other: Any) -> ColumnElement[bool]: ... - def __eq__(self, other: Any) -> "ColumnElement[bool]": # type: ignore[override] # noqa: E501 + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 ... - def __ne__(self, other: Any) -> "ColumnElement[bool]": # type: ignore[override] # noqa: E501 + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 ... - def is_distinct_from(self, other: Any) -> "ColumnElement[bool]": + def is_distinct_from(self, other: Any) -> ColumnElement[bool]: ... - def is_not_distinct_from(self, other: Any) -> "ColumnElement[bool]": + def is_not_distinct_from(self, other: Any) -> ColumnElement[bool]: ... - def __gt__(self, other: Any) -> "ColumnElement[bool]": + def __gt__(self, other: Any) -> ColumnElement[bool]: ... - def __ge__(self, other: Any) -> "ColumnElement[bool]": + def __ge__(self, other: Any) -> ColumnElement[bool]: ... - def __neg__(self) -> "UnaryExpression[_T]": + def __neg__(self) -> UnaryExpression[_T]: ... - def __contains__(self, other: Any) -> "ColumnElement[bool]": + def __contains__(self, other: Any) -> ColumnElement[bool]: ... - def __getitem__(self, index: Any) -> "ColumnElement": + def __getitem__(self, index: Any) -> ColumnElement[Any]: ... @overload - def concat( - self: "SQLCoreOperations[_ST]", other: Any - ) -> "ColumnElement[_ST]": + def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: ... @overload - def concat(self, other: Any) -> "ColumnElement": + def concat(self, other: Any) -> ColumnElement[Any]: ... - def concat(self, other: Any) -> "ColumnElement": + def concat(self, other: Any) -> ColumnElement[Any]: ... - def like(self, other: Any, escape=None) -> "BinaryExpression[bool]": + def like( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... - def ilike(self, other: Any, escape=None) -> "BinaryExpression[bool]": + def ilike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... def in_( self, - other: Union[Sequence[Any], "BindParameter", "Select"], - ) -> "BinaryExpression[bool]": + other: Union[Sequence[Any], BindParameter[Any], Select], + ) -> BinaryExpression[bool]: ... def not_in( self, - other: Union[Sequence[Any], "BindParameter", "Select"], - ) -> "BinaryExpression[bool]": + other: Union[Sequence[Any], BindParameter[Any], Select], + ) -> BinaryExpression[bool]: ... def not_like( - self, other: Any, escape=None - ) -> "BinaryExpression[bool]": + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... def not_ilike( - self, other: Any, escape=None - ) -> "BinaryExpression[bool]": + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... - def is_(self, other: Any) -> "BinaryExpression[bool]": + def is_(self, other: Any) -> BinaryExpression[bool]: ... - def is_not(self, other: Any) -> "BinaryExpression[bool]": + def is_not(self, other: Any) -> BinaryExpression[bool]: ... def startswith( - self, other: Any, escape=None, autoescape=False - ) -> "ColumnElement[bool]": + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: ... def endswith( - self, other: Any, escape=None, autoescape=False - ) -> "ColumnElement[bool]": + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: ... - def contains(self, other: Any, **kw: Any) -> "ColumnElement[bool]": + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: ... - def match(self, other: Any, **kwargs) -> "ColumnElement[bool]": + def match(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: ... - def regexp_match(self, pattern, flags=None) -> "ColumnElement[bool]": + def regexp_match( + self, pattern: Any, flags: Optional[str] = None + ) -> ColumnElement[bool]: ... def regexp_replace( - self, pattern, replacement, flags=None - ) -> "ColumnElement": + self, pattern: Any, replacement: Any, flags: Optional[str] = None + ) -> ColumnElement[str]: ... - def desc(self) -> "UnaryExpression[_T]": + def desc(self) -> UnaryExpression[_T]: ... - def asc(self) -> "UnaryExpression[_T]": + def asc(self) -> UnaryExpression[_T]: ... - def nulls_first(self) -> "UnaryExpression[_T]": + def nulls_first(self) -> UnaryExpression[_T]: ... - def nulls_last(self) -> "UnaryExpression[_T]": + def nulls_last(self) -> UnaryExpression[_T]: ... - def collate(self, collation) -> "CollationClause": + def collate(self, collation: str) -> CollationClause: ... def between( - self, cleft, cright, symmetric=False - ) -> "ColumnElement[bool]": + self, cleft: Any, cright: Any, symmetric: bool = False + ) -> BinaryExpression[bool]: ... - def distinct(self: "SQLCoreOperations[_T]") -> "UnaryExpression[_T]": + def distinct(self: _SQO[_T]) -> UnaryExpression[_T]: ... - def any_(self) -> "CollectionAggregate": + def any_(self) -> CollectionAggregate[Any]: ... - def all_(self) -> "CollectionAggregate": + def all_(self) -> CollectionAggregate[Any]: ... # numeric overloads. These need more tweaking @@ -807,179 +885,173 @@ class SQLCoreOperations( @overload def __add__( - self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]", - other: "Union[_SQO[Optional[_NT]], _SQO[_NT], _NT]", - ) -> "ColumnElement[_NT]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload def __add__( - self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]", + self: _SQO[str], other: Any, - ) -> "ColumnElement[_NUMERIC]": + ) -> ColumnElement[str]: ... - @overload - def __add__( - self: "Union[_SQO[_ST], _SQO[Optional[_ST]]]", - other: Any, - ) -> "ColumnElement[_ST]": + def __add__(self, other: Any) -> ColumnElement[Any]: ... - def __add__(self, other: Any) -> "ColumnElement": + @overload + def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: ... @overload - def __radd__(self, other: Any) -> "ColumnElement[_NUMERIC]": + def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... @overload - def __radd__(self, other: Any) -> "ColumnElement": + def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]: ... - def __radd__(self, other: Any) -> "ColumnElement": + def __radd__(self, other: Any) -> ColumnElement[Any]: ... @overload def __sub__( - self: "SQLCoreOperations[_NT]", - other: "Union[SQLCoreOperations[_NT], _NT]", - ) -> "ColumnElement[_NT]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __sub__(self, other: Any) -> "ColumnElement": + def __sub__(self, other: Any) -> ColumnElement[Any]: ... - def __sub__(self, other: Any) -> "ColumnElement": + def __sub__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rsub__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __rsub__(self, other: Any) -> "ColumnElement": + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... - def __rsub__(self, other: Any) -> "ColumnElement": + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... @overload def __mul__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __mul__(self, other: Any) -> "ColumnElement": + def __mul__(self, other: Any) -> ColumnElement[Any]: ... - def __mul__(self, other: Any) -> "ColumnElement": + def __mul__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rmul__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __rmul__(self, other: Any) -> "ColumnElement": + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... - def __rmul__(self, other: Any) -> "ColumnElement": + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __mod__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __mod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __mod__(self, other: Any) -> "ColumnElement": + def __mod__(self, other: Any) -> ColumnElement[Any]: ... - def __mod__(self, other: Any) -> "ColumnElement": + def __mod__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rmod__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __rmod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __rmod__(self, other: Any) -> "ColumnElement": + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... - def __rmod__(self, other: Any) -> "ColumnElement": + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... @overload def __truediv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NUMERIC]: ... @overload - def __truediv__(self, other: Any) -> "ColumnElement": + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... - def __truediv__(self, other: Any) -> "ColumnElement": + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rtruediv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NUMERIC]: ... @overload - def __rtruediv__(self, other: Any) -> "ColumnElement": + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... - def __rtruediv__(self, other: Any) -> "ColumnElement": + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __floordiv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __floordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __floordiv__(self, other: Any) -> "ColumnElement": + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... - def __floordiv__(self, other: Any) -> "ColumnElement": + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rfloordiv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __rfloordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __rfloordiv__(self, other: Any) -> "ColumnElement": + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... - def __rfloordiv__(self, other: Any) -> "ColumnElement": + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... _SQO = SQLCoreOperations +SelfColumnElement = TypeVar("SelfColumnElement", bound="ColumnElement[Any]") + class ColumnElement( roles.ColumnArgumentOrKeyRole, roles.StatementOptionRole, roles.WhereHavingRole, - roles.BinaryElementRole, + roles.BinaryElementRole[_T], roles.OrderByRole, roles.ColumnsClauseRole, roles.LimitOffsetRole, @@ -987,7 +1059,6 @@ class ColumnElement( roles.DDLConstraintColumnRole, roles.DDLExpressionRole, SQLCoreOperations[_T], - operators.ColumnOperators[SQLCoreOperations], DQLDMLClauseElement, ): """Represent a column-oriented SQL expression suitable for usage in the @@ -1069,28 +1140,37 @@ class ColumnElement( __visit_name__ = "column_element" - primary_key = False - foreign_keys = [] - _proxies = () + primary_key: bool = False + _is_clone_of: Optional[ColumnElement[_T]] - _tq_label = None - """The named label that can be used to target - this column in a result set in a "table qualified" context. + @util.memoized_property + def foreign_keys(self) -> Iterable[ForeignKey]: + return [] - This label is almost always the label used when - rendering <expr> AS <label> in a SELECT statement when using - the LABEL_STYLE_TABLENAME_PLUS_COL label style, which is what the legacy - ORM ``Query`` object uses as well. + @util.memoized_property + def _proxies(self) -> List[ColumnElement[Any]]: + return [] - For a regular Column bound to a Table, this is typically the label - <tablename>_<columnname>. For other constructs, different rules - may apply, such as anonymized labels and others. + @util.non_memoized_property + def _tq_label(self) -> Optional[str]: + """The named label that can be used to target + this column in a result set in a "table qualified" context. - .. versionchanged:: 1.4.21 renamed from ``._label`` + This label is almost always the label used when + rendering <expr> AS <label> in a SELECT statement when using + the LABEL_STYLE_TABLENAME_PLUS_COL label style, which is what the + legacy ORM ``Query`` object uses as well. - """ + For a regular Column bound to a Table, this is typically the label + <tablename>_<columnname>. For other constructs, different rules + may apply, such as anonymized labels and others. + + .. versionchanged:: 1.4.21 renamed from ``._label`` + + """ + return None - key = None + key: Optional[str] = None """The 'key' that in some circumstances refers to this object in a Python namespace. @@ -1101,7 +1181,7 @@ class ColumnElement( """ @HasMemoized.memoized_attribute - def _tq_key_label(self): + def _tq_key_label(self) -> Optional[str]: """A label-based version of 'key' that in some circumstances refers to this object in a Python namespace. @@ -1119,17 +1199,17 @@ class ColumnElement( return self._proxy_key @property - def _key_label(self): + def _key_label(self) -> Optional[str]: """legacy; renamed to _tq_key_label""" return self._tq_key_label @property - def _label(self): + def _label(self) -> Optional[str]: """legacy; renamed to _tq_label""" return self._tq_label @property - def _non_anon_label(self): + def _non_anon_label(self) -> Optional[str]: """the 'name' that naturally applies this element when rendered in SQL. @@ -1184,9 +1264,23 @@ class ColumnElement( _is_implicitly_boolean = False - _alt_names = () + _alt_names: Sequence[str] = () - def self_group(self, against=None): + @overload + def self_group( + self: ColumnElement[bool], against: Optional[OperatorType] = None + ) -> ColumnElement[bool]: + ... + + @overload + def self_group( + self: ColumnElement[_T], against: Optional[OperatorType] = None + ) -> ColumnElement[_T]: + ... + + def self_group( + self, against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: if ( against in (operators.and_, operators.or_, operators._asbool) and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity @@ -1197,18 +1291,32 @@ class ColumnElement( else: return self - def _negate(self): + @overload + def _negate(self: ColumnElement[bool]) -> ColumnElement[bool]: + ... + + @overload + def _negate(self: ColumnElement[_T]) -> ColumnElement[_T]: + ... + + def _negate(self) -> ColumnElement[Any]: if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return AsBoolean(self, operators.is_false, operators.is_true) else: - return super(ColumnElement, self)._negate() + return cast("UnaryExpression[_T]", super()._negate()) - @util.memoized_property - def type(self) -> "TypeEngine[_T]": - return type_api.NULLTYPE + type: TypeEngine[_T] + + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + # used for delayed setup of + # type_api + return type_api.NULLTYPE @HasMemoized.memoized_attribute - def comparator(self) -> "TypeEngine.Comparator[_T]": + def comparator(self) -> TypeEngine.Comparator[_T]: try: comparator_factory = self.type.comparator_factory except AttributeError as err: @@ -1219,7 +1327,7 @@ class ColumnElement( else: return comparator_factory(self) - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: try: return getattr(self.comparator, key) except AttributeError as err: @@ -1236,16 +1344,22 @@ class ColumnElement( self, op: operators.OperatorType, *other: Any, - **kwargs, - ) -> "ColumnElement": - return op(self.comparator, *other, **kwargs) + **kwargs: Any, + ) -> ColumnElement[Any]: + return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501 def reverse_operate( - self, op: operators.OperatorType, other: Any, **kwargs - ) -> "ColumnElement": - return op(other, self.comparator, **kwargs) + self, op: operators.OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501 - def _bind_param(self, operator, obj, type_=None, expanding=False): + def _bind_param( + self, + operator: operators.OperatorType, + obj: Any, + type_: Optional[TypeEngine[_T]] = None, + expanding: bool = False, + ) -> BindParameter[_T]: return BindParameter( None, obj, @@ -1257,7 +1371,7 @@ class ColumnElement( ) @property - def expression(self): + def expression(self) -> ColumnElement[Any]: """Return a column expression. Part of the inspection interface; returns self. @@ -1266,39 +1380,39 @@ class ColumnElement( return self @property - def _select_iterable(self): + def _select_iterable(self) -> Iterable[ColumnElement[Any]]: return (self,) @util.memoized_property - def base_columns(self): - return util.column_set(c for c in self.proxy_set if not c._proxies) + def base_columns(self) -> FrozenSet[ColumnElement[Any]]: + return frozenset(c for c in self.proxy_set if not c._proxies) @util.memoized_property - def proxy_set(self): - s = util.column_set([self]) - for c in self._proxies: - s.update(c.proxy_set) - return s + def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + return frozenset([self]).union( + itertools.chain.from_iterable(c.proxy_set for c in self._proxies) + ) - def _uncached_proxy_set(self): + def _uncached_proxy_set(self) -> FrozenSet[ColumnElement[Any]]: """An 'uncached' version of proxy set. This is so that we can read annotations from the list of columns without breaking the caching of the above proxy_set. """ - s = util.column_set([self]) - for c in self._proxies: - s.update(c._uncached_proxy_set()) - return s + return frozenset([self]).union( + itertools.chain.from_iterable( + c._uncached_proxy_set() for c in self._proxies + ) + ) - def shares_lineage(self, othercolumn): + def shares_lineage(self, othercolumn: ColumnElement[Any]) -> bool: """Return True if the given :class:`_expression.ColumnElement` has a common ancestor to this :class:`_expression.ColumnElement`.""" return bool(self.proxy_set.intersection(othercolumn.proxy_set)) - def _compare_name_for_result(self, other): + def _compare_name_for_result(self, other: ColumnElement[Any]) -> bool: """Return True if the given column element compares to this one when targeting within a result row.""" @@ -1309,9 +1423,9 @@ class ColumnElement( ) @HasMemoized.memoized_attribute - def _proxy_key(self): + def _proxy_key(self) -> Optional[str]: if self._annotations and "proxy_key" in self._annotations: - return self._annotations["proxy_key"] + return cast(str, self._annotations["proxy_key"]) name = self.key if not name: @@ -1327,7 +1441,7 @@ class ColumnElement( return name @HasMemoized.memoized_attribute - def _expression_label(self): + def _expression_label(self) -> Optional[str]: """a suggested label to use in the case that the column has no name, which should be used if possible as the explicit 'AS <label>' where this expression would normally have an anon label. @@ -1340,18 +1454,18 @@ class ColumnElement( if getattr(self, "name", None) is not None: return None elif self._annotations and "proxy_key" in self._annotations: - return self._annotations["proxy_key"] + return cast(str, self._annotations["proxy_key"]) else: return None def _make_proxy( self, - selectable, + selectable: FromClause, name: Optional[str] = None, - key=None, - name_is_truncatable=False, - **kw, - ): + key: Optional[str] = None, + name_is_truncatable: bool = False, + **kw: Any, + ) -> typing_Tuple[str, ColumnClause[_T]]: """Create a new :class:`_expression.ColumnElement` representing this :class:`_expression.ColumnElement` as it appears in the select list of a descending selectable. @@ -1364,7 +1478,7 @@ class ColumnElement( else: key = name - co = ColumnClause( + co: ColumnClause[_T] = ColumnClause( coercions.expect(roles.TruncatedLabelRole, name) if name_is_truncatable else name, @@ -1376,9 +1490,10 @@ class ColumnElement( co._proxies = [self] if selectable._is_clone_of is not None: co._is_clone_of = selectable._is_clone_of.columns.get(key) + assert key is not None return key, co - def cast(self, type_): + def cast(self, type_: TypeEngine[_T]) -> Cast[_T]: """Produce a type cast, i.e. ``CAST(<expression> AS <type>)``. This is a shortcut to the :func:`_expression.cast` function. @@ -1406,7 +1521,9 @@ class ColumnElement( """ return Label(name, self, self.type) - def _anon_label(self, seed, add_hash=None) -> "_anonymous_label": + def _anon_label( + self, seed: Optional[str], add_hash: Optional[int] = None + ) -> _anonymous_label: while self._is_clone_of is not None: self = self._is_clone_of @@ -1441,7 +1558,7 @@ class ColumnElement( return _anonymous_label.safe_construct(hash_value, seed or "anon") @util.memoized_property - def _anon_name_label(self) -> "_anonymous_label": + def _anon_name_label(self) -> str: """Provides a constant 'anonymous label' for this ColumnElement. This is a label() expression which will be named at compile time. @@ -1462,7 +1579,7 @@ class ColumnElement( return self._anon_label(name) @util.memoized_property - def _anon_key_label(self): + def _anon_key_label(self) -> _anonymous_label: """Provides a constant 'anonymous key label' for this ColumnElement. Compare to ``anon_label``, except that the "key" of the column, @@ -1478,25 +1595,23 @@ class ColumnElement( """ return self._anon_label(self._proxy_key) - @property - @util.deprecated( + @util.deprecated_property( "1.4", "The :attr:`_expression.ColumnElement.anon_label` attribute is now " "private, and the public accessor is deprecated.", ) - def anon_label(self): + def anon_label(self) -> str: return self._anon_name_label - @property - @util.deprecated( + @util.deprecated_property( "1.4", "The :attr:`_expression.ColumnElement.anon_key_label` attribute is " "now private, and the public accessor is deprecated.", ) - def anon_key_label(self): + def anon_key_label(self) -> str: return self._anon_key_label - def _dedupe_anon_label_idx(self, idx): + def _dedupe_anon_label_idx(self, idx: int) -> str: """label to apply to a column that is anon labeled, but repeated in the SELECT, so that we have to make an "extra anon" label that disambiguates it from the previous appearance. @@ -1520,20 +1635,20 @@ class ColumnElement( return self._anon_label(label, add_hash=idx) @util.memoized_property - def _anon_tq_label(self): + def _anon_tq_label(self) -> _anonymous_label: return self._anon_label(getattr(self, "_tq_label", None)) @util.memoized_property - def _anon_tq_key_label(self): + def _anon_tq_key_label(self) -> _anonymous_label: return self._anon_label(getattr(self, "_tq_key_label", None)) - def _dedupe_anon_tq_label_idx(self, idx): + def _dedupe_anon_tq_label_idx(self, idx: int) -> _anonymous_label: label = getattr(self, "_tq_label", None) or "anon" return self._anon_label(label, add_hash=idx) -class WrapsColumnExpression: +class WrapsColumnExpression(ColumnElement[_T]): """Mixin that defines a :class:`_expression.ColumnElement` as a wrapper with special labeling behavior for an expression that already has a name. @@ -1548,25 +1663,27 @@ class WrapsColumnExpression: """ @property - def wrapped_column_expression(self): + def wrapped_column_expression(self) -> ColumnElement[_T]: raise NotImplementedError() - @property - def _tq_label(self): + @util.non_memoized_property + def _tq_label(self) -> Optional[str]: wce = self.wrapped_column_expression if hasattr(wce, "_tq_label"): return wce._tq_label else: return None - _label = _tq_label + @property + def _label(self) -> Optional[str]: + return self._tq_label @property - def _non_anon_label(self): + def _non_anon_label(self) -> Optional[str]: return None - @property - def _anon_name_label(self): + @util.non_memoized_property + def _anon_name_label(self) -> str: wce = self.wrapped_column_expression # this logic tries to get the WrappedColumnExpression to render @@ -1578,9 +1695,9 @@ class WrapsColumnExpression: return nal elif hasattr(wce, "_anon_name_label"): return wce._anon_name_label - return super(WrapsColumnExpression, self)._anon_name_label + return super()._anon_name_label - def _dedupe_anon_label_idx(self, idx): + def _dedupe_anon_label_idx(self, idx: int) -> str: wce = self.wrapped_column_expression nal = wce._non_anon_label if nal: @@ -1589,7 +1706,7 @@ class WrapsColumnExpression: return self._dedupe_anon_tq_label_idx(idx) -SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter") +SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter[Any]") class BindParameter(roles.InElementRole, ColumnElement[_T]): @@ -1614,7 +1731,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): __visit_name__ = "bindparam" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("key", InternalTraversal.dp_anon_name), ("type", InternalTraversal.dp_type), ("callable", InternalTraversal.dp_plain_dict), @@ -1622,7 +1739,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): ] key: str - type: TypeEngine + type: TypeEngine[_T] _is_crud = False _is_bind_parameter = True @@ -1634,23 +1751,23 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): def __init__( self, - key, - value=NO_ARG, - type_=None, - unique=False, - required=NO_ARG, - quote=None, - callable_=None, - expanding=False, - isoutparam=False, - literal_execute=False, - _compared_to_operator=None, - _compared_to_type=None, - _is_crud=False, + key: Optional[str], + value: Any = _NoArg.NO_ARG, + type_: Optional[_TypeEngineArgument[_T]] = None, + unique: bool = False, + required: Union[bool, Literal[_NoArg.NO_ARG]] = _NoArg.NO_ARG, + quote: Optional[bool] = None, + callable_: Optional[Callable[[], Any]] = None, + expanding: bool = False, + isoutparam: bool = False, + literal_execute: bool = False, + _compared_to_operator: Optional[OperatorType] = None, + _compared_to_type: Optional[TypeEngine[Any]] = None, + _is_crud: bool = False, ): - if required is NO_ARG: - required = value is NO_ARG and callable_ is None - if value is NO_ARG: + if required is _NoArg.NO_ARG: + required = value is _NoArg.NO_ARG and callable_ is None + if value is _NoArg.NO_ARG: value = None if quote is not None: @@ -1713,12 +1830,19 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): self.type = type_api._resolve_value_to_type(check_value) elif isinstance(type_, type): self.type = type_() - elif type_._is_tuple_type and value: - if expanding: - check_value = value[0] + elif is_tuple_type(type_): + if value: + if expanding: + check_value = value[0] + else: + check_value = value + cast( + "BindParameter[typing_Tuple[Any, ...]]", self + ).type = type_._resolve_values_to_types(check_value) else: - check_value = value - self.type = type_._resolve_values_to_types(check_value) + cast( + "BindParameter[typing_Tuple[Any, ...]]", self + ).type = type_ else: self.type = type_ @@ -1791,7 +1915,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): return c def _clone( - self: SelfBindParameter, maintain_key=False, **kw + self: SelfBindParameter, maintain_key: bool = False, **kw: Any ) -> SelfBindParameter: c = ClauseElement._clone(self, **kw) if not maintain_key and self.unique: @@ -1865,7 +1989,9 @@ class TypeClause(DQLDMLClauseElement): __visit_name__ = "typeclause" - _traverse_internals = [("type", InternalTraversal.dp_type)] + _traverse_internals: _TraverseInternalsType = [ + ("type", InternalTraversal.dp_type) + ] def __init__(self, type_): self.type = type_ @@ -1882,7 +2008,7 @@ class TextClause( roles.OrderByRole, roles.FromClauseRole, roles.SelectStatementRole, - roles.BinaryElementRole, + roles.BinaryElementRole[Any], roles.InElementRole, Executable, DQLDMLClauseElement, @@ -1909,7 +2035,7 @@ class TextClause( __visit_name__ = "textclause" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("_bindparams", InternalTraversal.dp_string_clauseelement_dict), ("text", InternalTraversal.dp_string), ] @@ -1923,7 +2049,9 @@ class TextClause( _render_label_in_columns_clause = False - _hide_froms = () + @property + def _hide_froms(self) -> Iterable[FromClause]: + return () def __and__(self, other): # support use in select.where(), query.filter() @@ -1935,12 +2063,13 @@ class TextClause( # help in those cases where text() is # interpreted in a column expression situation - key = _label = None + key: Optional[str] = None + _label: Optional[str] = None _allow_label_resolve = False - def __init__(self, text): - self._bindparams = {} + def __init__(self, text: str): + self._bindparams: Dict[str, BindParameter[Any]] = {} def repl(m): self._bindparams[m.group(1)] = BindParameter(m.group(1)) @@ -1952,7 +2081,9 @@ class TextClause( @_generative def bindparams( - self: SelfTextClause, *binds, **names_to_values + self: SelfTextClause, + *binds: BindParameter[Any], + **names_to_values: Any, ) -> SelfTextClause: """Establish the values and/or types of bound parameters within this :class:`_expression.TextClause` construct. @@ -2205,7 +2336,7 @@ class TextClause( else col for col in cols ] - keyed_input_cols = [ + keyed_input_cols: List[ColumnClause[Any]] = [ ColumnClause(key, type_) for key, type_ in types.items() ] @@ -2230,7 +2361,7 @@ class TextClause( return self -class Null(SingletonConstant, roles.ConstExprRole, ColumnElement): +class Null(SingletonConstant, roles.ConstExprRole[None], ColumnElement[None]): """Represent the NULL keyword in a SQL statement. :class:`.Null` is accessed as a constant via the @@ -2240,23 +2371,26 @@ class Null(SingletonConstant, roles.ConstExprRole, ColumnElement): __visit_name__ = "null" - _traverse_internals = [] + _traverse_internals: _TraverseInternalsType = [] + _singleton: Null @util.memoized_property def type(self): return type_api.NULLTYPE @classmethod - def _instance(cls): + def _instance(cls) -> Null: """Return a constant :class:`.Null` construct.""" - return Null() + return Null._singleton Null._create_singleton() -class False_(SingletonConstant, roles.ConstExprRole, ColumnElement): +class False_( + SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool] +): """Represent the ``false`` keyword, or equivalent, in a SQL statement. :class:`.False_` is accessed as a constant via the @@ -2265,24 +2399,25 @@ class False_(SingletonConstant, roles.ConstExprRole, ColumnElement): """ __visit_name__ = "false" - _traverse_internals = [] + _traverse_internals: _TraverseInternalsType = [] + _singleton: False_ @util.memoized_property def type(self): return type_api.BOOLEANTYPE - def _negate(self): - return True_() + def _negate(self) -> True_: + return True_._singleton @classmethod - def _instance(cls): - return False_() + def _instance(cls) -> False_: + return False_._singleton False_._create_singleton() -class True_(SingletonConstant, roles.ConstExprRole, ColumnElement): +class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]): """Represent the ``true`` keyword, or equivalent, in a SQL statement. :class:`.True_` is accessed as a constant via the @@ -2292,14 +2427,15 @@ class True_(SingletonConstant, roles.ConstExprRole, ColumnElement): __visit_name__ = "true" - _traverse_internals = [] + _traverse_internals: _TraverseInternalsType = [] + _singleton: True_ @util.memoized_property def type(self): return type_api.BOOLEANTYPE - def _negate(self): - return False_() + def _negate(self) -> False_: + return False_._singleton @classmethod def _ifnone(cls, other): @@ -2309,8 +2445,8 @@ class True_(SingletonConstant, roles.ConstExprRole, ColumnElement): return other @classmethod - def _instance(cls): - return True_() + def _instance(cls) -> True_: + return True_._singleton True_._create_singleton() @@ -2333,18 +2469,18 @@ class ClauseList( _is_clause_list = True - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("clauses", InternalTraversal.dp_clauseelement_list), ("operator", InternalTraversal.dp_operator), ] def __init__( self, - *clauses, - operator=operators.comma_op, - group=True, - group_contents=True, - _flatten_sub_clauses=False, + *clauses: _ColumnExpression[Any], + operator: OperatorType = operators.comma_op, + group: bool = True, + group_contents: bool = True, + _flatten_sub_clauses: bool = False, _literal_as_text_role: Type[roles.SQLRole] = roles.WhereHavingRole, ): self.operator = operator @@ -2405,8 +2541,8 @@ class ClauseList( coercions.expect(self._text_converter_role, clause) ) - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return list(itertools.chain(*[c._from_objects for c in self.clauses])) def self_group(self, against=None): @@ -2465,7 +2601,14 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): return lcc, [c.self_group(against=against) for c in convert_clauses] @classmethod - def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): + def _construct( + cls, + operator: OperatorType, + continue_on: Any, + skip_on: Any, + *clauses: _ColumnExpression[Any], + **kw: Any, + ) -> BooleanClauseList: lcc, convert_clauses = cls._process_clauses_for_boolean( operator, continue_on, @@ -2479,11 +2622,11 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): if lcc > 1: # multiple elements. Return regular BooleanClauseList # which will link elements against the operator. - return cls._construct_raw(operator, convert_clauses) + return cls._construct_raw(operator, convert_clauses) # type: ignore[no-any-return] # noqa E501 elif lcc == 1: # just one element. return it as a single boolean element, # not a list and discard the operator. - return convert_clauses[0] + return convert_clauses[0] # type: ignore[no-any-return] # noqa E501 else: # no elements period. deprecated use case. return an empty # ClauseList construct that generates nothing unless it has @@ -2500,7 +2643,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): }, version="1.4", ) - return cls._construct_raw(operator) + return cls._construct_raw(operator) # type: ignore[no-any-return] # noqa E501 @classmethod def _construct_for_whereclause(cls, clauses): @@ -2540,7 +2683,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): return self @classmethod - def and_(cls, *clauses): + def and_(cls, *clauses: _ColumnExpression[bool]) -> BooleanClauseList: r"""Produce a conjunction of expressions joined by ``AND``. See :func:`_sql.and_` for full documentation. @@ -2550,7 +2693,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): ) @classmethod - def or_(cls, *clauses): + def or_(cls, *clauses: _ColumnExpression[bool]) -> BooleanClauseList: """Produce a conjunction of expressions joined by ``OR``. See :func:`_sql.or_` for full documentation. @@ -2577,19 +2720,27 @@ and_ = BooleanClauseList.and_ or_ = BooleanClauseList.or_ -class Tuple(ClauseList, ColumnElement): +class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): """Represent a SQL tuple.""" __visit_name__ = "tuple" - _traverse_internals = ClauseList._traverse_internals + [] + _traverse_internals: _TraverseInternalsType = ( + ClauseList._traverse_internals + [] + ) + + type: TupleType @util.preload_module("sqlalchemy.sql.sqltypes") - def __init__(self, *clauses, types=None): + def __init__( + self, + *clauses: _ColumnExpression[Any], + types: Optional[Sequence[_TypeEngineArgument[Any]]] = None, + ): sqltypes = util.preloaded.sql_sqltypes if types is None: - clauses = [ + init_clauses = [ coercions.expect(roles.ExpressionElementRole, c) for c in clauses ] @@ -2599,7 +2750,7 @@ class Tuple(ClauseList, ColumnElement): "Wrong number of elements for %d-tuple: %r " % (len(types), clauses) ) - clauses = [ + init_clauses = [ coercions.expect( roles.ExpressionElementRole, c, @@ -2608,8 +2759,8 @@ class Tuple(ClauseList, ColumnElement): for typ, c in zip(types, clauses) ] - self.type = sqltypes.TupleType(*[arg.type for arg in clauses]) - super(Tuple, self).__init__(*clauses) + self.type = sqltypes.TupleType(*[arg.type for arg in init_clauses]) + super(Tuple, self).__init__(*init_clauses) @property def _select_iterable(self): @@ -2672,7 +2823,7 @@ class Case(ColumnElement[_T]): __visit_name__ = "case" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("value", InternalTraversal.dp_clauseelement), ("whens", InternalTraversal.dp_clauseelement_tuples), ("else_", InternalTraversal.dp_clauseelement), @@ -2681,13 +2832,24 @@ class Case(ColumnElement[_T]): # for case(), the type is derived from the whens. so for the moment # users would have to cast() the case to get a specific type - def __init__(self, *whens, value=None, else_=None): + whens: List[typing_Tuple[ColumnElement[bool], ColumnElement[_T]]] + else_: Optional[ColumnElement[_T]] + value: Optional[ColumnElement[Any]] - whens = coercions._expression_collection_was_a_list( + def __init__( + self, + *whens: Union[ + typing_Tuple[_ColumnExpression[bool], Any], Mapping[Any, Any] + ], + value: Optional[Any] = None, + else_: Optional[Any] = None, + ): + + new_whens: Iterable[Any] = coercions._expression_collection_was_a_list( "whens", "case", whens ) try: - whens = util.dictlike_iteritems(whens) + new_whens = util.dictlike_iteritems(new_whens) except TypeError: pass @@ -2700,7 +2862,7 @@ class Case(ColumnElement[_T]): ).self_group(), coercions.expect(roles.ExpressionElementRole, r), ) - for (c, r) in whens + for (c, r) in new_whens ] if whenlist: @@ -2713,7 +2875,7 @@ class Case(ColumnElement[_T]): else: self.value = coercions.expect(roles.ExpressionElementRole, value) - self.type = type_ + self.type = cast(_T, type_) self.whens = whenlist if else_ is not None: @@ -2721,14 +2883,14 @@ class Case(ColumnElement[_T]): else: self.else_ = None - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return list( itertools.chain(*[x._from_objects for x in self.get_children()]) ) -class Cast(WrapsColumnExpression, ColumnElement[_T]): +class Cast(WrapsColumnExpression[_T]): """Represent a ``CAST`` expression. :class:`.Cast` is produced using the :func:`.cast` factory function, @@ -2754,12 +2916,20 @@ class Cast(WrapsColumnExpression, ColumnElement[_T]): __visit_name__ = "cast" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("clause", InternalTraversal.dp_clauseelement), ("typeclause", InternalTraversal.dp_clauseelement), ] - def __init__(self, expression, type_): + clause: ColumnElement[Any] + type: TypeEngine[_T] + typeclause: TypeClause + + def __init__( + self, + expression: _ColumnExpression[Any], + type_: _TypeEngineArgument[_T], + ): self.type = type_api.to_instance(type_) self.clause = coercions.expect( roles.ExpressionElementRole, @@ -2769,8 +2939,8 @@ class Cast(WrapsColumnExpression, ColumnElement[_T]): ) self.typeclause = TypeClause(self.type) - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.clause._from_objects @property @@ -2778,7 +2948,7 @@ class Cast(WrapsColumnExpression, ColumnElement[_T]): return self.clause -class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]): +class TypeCoerce(WrapsColumnExpression[_T]): """Represent a Python-side type-coercion wrapper. :class:`.TypeCoerce` supplies the :func:`_expression.type_coerce` @@ -2798,12 +2968,19 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]): __visit_name__ = "type_coerce" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("clause", InternalTraversal.dp_clauseelement), ("type", InternalTraversal.dp_type), ] - def __init__(self, expression, type_): + clause: ColumnElement[Any] + type: TypeEngine[_T] + + def __init__( + self, + expression: _ColumnExpression[Any], + type_: _TypeEngineArgument[_T], + ): self.type = type_api.to_instance(type_) self.clause = coercions.expect( roles.ExpressionElementRole, @@ -2812,8 +2989,8 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]): apply_propagate_attrs=self, ) - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.clause._from_objects @HasMemoized.memoized_attribute @@ -2837,27 +3014,30 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement[_T]): return self -class Extract(ColumnElement[_T]): +class Extract(ColumnElement[int]): """Represent a SQL EXTRACT clause, ``extract(field FROM expr)``.""" __visit_name__ = "extract" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("expr", InternalTraversal.dp_clauseelement), ("field", InternalTraversal.dp_string), ] - def __init__(self, field, expr): + expr: ColumnElement[Any] + field: str + + def __init__(self, field: str, expr: _ColumnExpression[Any]): self.type = type_api.INTEGERTYPE self.field = field self.expr = coercions.expect(roles.ExpressionElementRole, expr) - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.expr._from_objects -class _label_reference(ColumnElement): +class _label_reference(ColumnElement[_T]): """Wrap a column expression as it appears in a 'reference' context. This expression is any that includes an _order_by_label_element, @@ -2872,26 +3052,30 @@ class _label_reference(ColumnElement): __visit_name__ = "label_reference" - _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement) + ] - def __init__(self, element): + def __init__(self, element: ColumnElement[_T]): self.element = element - @property - def _from_objects(self): - return () + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: + return [] -class _textual_label_reference(ColumnElement): +class _textual_label_reference(ColumnElement[Any]): __visit_name__ = "textual_label_reference" - _traverse_internals = [("element", InternalTraversal.dp_string)] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_string) + ] - def __init__(self, element): + def __init__(self, element: str): self.element = element @util.memoized_property - def _text_clause(self): + def _text_clause(self) -> TextClause: return TextClause(self.element) @@ -2911,7 +3095,7 @@ class UnaryExpression(ColumnElement[_T]): __visit_name__ = "unary" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("operator", InternalTraversal.dp_operator), ("modifier", InternalTraversal.dp_operator), @@ -2919,11 +3103,11 @@ class UnaryExpression(ColumnElement[_T]): def __init__( self, - element, - operator=None, - modifier=None, - type_: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] = None, - wraps_column_expression=False, + element: ColumnElement[Any], + operator: Optional[OperatorType] = None, + modifier: Optional[OperatorType] = None, + type_: Optional[_TypeEngineArgument[_T]] = None, + wraps_column_expression: bool = False, ): self.operator = operator self.modifier = modifier @@ -2935,7 +3119,10 @@ class UnaryExpression(ColumnElement[_T]): self.wraps_column_expression = wraps_column_expression @classmethod - def _create_nulls_first(cls, column): + def _create_nulls_first( + cls, + column: _ColumnExpression[_T], + ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), modifier=operators.nulls_first_op, @@ -2943,7 +3130,10 @@ class UnaryExpression(ColumnElement[_T]): ) @classmethod - def _create_nulls_last(cls, column): + def _create_nulls_last( + cls, + column: _ColumnExpression[_T], + ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), modifier=operators.nulls_last_op, @@ -2951,7 +3141,9 @@ class UnaryExpression(ColumnElement[_T]): ) @classmethod - def _create_desc(cls, column): + def _create_desc( + cls, column: _ColumnExpression[_T] + ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), modifier=operators.desc_op, @@ -2959,7 +3151,10 @@ class UnaryExpression(ColumnElement[_T]): ) @classmethod - def _create_asc(cls, column): + def _create_asc( + cls, + column: _ColumnExpression[_T], + ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), modifier=operators.asc_op, @@ -2967,24 +3162,27 @@ class UnaryExpression(ColumnElement[_T]): ) @classmethod - def _create_distinct(cls, expr): - expr = coercions.expect(roles.ExpressionElementRole, expr) + def _create_distinct( + cls, + expr: _ColumnExpression[_T], + ) -> UnaryExpression[_T]: + col_expr = coercions.expect(roles.ExpressionElementRole, expr) return UnaryExpression( - expr, + col_expr, operator=operators.distinct_op, - type_=expr.type, + type_=col_expr.type, wraps_column_expression=False, ) @property - def _order_by_label_element(self): + def _order_by_label_element(self) -> Optional[Label[Any]]: if self.modifier in (operators.desc_op, operators.asc_op): return self.element._order_by_label_element else: return None - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.element._from_objects def _negate(self): @@ -3005,7 +3203,7 @@ class UnaryExpression(ColumnElement[_T]): return self -class CollectionAggregate(UnaryExpression): +class CollectionAggregate(UnaryExpression[_T]): """Forms the basis for right-hand collection operator modifiers ANY and ALL. @@ -3018,7 +3216,9 @@ class CollectionAggregate(UnaryExpression): inherit_cache = True @classmethod - def _create_any(cls, expr): + def _create_any( + cls, expr: _ColumnExpression[_T] + ) -> CollectionAggregate[_T]: expr = coercions.expect(roles.ExpressionElementRole, expr) expr = expr.self_group() @@ -3030,7 +3230,9 @@ class CollectionAggregate(UnaryExpression): ) @classmethod - def _create_all(cls, expr): + def _create_all( + cls, expr: _ColumnExpression[_T] + ) -> CollectionAggregate[_T]: expr = coercions.expect(roles.ExpressionElementRole, expr) expr = expr.self_group() return CollectionAggregate( @@ -3059,7 +3261,7 @@ class CollectionAggregate(UnaryExpression): ) -class AsBoolean(WrapsColumnExpression, UnaryExpression): +class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]): inherit_cache = True def __init__(self, element, operator, negate): @@ -3101,7 +3303,7 @@ class BinaryExpression(ColumnElement[_T]): __visit_name__ = "binary" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("left", InternalTraversal.dp_clauseelement), ("right", InternalTraversal.dp_clauseelement), ("operator", InternalTraversal.dp_operator), @@ -3119,16 +3321,16 @@ class BinaryExpression(ColumnElement[_T]): """ + modifiers: Optional[Mapping[str, Any]] + def __init__( self, - left: ColumnElement, - right: Union[ColumnElement, ClauseList], - operator, - type_: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] - ] = None, - negate=None, - modifiers=None, + left: ColumnElement[Any], + right: Union[ColumnElement[Any], ClauseList], + operator: OperatorType, + type_: Optional[_TypeEngineArgument[_T]] = None, + negate: Optional[OperatorType] = None, + modifiers: Optional[Mapping[str, Any]] = None, ): # allow compatibility with libraries that # refer to BinaryExpression directly and pass strings @@ -3149,8 +3351,40 @@ class BinaryExpression(ColumnElement[_T]): self.modifiers = modifiers def __bool__(self): - if self.operator in (operator.eq, operator.ne): - return self.operator(*self._orig) + """Implement Python-side "bool" for BinaryExpression as a + simple "identity" check for the left and right attributes, + if the operator is "eq" or "ne". Otherwise the expression + continues to not support "bool" like all other column expressions. + + The rationale here is so that ColumnElement objects can be hashable. + What? Well, suppose you do this:: + + c1, c2 = column('x'), column('y') + s1 = set([c1, c2]) + + We do that **a lot**, columns inside of sets is an extremely basic + thing all over the ORM for example. + + So what happens if we do this? :: + + c1 in s1 + + Hashing means it will normally use ``__hash__()`` of the object, + but in case of hash collision, it's going to also do ``c1 == c1`` + and/or ``c1 == c2`` inside. Those operations need to return a + True/False value. But because we override ``==`` and ``!=``, they're + going to get a BinaryExpression. Hence we implement ``__bool__`` here + so that these comparisons behave in this particular context mostly + like regular object comparisons. Thankfully Python is OK with + that! Otherwise we'd have to use special set classes for columns + (which we used to do, decades ago). + + """ + if self.operator in (operators.eq, operators.ne): + # this is using the eq/ne operator given int hash values, + # rather than Operator, so that "bool" can be based on + # identity + return self.operator(*self._orig) # type: ignore else: raise TypeError("Boolean value of this clause is not defined") @@ -3167,8 +3401,8 @@ class BinaryExpression(ColumnElement[_T]): def is_comparison(self): return operators.is_comparison(self.operator) - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.left._from_objects + self.right._from_objects def self_group(self, against=None): @@ -3192,7 +3426,7 @@ class BinaryExpression(ColumnElement[_T]): return super(BinaryExpression, self)._negate() -class Slice(ColumnElement): +class Slice(ColumnElement[Any]): """Represent SQL for a Python array-slice object. This is not a specific SQL construct at this level, but @@ -3202,7 +3436,7 @@ class Slice(ColumnElement): __visit_name__ = "slice" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("start", InternalTraversal.dp_clauseelement), ("stop", InternalTraversal.dp_clauseelement), ("step", InternalTraversal.dp_clauseelement), @@ -3234,7 +3468,7 @@ class Slice(ColumnElement): return self -class IndexExpression(BinaryExpression): +class IndexExpression(BinaryExpression[Any]): """Represent the class of expressions that are like an "index" operation.""" @@ -3246,6 +3480,8 @@ class GroupedElement(DQLDMLClauseElement): __visit_name__ = "grouping" + element: ClauseElement + def self_group(self, against=None): return self @@ -3253,15 +3489,19 @@ class GroupedElement(DQLDMLClauseElement): return self.element._ungroup() -class Grouping(GroupedElement, ColumnElement): +class Grouping(GroupedElement, ColumnElement[_T]): """Represent a grouping within a column expression""" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("type", InternalTraversal.dp_type), ] - def __init__(self, element): + element: Union[TextClause, ClauseList, ColumnElement[_T]] + + def __init__( + self, element: Union[TextClause, ClauseList, ColumnElement[_T]] + ): self.element = element self.type = getattr(element, "type", type_api.NULLTYPE) @@ -3272,21 +3512,21 @@ class Grouping(GroupedElement, ColumnElement): def _is_implicitly_boolean(self): return self.element._is_implicitly_boolean - @property - def _tq_label(self): + @util.non_memoized_property + def _tq_label(self) -> Optional[str]: return ( getattr(self.element, "_tq_label", None) or self._anon_name_label ) - @property - def _proxies(self): + @util.non_memoized_property + def _proxies(self) -> List[ColumnElement[Any]]: if isinstance(self.element, ColumnElement): return [self.element] else: return [] - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.element._from_objects def __getattr__(self, attr): @@ -3300,8 +3540,13 @@ class Grouping(GroupedElement, ColumnElement): self.type = state["type"] -RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED") -RANGE_CURRENT = util.symbol("RANGE_CURRENT") +class _OverRange(IntEnum): + RANGE_UNBOUNDED = 0 + RANGE_CURRENT = 1 + + +RANGE_UNBOUNDED = _OverRange.RANGE_UNBOUNDED +RANGE_CURRENT = _OverRange.RANGE_CURRENT class Over(ColumnElement[_T]): @@ -3316,7 +3561,7 @@ class Over(ColumnElement[_T]): __visit_name__ = "over" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("order_by", InternalTraversal.dp_clauseelement), ("partition_by", InternalTraversal.dp_clauseelement), @@ -3324,15 +3569,26 @@ class Over(ColumnElement[_T]): ("rows", InternalTraversal.dp_plain_obj), ] - order_by = None - partition_by = None + order_by: Optional[ClauseList] = None + partition_by: Optional[ClauseList] = None - element = None + element: ColumnElement[_T] """The underlying expression object to which this :class:`.Over` object refers towards.""" + range_: Optional[typing_Tuple[int, int]] + def __init__( - self, element, partition_by=None, order_by=None, range_=None, rows=None + self, + element: ColumnElement[_T], + partition_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + order_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ): self.element = element if order_by is not None: @@ -3368,10 +3624,15 @@ class Over(ColumnElement[_T]): self.rows, ) - def _interpret_range(self, range_): + def _interpret_range( + self, range_: typing_Tuple[Optional[int], Optional[int]] + ) -> typing_Tuple[int, int]: if not isinstance(range_, tuple) or len(range_) != 2: raise exc.ArgumentError("2-tuple expected for range/rows") + lower: int + upper: int + if range_[0] is None: lower = RANGE_UNBOUNDED else: @@ -3404,8 +3665,8 @@ class Over(ColumnElement[_T]): def type(self): return self.element.type - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return list( itertools.chain( *[ @@ -3436,14 +3697,16 @@ class WithinGroup(ColumnElement[_T]): __visit_name__ = "withingroup" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("order_by", InternalTraversal.dp_clauseelement), ] - order_by = None + order_by: Optional[ClauseList] = None - def __init__(self, element, *order_by): + def __init__( + self, element: FunctionElement[_T], *order_by: _ColumnExpression[Any] + ): self.element = element if order_by is not None: self.order_by = ClauseList( @@ -3451,7 +3714,9 @@ class WithinGroup(ColumnElement[_T]): ) def __reduce__(self): - return self.__class__, (self.element,) + tuple(self.order_by) + return self.__class__, (self.element,) + ( + tuple(self.order_by) if self.order_by is not None else () + ) def over(self, partition_by=None, order_by=None, range_=None, rows=None): """Produce an OVER clause against this :class:`.WithinGroup` @@ -3477,8 +3742,8 @@ class WithinGroup(ColumnElement[_T]): else: return self.element.type - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return list( itertools.chain( *[ @@ -3490,7 +3755,7 @@ class WithinGroup(ColumnElement[_T]): ) -class FunctionFilter(ColumnElement): +class FunctionFilter(ColumnElement[_T]): """Represent a function FILTER clause. This is a special operator against aggregate and window functions, @@ -3512,14 +3777,16 @@ class FunctionFilter(ColumnElement): __visit_name__ = "funcfilter" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("func", InternalTraversal.dp_clauseelement), ("criterion", InternalTraversal.dp_clauseelement), ] - criterion = None + criterion: Optional[ColumnElement[bool]] = None - def __init__(self, func, *criterion): + def __init__( + self, func: FunctionElement[_T], *criterion: _ColumnExpression[bool] + ): self.func = func self.filter(*criterion) @@ -3535,17 +3802,27 @@ class FunctionFilter(ColumnElement): """ - for criterion in list(criterion): - criterion = coercions.expect(roles.WhereHavingRole, criterion) + for crit in list(criterion): + crit = coercions.expect(roles.WhereHavingRole, crit) if self.criterion is not None: - self.criterion = self.criterion & criterion + self.criterion = self.criterion & crit else: - self.criterion = criterion + self.criterion = crit return self - def over(self, partition_by=None, order_by=None, range_=None, rows=None): + def over( + self, + partition_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + order_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + ) -> Over[_T]: """Produce an OVER clause against this filtered function. Used against aggregate or so-called "window" functions, @@ -3581,8 +3858,8 @@ class FunctionFilter(ColumnElement): def type(self): return self.func.type - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return list( itertools.chain( *[ @@ -3594,7 +3871,7 @@ class FunctionFilter(ColumnElement): ) -class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): +class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -3604,13 +3881,21 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): __visit_name__ = "label" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("name", InternalTraversal.dp_anon_name), - ("_type", InternalTraversal.dp_type), + ("type", InternalTraversal.dp_type), ("_element", InternalTraversal.dp_clauseelement), ] - def __init__(self, name, element, type_=None): + _element: ColumnElement[_T] + name: str + + def __init__( + self, + name: Optional[str], + element: _ColumnExpression[_T], + type_: Optional[_TypeEngineArgument[_T]] = None, + ): orig_element = element element = coercions.expect( roles.ExpressionElementRole, @@ -3635,11 +3920,14 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): self.key = self._tq_label = self._tq_key_label = self.name self._element = element - self._type = type_ + # self._type = type_ + self.type = type_api.to_instance( + type_ or getattr(self._element, "type", None) + ) self._proxies = [element] def __reduce__(self): - return self.__class__, (self.name, self._element, self._type) + return self.__class__, (self.name, self._element, self.type) @util.memoized_property def _is_implicitly_boolean(self): @@ -3653,14 +3941,8 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): def _order_by_label_element(self): return self - @util.memoized_property - def type(self): - return type_api.to_instance( - self._type or getattr(self._element, "type", None) - ) - @HasMemoized.memoized_attribute - def element(self): + def element(self) -> ColumnElement[_T]: return self._element.self_group(against=operators.as_) def self_group(self, against=None): @@ -3672,7 +3954,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): def _apply_to_inner(self, fn, *arg, **kw): sub_element = fn(*arg, **kw) if sub_element is not self._element: - return Label(self.name, sub_element, type_=self._type) + return Label(self.name, sub_element, type_=self.type) else: return self @@ -3693,8 +3975,8 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): ) self.key = self._tq_label = self._tq_key_label = self.name - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.element._from_objects def _make_proxy(self, selectable, name=None, **kw): @@ -3724,15 +4006,16 @@ class Label(roles.LabeledColumnExprRole, ColumnElement[_T]): e._propagate_attrs = selectable._propagate_attrs e._proxies.append(self) - if self._type is not None: - e.type = self._type + if self.type is not None: + e.type = self.type return self.key, e class NamedColumn(ColumnElement[_T]): is_literal = False - table = None + table: Optional[FromClause] = None + name: str def _compare_name_for_result(self, other): return (hasattr(other, "name") and self.name == other.name) or ( @@ -3740,7 +4023,7 @@ class NamedColumn(ColumnElement[_T]): ) @util.memoized_property - def description(self): + def description(self) -> str: return self.name @HasMemoized.memoized_attribute @@ -3759,7 +4042,7 @@ class NamedColumn(ColumnElement[_T]): return self._tq_label @HasMemoized.memoized_attribute - def _tq_label(self): + def _tq_label(self) -> Optional[str]: """table qualified label based on column name. for table-bound columns this is <tablename>_<columnname>; all other @@ -3776,7 +4059,9 @@ class NamedColumn(ColumnElement[_T]): def _non_anon_label(self): return self.name - def _gen_tq_label(self, name, dedupe_on_key=True): + def _gen_tq_label( + self, name: str, dedupe_on_key: bool = True + ) -> Optional[str]: return name def _bind_param(self, operator, obj, type_=None, expanding=False): @@ -3817,7 +4102,7 @@ class NamedColumn(ColumnElement[_T]): class ColumnClause( roles.DDLReferredColumnRole, - roles.LabeledColumnExprRole, + roles.LabeledColumnExprRole[_T], roles.StrAsPlainColumnRole, Immutable, NamedColumn[_T], @@ -3859,30 +4144,31 @@ class ColumnClause( """ - table = None - is_literal = False + table: Optional[FromClause] + is_literal: bool __visit_name__ = "column" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("name", InternalTraversal.dp_anon_name), ("type", InternalTraversal.dp_type), ("table", InternalTraversal.dp_clauseelement), ("is_literal", InternalTraversal.dp_boolean), ] - onupdate = default = server_default = server_onupdate = None + onupdate: Optional[DefaultGenerator] = None + default: Optional[DefaultGenerator] = None + server_default: Optional[DefaultGenerator] = None + server_onupdate: Optional[DefaultGenerator] = None _is_multiparam_column = False def __init__( self, text: str, - type_: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] - ] = None, + type_: Optional[_TypeEngineArgument[_T]] = None, is_literal: bool = False, - _selectable: Optional["FromClause"] = None, + _selectable: Optional[FromClause] = None, ): self.key = self.name = text self.table = _selectable @@ -3916,7 +4202,7 @@ class ColumnClause( return super(ColumnClause, self)._clone(**kw) @HasMemoized.memoized_attribute - def _from_objects(self): + def _from_objects(self) -> List[FromClause]: t = self.table if t is not None: return [t] @@ -3953,7 +4239,9 @@ class ColumnClause( else: return other.proxy_set.intersection(self.proxy_set) - def _gen_tq_label(self, name, dedupe_on_key=True): + def _gen_tq_label( + self, name: str, dedupe_on_key: bool = True + ) -> Optional[str]: """generate table-qualified label for a table-bound column this is <tablename>_<columnname>. @@ -3962,22 +4250,24 @@ class ColumnClause( as well as the .columns collection on a Join object. """ + label: str t = self.table if self.is_literal: return None - elif t is not None and t.named_with_column: - if getattr(t, "schema", None): + elif t is not None and is_named_from_clause(t): + if has_schema_attr(t) and t.schema: label = t.schema.replace(".", "_") + "_" + t.name + "_" + name else: + assert not TYPE_CHECKING or isinstance(t, NamedFromClause) label = t.name + "_" + name # propagate name quoting rules for labels. - if getattr(name, "quote", None) is not None: - if isinstance(label, quoted_name): + if is_quoted_name(name) and name.quote is not None: + if is_quoted_name(label): label.quote = name.quote else: label = quoted_name(label, name.quote) - elif getattr(t.name, "quote", None) is not None: + elif is_quoted_name(t.name) and t.name.quote is not None: # can't get this situation to occur, so let's # assert false on it for now assert not isinstance(label, quoted_name) @@ -4046,16 +4336,16 @@ class ColumnClause( return c.key, c -class TableValuedColumn(NamedColumn): +class TableValuedColumn(NamedColumn[_T]): __visit_name__ = "table_valued_column" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("name", InternalTraversal.dp_anon_name), ("type", InternalTraversal.dp_type), ("scalar_alias", InternalTraversal.dp_clauseelement), ] - def __init__(self, scalar_alias, type_): + def __init__(self, scalar_alias: NamedFromClause, type_: TypeEngine[_T]): self.scalar_alias = scalar_alias self.key = self.name = scalar_alias.name self.type = type_ @@ -4064,24 +4354,28 @@ class TableValuedColumn(NamedColumn): self.scalar_alias = clone(self.scalar_alias, **kw) self.key = self.name = self.scalar_alias.name - @property - def _from_objects(self): + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: return [self.scalar_alias] -class CollationClause(ColumnElement): +class CollationClause(ColumnElement[str]): __visit_name__ = "collation" - _traverse_internals = [("collation", InternalTraversal.dp_string)] + _traverse_internals: _TraverseInternalsType = [ + ("collation", InternalTraversal.dp_string) + ] @classmethod - def _create_collation_expression(cls, expression, collation): + def _create_collation_expression( + cls, expression: _ColumnExpression[str], collation: str + ) -> BinaryExpression[str]: expr = coercions.expect(roles.ExpressionElementRole, expression) return BinaryExpression( expr, CollationClause(collation), operators.collate, - type_=expression.type, + type_=expr.type, ) def __init__(self, collation): @@ -4163,6 +4457,8 @@ class quoted_name(util.MemoizedSlots, str): __slots__ = "quote", "lower", "upper" + quote: Optional[bool] + def __new__(cls, value, quote): if value is None: return None @@ -4196,10 +4492,10 @@ class quoted_name(util.MemoizedSlots, str): return str(self).upper() -def _find_columns(clause): +def _find_columns(clause: ClauseElement) -> Set[ColumnClause[Any]]: """locate Column objects within the given expression.""" - cols = util.column_set() + cols: Set[ColumnClause[Any]] = set() traverse(clause, {}, {"column": cols.add}) return cols @@ -4226,6 +4522,8 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False): class AnnotatedColumnElement(Annotated): + _Annotated__element: ColumnElement[Any] + def __init__(self, element, values): Annotated.__init__(self, element, values) for attr in ( @@ -4265,7 +4563,7 @@ class AnnotatedColumnElement(Annotated): return self._Annotated__element.info @util.memoized_property - def _anon_name_label(self): + def _anon_name_label(self) -> str: return self._Annotated__element._anon_name_label @@ -4353,8 +4651,12 @@ class _anonymous_label(_truncated_label): @classmethod def safe_construct( - cls, seed, body, enclosing_label=None, sanitize_key=False - ) -> "_anonymous_label": + cls, + seed: int, + body: str, + enclosing_label: Optional[str] = None, + sanitize_key: bool = False, + ) -> _anonymous_label: if sanitize_key: body = re.sub(r"[%\(\) \$]+", "_", body).strip("_") diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index f08e71bcd..7db1631c8 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -12,6 +12,7 @@ from __future__ import annotations +from enum import IntEnum from operator import add as _uncast_add from operator import and_ as _uncast_and_ from operator import contains as _uncast_contains @@ -36,15 +37,17 @@ import typing from typing import Any from typing import Callable from typing import cast +from typing import Dict from typing import Generic from typing import Optional -from typing import overload +from typing import Set from typing import Type from typing import TypeVar from typing import Union from .. import exc from .. import util +from ..util.typing import Literal from ..util.typing import Protocol if typing.TYPE_CHECKING: @@ -52,8 +55,8 @@ if typing.TYPE_CHECKING: from .elements import ColumnElement from .type_api import TypeEngine -_OP_RETURN = TypeVar("_OP_RETURN", bound=Any, covariant=True) _T = TypeVar("_T", bound=Any) +_FN = TypeVar("_FN", bound=Callable[..., Any]) class OperatorType(Protocol): @@ -64,8 +67,12 @@ class OperatorType(Protocol): __name__: str def __call__( - self, left: "Operators[_OP_RETURN]", *other: Any, **kwargs: Any - ) -> "_OP_RETURN": + self, + left: "Operators", + right: Optional[Any] = None, + *other: Any, + **kwargs: Any, + ) -> "Operators": ... @@ -91,7 +98,7 @@ sub = cast(OperatorType, _uncast_sub) truediv = cast(OperatorType, _uncast_truediv) -class Operators(Generic[_OP_RETURN]): +class Operators: """Base of comparison and logical operators. Implements base methods @@ -108,7 +115,7 @@ class Operators(Generic[_OP_RETURN]): __slots__ = () - def __and__(self, other: Any) -> "Operators": + def __and__(self, other: Any) -> Operators: """Implement the ``&`` operator. When used with SQL expressions, results in an @@ -132,7 +139,7 @@ class Operators(Generic[_OP_RETURN]): """ return self.operate(and_, other) - def __or__(self, other: Any) -> "Operators": + def __or__(self, other: Any) -> Operators: """Implement the ``|`` operator. When used with SQL expressions, results in an @@ -156,7 +163,7 @@ class Operators(Generic[_OP_RETURN]): """ return self.operate(or_, other) - def __invert__(self) -> "Operators": + def __invert__(self) -> Operators: """Implement the ``~`` operator. When used with SQL expressions, results in a @@ -175,14 +182,14 @@ class Operators(Generic[_OP_RETURN]): def op( self, - opstring: Any, + opstring: str, precedence: int = 0, is_comparison: bool = False, return_type: Optional[ Union[Type["TypeEngine[Any]"], "TypeEngine[Any]"] ] = None, - python_impl=None, - ) -> Callable[[Any], Any]: + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], Operators]: """Produce a generic operator function. e.g.:: @@ -200,7 +207,7 @@ class Operators(Generic[_OP_RETURN]): is a bitwise AND of the value in ``somecolumn``. - :param operator: a string which will be output as the infix operator + :param opstring: a string which will be output as the infix operator between this element and the expression passed to the generated function. @@ -263,14 +270,17 @@ class Operators(Generic[_OP_RETURN]): python_impl=python_impl, ) - def against(other: Any) -> _OP_RETURN: - return operator(self, other) + def against(other: Any) -> Operators: + return operator(self, other) # type: ignore return against def bool_op( - self, opstring: Any, precedence: int = 0, python_impl=None - ) -> Callable[[Any], Any]: + self, + opstring: str, + precedence: int = 0, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], Operators]: """Return a custom boolean operator. This method is shorthand for calling @@ -292,7 +302,7 @@ class Operators(Generic[_OP_RETURN]): def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> _OP_RETURN: + ) -> Operators: r"""Operate on an argument. This is the lowest level of operation, raises @@ -322,7 +332,7 @@ class Operators(Generic[_OP_RETURN]): def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> _OP_RETURN: + ) -> Operators: """Reverse operate on an argument. Usage is the same as :meth:`operate`. @@ -379,7 +389,7 @@ class custom_op(OperatorType, Generic[_T]): ] = None, natural_self_precedent: bool = False, eager_grouping: bool = False, - python_impl=None, + python_impl: Optional[Callable[..., Any]] = None, ): self.opstring = opstring self.precedence = precedence @@ -397,25 +407,17 @@ class custom_op(OperatorType, Generic[_T]): def __hash__(self) -> int: return id(self) - @overload def __call__( - self, left: "ColumnElement", right: Any, **kw - ) -> "BinaryExpression[_T]": - ... - - @overload - def __call__( - self, left: "Operators[_OP_RETURN]", right: Any, **kw - ) -> _OP_RETURN: - ... - - def __call__( - self, left: "Operators[_OP_RETURN]", right: Any, **kw - ) -> _OP_RETURN: + self, + left: Operators, + right: Optional[Any] = None, + *other: Any, + **kwargs: Any, + ) -> Operators: if hasattr(left, "__sa_operate__"): - return left.operate(self, right, **kw) + return left.operate(self, right, *other, **kwargs) elif self.python_impl: - return self.python_impl(left, right, **kw) + return self.python_impl(left, right, *other, **kwargs) # type: ignore # noqa E501 else: raise exc.InvalidRequestError( f"Custom operator {self.opstring!r} can't be used with " @@ -424,7 +426,7 @@ class custom_op(OperatorType, Generic[_T]): ) -class ColumnOperators(Operators[_OP_RETURN]): +class ColumnOperators(Operators): """Defines boolean, comparison, and other operators for :class:`_expression.ColumnElement` expressions. @@ -464,22 +466,22 @@ class ColumnOperators(Operators[_OP_RETURN]): __slots__ = () - timetuple = None + timetuple: Literal[None] = None """Hack, allows datetime objects to be compared on the LHS.""" if typing.TYPE_CHECKING: def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> "ColumnOperators": + ) -> ColumnOperators: ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> "ColumnOperators": + ) -> ColumnOperators: ... - def __lt__(self, other: Any) -> "ColumnOperators": + def __lt__(self, other: Any) -> ColumnOperators: """Implement the ``<`` operator. In a column context, produces the clause ``a < b``. @@ -487,7 +489,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(lt, other) - def __le__(self, other: Any) -> "ColumnOperators": + def __le__(self, other: Any) -> ColumnOperators: """Implement the ``<=`` operator. In a column context, produces the clause ``a <= b``. @@ -498,7 +500,7 @@ class ColumnOperators(Operators[_OP_RETURN]): # TODO: not sure why we have this __hash__ = Operators.__hash__ # type: ignore - def __eq__(self, other: Any) -> "ColumnOperators": + def __eq__(self, other: Any) -> ColumnOperators: # type: ignore[override] """Implement the ``==`` operator. In a column context, produces the clause ``a = b``. @@ -507,7 +509,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(eq, other) - def __ne__(self, other: Any) -> "ColumnOperators": + def __ne__(self, other: Any) -> ColumnOperators: # type: ignore[override] """Implement the ``!=`` operator. In a column context, produces the clause ``a != b``. @@ -516,7 +518,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(ne, other) - def is_distinct_from(self, other: Any) -> "ColumnOperators": + def is_distinct_from(self, other: Any) -> ColumnOperators: """Implement the ``IS DISTINCT FROM`` operator. Renders "a IS DISTINCT FROM b" on most platforms; @@ -527,7 +529,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(is_distinct_from, other) - def is_not_distinct_from(self, other: Any) -> "ColumnOperators": + def is_not_distinct_from(self, other: Any) -> ColumnOperators: """Implement the ``IS NOT DISTINCT FROM`` operator. Renders "a IS NOT DISTINCT FROM b" on most platforms; @@ -545,7 +547,7 @@ class ColumnOperators(Operators[_OP_RETURN]): # deprecated 1.4; see #5435 isnot_distinct_from = is_not_distinct_from - def __gt__(self, other: Any) -> "ColumnOperators": + def __gt__(self, other: Any) -> ColumnOperators: """Implement the ``>`` operator. In a column context, produces the clause ``a > b``. @@ -553,7 +555,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(gt, other) - def __ge__(self, other: Any) -> "ColumnOperators": + def __ge__(self, other: Any) -> ColumnOperators: """Implement the ``>=`` operator. In a column context, produces the clause ``a >= b``. @@ -561,7 +563,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(ge, other) - def __neg__(self) -> "ColumnOperators": + def __neg__(self) -> ColumnOperators: """Implement the ``-`` operator. In a column context, produces the clause ``-a``. @@ -569,10 +571,10 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(neg) - def __contains__(self, other: Any) -> "ColumnOperators": + def __contains__(self, other: Any) -> ColumnOperators: return self.operate(contains, other) - def __getitem__(self, index: Any) -> "ColumnOperators": + def __getitem__(self, index: Any) -> ColumnOperators: """Implement the [] operator. This can be used by some database-specific types @@ -581,7 +583,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(getitem, index) - def __lshift__(self, other: Any) -> "ColumnOperators": + def __lshift__(self, other: Any) -> ColumnOperators: """implement the << operator. Not used by SQLAlchemy core, this is provided @@ -590,7 +592,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(lshift, other) - def __rshift__(self, other: Any) -> "ColumnOperators": + def __rshift__(self, other: Any) -> ColumnOperators: """implement the >> operator. Not used by SQLAlchemy core, this is provided @@ -599,7 +601,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(rshift, other) - def concat(self, other: Any) -> "ColumnOperators": + def concat(self, other: Any) -> ColumnOperators: """Implement the 'concat' operator. In a column context, produces the clause ``a || b``, @@ -608,7 +610,9 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(concat_op, other) - def like(self, other: Any, escape=None) -> "ColumnOperators": + def like( + self, other: Any, escape: Optional[str] = None + ) -> ColumnOperators: r"""Implement the ``like`` operator. In a column context, produces the expression:: @@ -633,7 +637,9 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(like_op, other, escape=escape) - def ilike(self, other: Any, escape=None) -> "ColumnOperators": + def ilike( + self, other: Any, escape: Optional[str] = None + ) -> ColumnOperators: r"""Implement the ``ilike`` operator, e.g. case insensitive LIKE. In a column context, produces an expression either of the form:: @@ -662,7 +668,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(ilike_op, other, escape=escape) - def in_(self, other: Any) -> "ColumnOperators": + def in_(self, other: Any) -> ColumnOperators: """Implement the ``in`` operator. In a column context, produces the clause ``column IN <other>``. @@ -751,7 +757,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(in_op, other) - def not_in(self, other: Any) -> "ColumnOperators": + def not_in(self, other: Any) -> ColumnOperators: """implement the ``NOT IN`` operator. This is equivalent to using negation with @@ -782,7 +788,9 @@ class ColumnOperators(Operators[_OP_RETURN]): # deprecated 1.4; see #5429 notin_ = not_in - def not_like(self, other: Any, escape=None) -> "ColumnOperators": + def not_like( + self, other: Any, escape: Optional[str] = None + ) -> ColumnOperators: """implement the ``NOT LIKE`` operator. This is equivalent to using negation with @@ -802,7 +810,9 @@ class ColumnOperators(Operators[_OP_RETURN]): # deprecated 1.4; see #5435 notlike = not_like - def not_ilike(self, other: Any, escape=None) -> "ColumnOperators": + def not_ilike( + self, other: Any, escape: Optional[str] = None + ) -> ColumnOperators: """implement the ``NOT ILIKE`` operator. This is equivalent to using negation with @@ -822,7 +832,7 @@ class ColumnOperators(Operators[_OP_RETURN]): # deprecated 1.4; see #5435 notilike = not_ilike - def is_(self, other: Any) -> "ColumnOperators": + def is_(self, other: Any) -> ColumnOperators: """Implement the ``IS`` operator. Normally, ``IS`` is generated automatically when comparing to a @@ -835,7 +845,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(is_, other) - def is_not(self, other: Any) -> "ColumnOperators": + def is_not(self, other: Any) -> ColumnOperators: """Implement the ``IS NOT`` operator. Normally, ``IS NOT`` is generated automatically when comparing to a @@ -856,8 +866,11 @@ class ColumnOperators(Operators[_OP_RETURN]): isnot = is_not def startswith( - self, other: Any, escape=None, autoescape=False - ) -> "ColumnOperators": + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnOperators: r"""Implement the ``startswith`` operator. Produces a LIKE expression that tests against a match for the start @@ -939,8 +952,11 @@ class ColumnOperators(Operators[_OP_RETURN]): ) def endswith( - self, other: Any, escape=None, autoescape=False - ) -> "ColumnOperators": + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnOperators: r"""Implement the 'endswith' operator. Produces a LIKE expression that tests against a match for the end @@ -1021,7 +1037,7 @@ class ColumnOperators(Operators[_OP_RETURN]): endswith_op, other, escape=escape, autoescape=autoescape ) - def contains(self, other: Any, **kw: Any) -> "ColumnOperators": + def contains(self, other: Any, **kw: Any) -> ColumnOperators: r"""Implement the 'contains' operator. Produces a LIKE expression that tests against a match for the middle @@ -1101,7 +1117,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(contains_op, other, **kw) - def match(self, other: Any, **kwargs) -> "ColumnOperators": + def match(self, other: Any, **kwargs: Any) -> ColumnOperators: """Implements a database-specific 'match' operator. :meth:`_sql.ColumnOperators.match` attempts to resolve to @@ -1125,7 +1141,9 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(match_op, other, **kwargs) - def regexp_match(self, pattern, flags=None) -> "ColumnOperators": + def regexp_match( + self, pattern: Any, flags: Optional[str] = None + ) -> ColumnOperators: """Implements a database-specific 'regexp match' operator. E.g.:: @@ -1174,8 +1192,8 @@ class ColumnOperators(Operators[_OP_RETURN]): return self.operate(regexp_match_op, pattern, flags=flags) def regexp_replace( - self, pattern, replacement, flags=None - ) -> "ColumnOperators": + self, pattern: Any, replacement: Any, flags: Optional[str] = None + ) -> ColumnOperators: """Implements a database-specific 'regexp replace' operator. E.g.:: @@ -1220,17 +1238,17 @@ class ColumnOperators(Operators[_OP_RETURN]): flags=flags, ) - def desc(self) -> "ColumnOperators": + def desc(self) -> ColumnOperators: """Produce a :func:`_expression.desc` clause against the parent object.""" return self.operate(desc_op) - def asc(self) -> "ColumnOperators": + def asc(self) -> ColumnOperators: """Produce a :func:`_expression.asc` clause against the parent object.""" return self.operate(asc_op) - def nulls_first(self) -> "ColumnOperators": + def nulls_first(self) -> ColumnOperators: """Produce a :func:`_expression.nulls_first` clause against the parent object. @@ -1243,7 +1261,7 @@ class ColumnOperators(Operators[_OP_RETURN]): # deprecated 1.4; see #5435 nullsfirst = nulls_first - def nulls_last(self) -> "ColumnOperators": + def nulls_last(self) -> ColumnOperators: """Produce a :func:`_expression.nulls_last` clause against the parent object. @@ -1256,7 +1274,7 @@ class ColumnOperators(Operators[_OP_RETURN]): # deprecated 1.4; see #5429 nullslast = nulls_last - def collate(self, collation) -> "ColumnOperators": + def collate(self, collation: str) -> ColumnOperators: """Produce a :func:`_expression.collate` clause against the parent object, given the collation string. @@ -1267,7 +1285,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(collate, collation) - def __radd__(self, other: Any) -> "ColumnOperators": + def __radd__(self, other: Any) -> ColumnOperators: """Implement the ``+`` operator in reverse. See :meth:`.ColumnOperators.__add__`. @@ -1275,7 +1293,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.reverse_operate(add, other) - def __rsub__(self, other: Any) -> "ColumnOperators": + def __rsub__(self, other: Any) -> ColumnOperators: """Implement the ``-`` operator in reverse. See :meth:`.ColumnOperators.__sub__`. @@ -1283,7 +1301,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.reverse_operate(sub, other) - def __rmul__(self, other: Any) -> "ColumnOperators": + def __rmul__(self, other: Any) -> ColumnOperators: """Implement the ``*`` operator in reverse. See :meth:`.ColumnOperators.__mul__`. @@ -1291,7 +1309,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.reverse_operate(mul, other) - def __rmod__(self, other: Any) -> "ColumnOperators": + def __rmod__(self, other: Any) -> ColumnOperators: """Implement the ``%`` operator in reverse. See :meth:`.ColumnOperators.__mod__`. @@ -1299,21 +1317,23 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.reverse_operate(mod, other) - def between(self, cleft, cright, symmetric=False) -> "ColumnOperators": + def between( + self, cleft: Any, cright: Any, symmetric: bool = False + ) -> ColumnOperators: """Produce a :func:`_expression.between` clause against the parent object, given the lower and upper range. """ return self.operate(between_op, cleft, cright, symmetric=symmetric) - def distinct(self) -> "ColumnOperators": + def distinct(self) -> ColumnOperators: """Produce a :func:`_expression.distinct` clause against the parent object. """ return self.operate(distinct_op) - def any_(self) -> "ColumnOperators": + def any_(self) -> ColumnOperators: """Produce an :func:`_expression.any_` clause against the parent object. @@ -1330,7 +1350,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(any_op) - def all_(self) -> "ColumnOperators": + def all_(self) -> ColumnOperators: """Produce an :func:`_expression.all_` clause against the parent object. @@ -1348,7 +1368,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(all_op) - def __add__(self, other: Any) -> "ColumnOperators": + def __add__(self, other: Any) -> ColumnOperators: """Implement the ``+`` operator. In a column context, produces the clause ``a + b`` @@ -1360,7 +1380,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(add, other) - def __sub__(self, other: Any) -> "ColumnOperators": + def __sub__(self, other: Any) -> ColumnOperators: """Implement the ``-`` operator. In a column context, produces the clause ``a - b``. @@ -1368,7 +1388,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(sub, other) - def __mul__(self, other: Any) -> "ColumnOperators": + def __mul__(self, other: Any) -> ColumnOperators: """Implement the ``*`` operator. In a column context, produces the clause ``a * b``. @@ -1376,7 +1396,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(mul, other) - def __mod__(self, other: Any) -> "ColumnOperators": + def __mod__(self, other: Any) -> ColumnOperators: """Implement the ``%`` operator. In a column context, produces the clause ``a % b``. @@ -1384,7 +1404,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(mod, other) - def __truediv__(self, other: Any) -> "ColumnOperators": + def __truediv__(self, other: Any) -> ColumnOperators: """Implement the ``/`` operator. In a column context, produces the clause ``a / b``, and @@ -1397,7 +1417,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(truediv, other) - def __rtruediv__(self, other: Any) -> "ColumnOperators": + def __rtruediv__(self, other: Any) -> ColumnOperators: """Implement the ``/`` operator in reverse. See :meth:`.ColumnOperators.__truediv__`. @@ -1405,7 +1425,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.reverse_operate(truediv, other) - def __floordiv__(self, other: Any) -> "ColumnOperators": + def __floordiv__(self, other: Any) -> ColumnOperators: """Implement the ``//`` operator. In a column context, produces the clause ``a / b``, @@ -1417,7 +1437,7 @@ class ColumnOperators(Operators[_OP_RETURN]): """ return self.operate(floordiv, other) - def __rfloordiv__(self, other: Any) -> "ColumnOperators": + def __rfloordiv__(self, other: Any) -> ColumnOperators: """Implement the ``//`` operator in reverse. See :meth:`.ColumnOperators.__floordiv__`. @@ -1426,43 +1446,47 @@ class ColumnOperators(Operators[_OP_RETURN]): return self.reverse_operate(floordiv, other) -_commutative = {eq, ne, add, mul} -_comparison = {eq, ne, lt, gt, ge, le} +_commutative: Set[Any] = {eq, ne, add, mul} +_comparison: Set[Any] = {eq, ne, lt, gt, ge, le} -def _operator_fn(fn): +def _operator_fn(fn: Callable[..., Any]) -> OperatorType: return cast(OperatorType, fn) -def commutative_op(fn): +def commutative_op(fn: _FN) -> _FN: _commutative.add(fn) return fn -def comparison_op(fn): +def comparison_op(fn: _FN) -> _FN: _comparison.add(fn) return fn -def from_(): +@_operator_fn +def from_() -> Any: raise NotImplementedError() +@_operator_fn @comparison_op -def function_as_comparison_op(): +def function_as_comparison_op() -> Any: raise NotImplementedError() -def as_(): +@_operator_fn +def as_() -> Any: raise NotImplementedError() -def exists(): +@_operator_fn +def exists() -> Any: raise NotImplementedError() @_operator_fn -def is_true(a): +def is_true(a: Any) -> Any: raise NotImplementedError() @@ -1471,7 +1495,7 @@ istrue = is_true @_operator_fn -def is_false(a): +def is_false(a: Any) -> Any: raise NotImplementedError() @@ -1481,13 +1505,13 @@ isfalse = is_false @comparison_op @_operator_fn -def is_distinct_from(a, b): +def is_distinct_from(a: Any, b: Any) -> Any: return a.is_distinct_from(b) @comparison_op @_operator_fn -def is_not_distinct_from(a, b): +def is_not_distinct_from(a: Any, b: Any) -> Any: return a.is_not_distinct_from(b) @@ -1497,13 +1521,13 @@ isnot_distinct_from = is_not_distinct_from @comparison_op @_operator_fn -def is_(a, b): +def is_(a: Any, b: Any) -> Any: return a.is_(b) @comparison_op @_operator_fn -def is_not(a, b): +def is_not(a: Any, b: Any) -> Any: return a.is_not(b) @@ -1512,24 +1536,24 @@ isnot = is_not @_operator_fn -def collate(a, b): +def collate(a: Any, b: Any) -> Any: return a.collate(b) @_operator_fn -def op(a, opstring, b): +def op(a: Any, opstring: str, b: Any) -> Any: return a.op(opstring)(b) @comparison_op @_operator_fn -def like_op(a, b, escape=None): +def like_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: return a.like(b, escape=escape) @comparison_op @_operator_fn -def not_like_op(a, b, escape=None): +def not_like_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: return a.notlike(b, escape=escape) @@ -1539,13 +1563,13 @@ notlike_op = not_like_op @comparison_op @_operator_fn -def ilike_op(a, b, escape=None): +def ilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: return a.ilike(b, escape=escape) @comparison_op @_operator_fn -def not_ilike_op(a, b, escape=None): +def not_ilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: return a.not_ilike(b, escape=escape) @@ -1555,13 +1579,13 @@ notilike_op = not_ilike_op @comparison_op @_operator_fn -def between_op(a, b, c, symmetric=False): +def between_op(a: Any, b: Any, c: Any, symmetric: bool = False) -> Any: return a.between(b, c, symmetric=symmetric) @comparison_op @_operator_fn -def not_between_op(a, b, c, symmetric=False): +def not_between_op(a: Any, b: Any, c: Any, symmetric: bool = False) -> Any: return ~a.between(b, c, symmetric=symmetric) @@ -1571,13 +1595,13 @@ notbetween_op = not_between_op @comparison_op @_operator_fn -def in_op(a, b): +def in_op(a: Any, b: Any) -> Any: return a.in_(b) @comparison_op @_operator_fn -def not_in_op(a, b): +def not_in_op(a: Any, b: Any) -> Any: return a.not_in(b) @@ -1586,21 +1610,23 @@ notin_op = not_in_op @_operator_fn -def distinct_op(a): +def distinct_op(a: Any) -> Any: return a.distinct() @_operator_fn -def any_op(a): +def any_op(a: Any) -> Any: return a.any_() @_operator_fn -def all_op(a): +def all_op(a: Any) -> Any: return a.all_() -def _escaped_like_impl(fn, other: Any, escape, autoescape): +def _escaped_like_impl( + fn: Callable[..., Any], other: Any, escape: Optional[str], autoescape: bool +) -> Any: if autoescape: if autoescape is not True: util.warn( @@ -1622,13 +1648,17 @@ def _escaped_like_impl(fn, other: Any, escape, autoescape): @comparison_op @_operator_fn -def startswith_op(a, b, escape=None, autoescape=False): +def startswith_op( + a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False +) -> Any: return _escaped_like_impl(a.startswith, b, escape, autoescape) @comparison_op @_operator_fn -def not_startswith_op(a, b, escape=None, autoescape=False): +def not_startswith_op( + a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False +) -> Any: return ~_escaped_like_impl(a.startswith, b, escape, autoescape) @@ -1638,13 +1668,17 @@ notstartswith_op = not_startswith_op @comparison_op @_operator_fn -def endswith_op(a, b, escape=None, autoescape=False): +def endswith_op( + a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False +) -> Any: return _escaped_like_impl(a.endswith, b, escape, autoescape) @comparison_op @_operator_fn -def not_endswith_op(a, b, escape=None, autoescape=False): +def not_endswith_op( + a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False +) -> Any: return ~_escaped_like_impl(a.endswith, b, escape, autoescape) @@ -1654,13 +1688,17 @@ notendswith_op = not_endswith_op @comparison_op @_operator_fn -def contains_op(a, b, escape=None, autoescape=False): +def contains_op( + a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False +) -> Any: return _escaped_like_impl(a.contains, b, escape, autoescape) @comparison_op @_operator_fn -def not_contains_op(a, b, escape=None, autoescape=False): +def not_contains_op( + a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False +) -> Any: return ~_escaped_like_impl(a.contains, b, escape, autoescape) @@ -1670,30 +1708,32 @@ notcontains_op = not_contains_op @comparison_op @_operator_fn -def match_op(a, b, **kw): +def match_op(a: Any, b: Any, **kw: Any) -> Any: return a.match(b, **kw) @comparison_op @_operator_fn -def regexp_match_op(a, b, flags=None): +def regexp_match_op(a: Any, b: Any, flags: Optional[str] = None) -> Any: return a.regexp_match(b, flags=flags) @comparison_op @_operator_fn -def not_regexp_match_op(a, b, flags=None): +def not_regexp_match_op(a: Any, b: Any, flags: Optional[str] = None) -> Any: return ~a.regexp_match(b, flags=flags) @_operator_fn -def regexp_replace_op(a, b, replacement, flags=None): +def regexp_replace_op( + a: Any, b: Any, replacement: Any, flags: Optional[str] = None +) -> Any: return a.regexp_replace(b, replacement=replacement, flags=flags) @comparison_op @_operator_fn -def not_match_op(a, b, **kw): +def not_match_op(a: Any, b: Any, **kw: Any) -> Any: return ~a.match(b, **kw) @@ -1702,32 +1742,32 @@ notmatch_op = not_match_op @_operator_fn -def comma_op(a, b): +def comma_op(a: Any, b: Any) -> Any: raise NotImplementedError() @_operator_fn -def filter_op(a, b): +def filter_op(a: Any, b: Any) -> Any: raise NotImplementedError() @_operator_fn -def concat_op(a, b): +def concat_op(a: Any, b: Any) -> Any: return a.concat(b) @_operator_fn -def desc_op(a): +def desc_op(a: Any) -> Any: return a.desc() @_operator_fn -def asc_op(a): +def asc_op(a: Any) -> Any: return a.asc() @_operator_fn -def nulls_first_op(a): +def nulls_first_op(a: Any) -> Any: return a.nulls_first() @@ -1736,7 +1776,7 @@ nullsfirst_op = nulls_first_op @_operator_fn -def nulls_last_op(a): +def nulls_last_op(a: Any) -> Any: return a.nulls_last() @@ -1745,28 +1785,28 @@ nullslast_op = nulls_last_op @_operator_fn -def json_getitem_op(a, b): +def json_getitem_op(a: Any, b: Any) -> Any: raise NotImplementedError() @_operator_fn -def json_path_getitem_op(a, b): +def json_path_getitem_op(a: Any, b: Any) -> Any: raise NotImplementedError() -def is_comparison(op): +def is_comparison(op: OperatorType) -> bool: return op in _comparison or isinstance(op, custom_op) and op.is_comparison -def is_commutative(op): +def is_commutative(op: OperatorType) -> bool: return op in _commutative -def is_ordering_modifier(op): +def is_ordering_modifier(op: OperatorType) -> bool: return op in (asc_op, desc_op, nulls_first_op, nulls_last_op) -def is_natural_self_precedent(op): +def is_natural_self_precedent(op: OperatorType) -> bool: return ( op in _natural_self_precedent or isinstance(op, custom_op) @@ -1777,14 +1817,14 @@ def is_natural_self_precedent(op): _booleans = (inv, is_true, is_false, and_, or_) -def is_boolean(op): +def is_boolean(op: OperatorType) -> bool: return is_comparison(op) or op in _booleans _mirror = {gt: lt, ge: le, lt: gt, le: ge} -def mirror(op): +def mirror(op: OperatorType) -> OperatorType: """rotate a comparison operator 180 degrees. Note this is not the same as negation. @@ -1796,7 +1836,7 @@ def mirror(op): _associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne]) -def is_associative(op): +def is_associative(op: OperatorType) -> bool: return op in _associative @@ -1809,11 +1849,17 @@ parenthesize (a op b). """ -_asbool = util.symbol("_asbool", canonical=-10) -_smallest = util.symbol("_smallest", canonical=-100) -_largest = util.symbol("_largest", canonical=100) +@_operator_fn +def _asbool(a: Any) -> Any: + raise NotImplementedError() -_PRECEDENCE = { + +class _OpLimit(IntEnum): + _smallest = -100 + _largest = 100 + + +_PRECEDENCE: Dict[OperatorType, int] = { from_: 15, function_as_comparison_op: 15, any_op: 15, @@ -1866,15 +1912,18 @@ _PRECEDENCE = { as_: -1, exists: 0, _asbool: -10, - _smallest: _smallest, - _largest: _largest, } -def is_precedent(operator, against): +def is_precedent(operator: OperatorType, against: OperatorType) -> bool: if operator is against and is_natural_self_precedent(operator): return False else: - return _PRECEDENCE.get( - operator, getattr(operator, "precedence", _smallest) - ) <= _PRECEDENCE.get(against, getattr(against, "precedence", _largest)) + return bool( + _PRECEDENCE.get( + operator, getattr(operator, "precedence", _OpLimit._smallest) + ) + <= _PRECEDENCE.get( + against, getattr(against, "precedence", _OpLimit._largest) + ) + ) diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 1a7a5f4d4..4c4f49aa8 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -8,23 +8,28 @@ from __future__ import annotations import typing from typing import Any +from typing import Generic from typing import Iterable -from typing import Mapping from typing import Optional -from typing import Sequence +from typing import TypeVar from .. import util from ..util import TypingOnly from ..util.typing import Literal if typing.TYPE_CHECKING: + from ._typing import _PropagateAttrsType from .base import ColumnCollection from .elements import ClauseElement + from .elements import ColumnElement from .elements import Label from .selectable import FromClause from .selectable import Subquery +_T = TypeVar("_T", bound=Any) + + class SQLRole: """Define a "role" within a SQL statement structure. @@ -104,7 +109,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): _role_name = "Column expression or FROM clause" @property - def _select_iterable(self) -> Sequence[ColumnsClauseRole]: + def _select_iterable(self) -> Iterable[ColumnsClauseRole]: raise NotImplementedError() @@ -154,24 +159,24 @@ class WhereHavingRole(OnClauseRole): _role_name = "SQL expression for WHERE/HAVING role" -class ExpressionElementRole(SQLRole): +class ExpressionElementRole(Generic[_T], SQLRole): __slots__ = () _role_name = "SQL expression element" - def label(self, name: Optional[str]) -> Label[Any]: + def label(self, name: Optional[str]) -> Label[_T]: raise NotImplementedError() -class ConstExprRole(ExpressionElementRole): +class ConstExprRole(ExpressionElementRole[_T]): __slots__ = () _role_name = "Constant True/False/None expression" -class LabeledColumnExprRole(ExpressionElementRole): +class LabeledColumnExprRole(ExpressionElementRole[_T]): __slots__ = () -class BinaryElementRole(ExpressionElementRole): +class BinaryElementRole(ExpressionElementRole[_T]): __slots__ = () _role_name = "SQL expression element or literal value" @@ -235,7 +240,7 @@ class StatementRole(SQLRole): __slots__ = () _role_name = "Executable SQL or text() construct" - _propagate_attrs: Mapping[str, Any] = util.immutabledict() + _propagate_attrs: _PropagateAttrsType = util.immutabledict() class SelectStatementRole(StatementRole, ReturnsRowsRole): @@ -317,7 +322,18 @@ class HasClauseElement(TypingOnly): if typing.TYPE_CHECKING: - def __clause_element__(self) -> "ClauseElement": + def __clause_element__(self) -> ClauseElement: + ... + + +class HasColumnElementClauseElement(TypingOnly): + """indicates a class that has a __clause_element__() method""" + + __slots__ = () + + if typing.TYPE_CHECKING: + + def __clause_element__(self) -> ColumnElement[Any]: ... @@ -328,5 +344,5 @@ class HasFromClauseElement(HasClauseElement, TypingOnly): if typing.TYPE_CHECKING: - def __clause_element__(self) -> "FromClause": + def __clause_element__(self) -> FromClause: ... diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 33e300bf6..78d524127 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2671,6 +2671,7 @@ class DefaultGenerator(Executable, SchemaItem): is_sequence = False is_server_default = False + is_scalar = False column = None def __init__(self, for_update=False): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 09befb078..0692483e9 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -20,9 +20,11 @@ from operator import attrgetter import typing from typing import Any as TODO_Any from typing import Any +from typing import Iterable from typing import NamedTuple from typing import Optional from typing import Tuple +from typing import TYPE_CHECKING from typing import TypeVar from . import cache_key @@ -454,7 +456,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): __visit_name__ = "fromclause" named_with_column = False - _hide_froms = [] + + @property + def _hide_froms(self) -> Iterable[FromClause]: + return () + + _is_clone_of: Optional[FromClause] schema = None """Define the 'schema' attribute for this :class:`_expression.FromClause`. @@ -667,7 +674,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self._cloned_set.intersection(other._cloned_set) @property - def description(self): + def description(self) -> str: """A brief description of this :class:`_expression.FromClause`. Used primarily for error message formatting. @@ -703,7 +710,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.columns @util.memoized_property - def columns(self): + def columns(self) -> ColumnCollection: """A named-based collection of :class:`_expression.ColumnElement` objects maintained by this :class:`_expression.FromClause`. @@ -787,19 +794,24 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): for key in ["_columns", "columns", "primary_key", "foreign_keys"]: self.__dict__.pop(key, None) - c = property( - attrgetter("columns"), - doc=""" - A named-based collection of :class:`_expression.ColumnElement` - objects maintained by this :class:`_expression.FromClause`. + # this is awkward. maybe there's a better way + if TYPE_CHECKING: + c: ColumnCollection + else: + c = property( + attrgetter("columns"), + doc=""" + A named-based collection of :class:`_expression.ColumnElement` + objects maintained by this :class:`_expression.FromClause`. - The :attr:`_sql.FromClause.c` attribute is an alias for the - :attr:`_sql.FromClause.columns` attribute. + The :attr:`_sql.FromClause.c` attribute is an alias for the + :attr:`_sql.FromClause.columns` attribute. - :return: a :class:`.ColumnCollection` + :return: a :class:`.ColumnCollection` + + """, + ) - """, - ) _select_iterable = property(attrgetter("columns")) def _init_collections(self): @@ -1015,7 +1027,7 @@ class Join(roles.DMLTableRole, FromClause): self.full = full @property - def description(self): + def description(self) -> str: return "Join object on %s(%d) and %s(%d)" % ( self.left.description, id(self.left), @@ -1289,7 +1301,7 @@ class Join(roles.DMLTableRole, FromClause): ) @property - def _hide_froms(self): + def _hide_froms(self) -> Iterable[FromClause]: return itertools.chain( *[_from_objects(x.left, x.right) for x in self._cloned_set] ) @@ -1370,7 +1382,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): self.element._refresh_for_new_column(column) @property - def description(self): + def description(self) -> str: name = self.name if isinstance(name, _anonymous_label): name = "anon_1" @@ -2301,6 +2313,8 @@ class FromGrouping(GroupedElement, FromClause): _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + element: FromClause + def __init__(self, element): self.element = coercions.expect(roles.FromClauseRole, element) @@ -2329,7 +2343,7 @@ class FromGrouping(GroupedElement, FromClause): return FromGrouping(self.element._anonymous_fromclause(**kw)) @property - def _hide_froms(self): + def _hide_froms(self) -> Iterable[FromClause]: return self.element._hide_froms @property @@ -2425,7 +2439,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): pass @util.memoized_property - def description(self): + def description(self) -> str: return self.name def append_column(self, c, **kw): diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index b2b1d9bc2..e64ec0843 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -18,7 +18,6 @@ import json import pickle from typing import Any from typing import Sequence -from typing import Text as typing_Text from typing import Tuple from typing import TypeVar from typing import Union @@ -132,7 +131,7 @@ class Indexable: comparator_factory = Comparator -class String(Concatenable, TypeEngine[typing_Text]): +class String(Concatenable, TypeEngine[str]): """The base for all string and character types. @@ -2793,11 +2792,13 @@ class ARRAY( self.item_type._set_parent_with_dispatch(parent) -class TupleType(TypeEngine[Tuple[Any]]): +class TupleType(TypeEngine[Tuple[Any, ...]]): """represent the composite type of a Tuple.""" _is_tuple_type = True + types: List[TypeEngine[Any]] + def __init__(self, *types): self._fully_typed = NULLTYPE not in types self.types = [ @@ -2805,7 +2806,7 @@ class TupleType(TypeEngine[Tuple[Any]]): for item_type in types ] - def _resolve_values_to_types(self, value): + def _resolve_values_to_types(self, value: Any) -> TupleType: if self._fully_typed: return self else: diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index cf9487f93..1f3d50876 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -17,6 +17,7 @@ from typing import Any from typing import Callable from typing import Deque from typing import Dict +from typing import Iterable from typing import Set from typing import Tuple from typing import Type @@ -226,7 +227,9 @@ class HasCopyInternals(HasTraverseInternals): def _clone(self, **kw): raise NotImplementedError() - def _copy_internals(self, omit_attrs=(), **kw): + def _copy_internals( + self, omit_attrs: Iterable[str] = (), **kw: Any + ) -> None: """Reassign internal elements to be clones of themselves. Called during a copy-and-traverse operation on newly diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index f76b4e462..55997556a 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -29,26 +29,19 @@ from .. import exc from .. import util # these are back-assigned by sqltypes. -if not typing.TYPE_CHECKING: - BOOLEANTYPE = None - INTEGERTYPE = None - NULLTYPE = None - STRINGTYPE = None - MATCHTYPE = None - INDEXABLE = None - TABLEVALUE = None - _resolve_value_to_type = None - if typing.TYPE_CHECKING: from .elements import ColumnElement from .operators import OperatorType - from .sqltypes import _resolve_value_to_type - from .sqltypes import Boolean as BOOLEANTYPE # noqa - from .sqltypes import Indexable as INDEXABLE # noqa - from .sqltypes import MatchType as MATCHTYPE # noqa - from .sqltypes import NULLTYPE + from .sqltypes import _resolve_value_to_type as _resolve_value_to_type + from .sqltypes import BOOLEANTYPE as BOOLEANTYPE + from .sqltypes import Indexable as INDEXABLE + from .sqltypes import INTEGERTYPE as INTEGERTYPE + from .sqltypes import MATCHTYPE as MATCHTYPE + from .sqltypes import NULLTYPE as NULLTYPE + _T = TypeVar("_T", bound=Any) +_TE = TypeVar("_TE", bound="TypeEngine[Any]") _CT = TypeVar("_CT", bound=Any) # replace with pep-673 when applicable @@ -95,7 +88,7 @@ class TypeEngine(Visitable, Generic[_T]): """ class Comparator( - ColumnOperators["ColumnElement"], + ColumnOperators, Generic[_CT], ): """Base class for custom comparison operations defined at the @@ -539,7 +532,7 @@ class TypeEngine(Visitable, Generic[_T]): return util.method_is_overridden(self, TypeEngine.bind_expression) @staticmethod - def _to_instance(cls_or_self): + def _to_instance(cls_or_self: Union[Type[_TE], _TE]) -> _TE: return to_instance(cls_or_self) def compare_values(self, x, y): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 0c41e440e..903aae648 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -866,7 +866,7 @@ def traverse( def cloned_traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], - visitors: Mapping[str, _TraverseTransformCallableType], + visitors: Mapping[str, _TraverseCallableType[Any]], ) -> ExternallyTraversible: """Clone the given expression structure, allowing modifications by visitors. |
