diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/_elements_constructors.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/_typing.py | 21 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/ddl.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 37 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/roles.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 135 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 98 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 8 |
13 files changed, 172 insertions, 176 deletions
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index ea21e01c6..605f75ec4 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -389,7 +389,7 @@ def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: def bindparam( - key: str, + key: Optional[str], value: Any = _NoArg.NO_ARG, type_: Optional[TypeEngine[_T]] = None, unique: bool = False, @@ -521,6 +521,11 @@ def bindparam( key, or if its length is too long and truncation is required. + If omitted, an "anonymous" name is generated for the bound parameter; + when given a value to bind, the end result is equivalent to calling upon + the :func:`.literal` function with a value to bind, particularly + if the :paramref:`.bindparam.unique` parameter is also provided. + :param value: Initial value for this bind param. Will be used at statement execution time as the value for this parameter passed to the diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index b0a717a1a..53d29b628 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -2,13 +2,14 @@ from __future__ import annotations import operator from typing import Any +from typing import Callable from typing import Dict +from typing import Set from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.sql.base import Executable from . import roles from .. import util from ..inspection import Inspectable @@ -16,6 +17,7 @@ from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: + from .base import Executable from .compiler import Compiled from .compiler import DDLCompiler from .compiler import SQLCompiler @@ -27,17 +29,20 @@ if TYPE_CHECKING: from .elements import quoted_name from .elements import SQLCoreOperations from .elements import TextClause + from .lambdas import LambdaElement from .roles import ColumnsClauseRole from .roles import FromClauseRole from .schema import Column from .schema import DefaultGenerator from .schema import Sequence + from .schema import Table from .selectable import Alias from .selectable import FromClause from .selectable import Join from .selectable import NamedFromClause from .selectable import ReturnsRows from .selectable import Select + from .selectable import Selectable from .selectable import SelectBase from .selectable import Subquery from .selectable import TableClause @@ -46,7 +51,6 @@ if TYPE_CHECKING: from .type_api import TypeEngine from ..util.typing import TypeGuard - _T = TypeVar("_T", bound=Any) @@ -89,7 +93,11 @@ sets; select(...), insert().returning(...), etc. """ _ColumnExpressionArgument = Union[ - "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T] + "ColumnElement[_T]", + _HasClauseElement, + roles.ExpressionElementRole[_T], + Callable[[], "ColumnElement[_T]"], + "LambdaElement", ] """narrower "column expression" argument. @@ -103,6 +111,7 @@ overall which brings in the TextClause object also. """ + _InfoType = Dict[Any, Any] """the .info dictionary accepted and used throughout Core /ORM""" @@ -169,6 +178,8 @@ _PropagateAttrsType = util.immutabledict[str, Any] _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] +_EquivalentColumnMap = Dict["ColumnElement[Any]", Set["ColumnElement[Any]"]] + if TYPE_CHECKING: def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: @@ -195,6 +206,9 @@ if TYPE_CHECKING: def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]: ... + def is_selectable(t: Any) -> TypeGuard[Selectable]: + ... + def is_select_base( t: Union[Executable, ReturnsRows] ) -> TypeGuard[SelectBase]: @@ -224,6 +238,7 @@ else: is_from_clause = operator.attrgetter("_is_from_clause") is_tuple_type = operator.attrgetter("_is_tuple_type") is_table_value_type = operator.attrgetter("_is_table_value") + is_selectable = operator.attrgetter("is_selectable") is_select_base = operator.attrgetter("_is_select_base") is_select_statement = operator.attrgetter("_is_select_statement") is_table = operator.attrgetter("_is_table") diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f7692dbc2..f81878d55 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -218,7 +218,7 @@ def _generative(fn: _Fn) -> _Fn: """ - @util.decorator + @util.decorator # type: ignore def _generative( fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any ) -> _SelfGenerativeType: @@ -244,7 +244,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: for name in names ] - @util.decorator + @util.decorator # type: ignore def check(fn, *args, **kw): # make pylance happy by not including "self" in the argument # list @@ -260,7 +260,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: raise exc.InvalidRequestError(msg) return fn(self, *args, **kw) - return check + return check # type: ignore def _clone(element, **kw): @@ -1750,15 +1750,14 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): self._collection.append((k, col)) self._colset.update(c for (k, c) in self._collection) - # https://github.com/python/mypy/issues/12610 self._index.update( - (idx, c) for idx, (k, c) in enumerate(self._collection) # type: ignore # noqa: E501 + (idx, c) for idx, (k, c) in enumerate(self._collection) ) for col in replace_col: self.replace(col) def extend(self, iter_: Iterable[_NAMEDCOL]) -> None: - self._populate_separate_keys((col.key, col) for col in iter_) # type: ignore # noqa: E501 + self._populate_separate_keys((col.key, col) for col in iter_) def remove(self, column: _NAMEDCOL) -> None: if column not in self._colset: @@ -1772,9 +1771,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): (k, c) for (k, c) in self._collection if c is not column ] - # https://github.com/python/mypy/issues/12610 self._index.update( - {idx: col for idx, (k, col) in enumerate(self._collection)} # type: ignore # noqa: E501 + {idx: col for idx, (k, col) in enumerate(self._collection)} ) # delete higher index del self._index[len(self._collection)] @@ -1827,9 +1825,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): self._index.clear() - # https://github.com/python/mypy/issues/12610 self._index.update( - {idx: col for idx, (k, col) in enumerate(self._collection)} # type: ignore # noqa: E501 + {idx: col for idx, (k, col) in enumerate(self._collection)} ) self._index.update(self._collection) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 4bf45da9c..0659709ab 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -214,6 +214,7 @@ def expect( Type[roles.ExpressionElementRole[Any]], Type[roles.LimitOffsetRole], Type[roles.WhereHavingRole], + Type[roles.OnClauseRole], ], element: Any, **kw: Any, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 938be0f81..c524a2602 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1078,7 +1078,7 @@ class SQLCompiler(Compiled): return list(self.insert_prefetch) + list(self.update_prefetch) @util.memoized_property - def _global_attributes(self): + def _global_attributes(self) -> Dict[Any, Any]: return {} @util.memoized_instancemethod diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 6ac7c2448..052af6ac9 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -14,6 +14,7 @@ from __future__ import annotations import typing from typing import Any from typing import Callable +from typing import Iterable from typing import List from typing import Optional from typing import Sequence as typing_Sequence @@ -1143,7 +1144,7 @@ class SchemaDropper(InvokeDDLBase): def sort_tables( - tables: typing_Sequence["Table"], + tables: Iterable["Table"], skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, extra_dependencies: Optional[ typing_Sequence[Tuple["Table", "Table"]] diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ea0fa7996..34d5127ab 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -293,11 +293,18 @@ class ClauseElement( __visit_name__ = "clause" - _propagate_attrs: _PropagateAttrsType = util.immutabledict() - """like annotations, however these propagate outwards liberally - as SQL constructs are built, and are set up at construction time. + if TYPE_CHECKING: - """ + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + """like annotations, however these propagate outwards liberally + as SQL constructs are built, and are set up at construction time. + + """ + ... + + else: + _propagate_attrs = util.EMPTY_DICT @util.ro_memoized_property def description(self) -> Optional[str]: @@ -343,7 +350,9 @@ class ClauseElement( def _from_objects(self) -> List[FromClause]: return [] - def _set_propagate_attrs(self, values): + def _set_propagate_attrs( + self: SelfClauseElement, values: Mapping[str, Any] + ) -> SelfClauseElement: # usually, self._propagate_attrs is empty here. one case where it's # not is a subquery against ORM select, that is then pulled as a # property of an aliased class. should all be good @@ -526,13 +535,10 @@ class ClauseElement( if unique: bind._convert_to_unique() - return cast( - SelfClauseElement, - cloned_traverse( - self, - {"maintain_key": True, "detect_subquery_cols": True}, - {"bindparam": visit_bindparam}, - ), + return cloned_traverse( + self, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, ) def compare(self, other, **kw): @@ -730,7 +736,9 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): # redefined with the specific types returned by ColumnElement hierarchies if typing.TYPE_CHECKING: - _propagate_attrs: _PropagateAttrsType + @util.non_memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + ... def operate( self, op: OperatorType, *other: Any, **kwargs: Any @@ -2064,10 +2072,11 @@ class TextClause( roles.OrderByRole, roles.FromClauseRole, roles.SelectStatementRole, - roles.BinaryElementRole[Any], roles.InElementRole, Executable, DQLDMLClauseElement, + roles.BinaryElementRole[Any], + inspection.Inspectable["TextClause"], ): """Represent a literal SQL text fragment. diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index da15c305f..4b220188f 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -444,7 +444,7 @@ class DeferredLambdaElement(LambdaElement): def _invoke_user_fn(self, fn, *arg): return fn(*self.lambda_args) - def _resolve_with_args(self, *lambda_args): + def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement: assert isinstance(self._rec, AnalyzedFunction) tracker_fn = self._rec.tracker_instrumented_fn expr = tracker_fn(*lambda_args) @@ -478,7 +478,7 @@ class DeferredLambdaElement(LambdaElement): for deferred_copy_internals in self._transforms: expr = deferred_copy_internals(expr) - return expr + return expr # type: ignore def _copy_internals( self, clone=_clone, deferred_copy_internals=None, **kw diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 577d868fd..231c70a5b 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -22,9 +22,7 @@ if TYPE_CHECKING: from .base import _EntityNamespace from .base import ColumnCollection from .base import ReadOnlyColumnCollection - from .elements import ClauseElement from .elements import ColumnClause - from .elements import ColumnElement from .elements import Label from .elements import NamedColumn from .selectable import _SelectIterable @@ -271,7 +269,14 @@ class StatementRole(SQLRole): __slots__ = () _role_name = "Executable SQL or text() construct" - _propagate_attrs: _PropagateAttrsType = util.immutabledict() + if TYPE_CHECKING: + + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + ... + + else: + _propagate_attrs = util.EMPTY_DICT class SelectStatementRole(StatementRole, ReturnsRowsRole): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 92b9cc62c..52ba60a62 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -144,9 +144,9 @@ class SchemaConst(Enum): NULL_UNSPECIFIED = 3 """Symbol indicating the "nullable" keyword was not passed to a Column. - Normally we would expect None to be acceptable for this but some backends - such as that of SQL Server place special signficance on a "nullability" - value of None. + This is used to distinguish between the use case of passing + ``nullable=None`` to a :class:`.Column`, which has special meaning + on some backends such as SQL Server. """ @@ -308,7 +308,9 @@ class HasSchemaAttr(SchemaItem): schema: Optional[str] -class Table(DialectKWArgs, HasSchemaAttr, TableClause): +class Table( + DialectKWArgs, HasSchemaAttr, TableClause, inspection.Inspectable["Table"] +): r"""Represent a table in a database. e.g.:: @@ -1318,117 +1320,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): inherit_cache = True key: str - @overload - def __init__( - self, - *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - @overload - def __init__( - self, - __name: str, - *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - @overload def __init__( self, - __type: _TypeEngineArgument[_T], - *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - @overload - def __init__( - self, - __name: str, - __type: _TypeEngineArgument[_T], + __name_pos: Optional[ + Union[str, _TypeEngineArgument[_T], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[_T], SchemaEventTarget] + ] = None, *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - def __init__( - self, - *args: Union[str, _TypeEngineArgument[_T], SchemaEventTarget], name: Optional[str] = None, type_: Optional[_TypeEngineArgument[_T]] = None, autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", @@ -1440,7 +1340,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): info: Optional[_InfoType] = None, nullable: Optional[ Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, + ] = SchemaConst.NULL_UNSPECIFIED, onupdate: Optional[Any] = None, primary_key: bool = False, server_default: Optional[_ServerDefaultType] = None, @@ -1953,7 +1853,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): """ # noqa: E501, RST201, RST202 - l_args = list(args) + l_args = [__name_pos, __type_pos] + list(args) del args if l_args: @@ -1963,6 +1863,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): "May not pass name positionally and as a keyword." ) name = l_args.pop(0) # type: ignore + elif l_args[0] is None: + l_args.pop(0) if l_args: coltype = l_args[0] @@ -1972,6 +1874,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): "May not pass type_ positionally and as a keyword." ) type_ = l_args.pop(0) # type: ignore + elif l_args[0] is None: + l_args.pop(0) if name is not None: name = quoted_name(name, quote) @@ -1989,7 +1893,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): self.primary_key = primary_key self._user_defined_nullable = udn = nullable - if udn is not NULL_UNSPECIFIED: self.nullable = udn else: @@ -5128,7 +5031,7 @@ class MetaData(HasSchemaAttr): def clear(self) -> None: """Clear all Table objects from this MetaData.""" - dict.clear(self.tables) + dict.clear(self.tables) # type: ignore self._schemas.clear() self._fk_memos.clear() diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index aab3c678c..9d4d1d6c7 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1223,7 +1223,9 @@ class Join(roles.DMLTableRole, FromClause): @util.preload_module("sqlalchemy.sql.util") def _populate_column_collection(self): sqlutil = util.preloaded.sql_util - columns = [c for c in self.left.c] + [c for c in self.right.c] + columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [ + c for c in self.right.c + ] self.primary_key.extend( # type: ignore sqlutil.reduce_columns( diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 284343154..d08fef60a 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -17,7 +17,9 @@ from typing import AbstractSet from typing import Any from typing import Callable from typing import cast +from typing import Collection from typing import Dict +from typing import Iterable from typing import Iterator from typing import List from typing import Optional @@ -32,15 +34,15 @@ from . import coercions from . import operators from . import roles from . import visitors +from ._typing import is_text_clause from .annotation import _deep_annotate as _deep_annotate from .annotation import _deep_deannotate as _deep_deannotate from .annotation import _shallow_annotate as _shallow_annotate from .base import _expand_cloned from .base import _from_objects -from .base import ColumnSet -from .cache_key import HasCacheKey # noqa -from .ddl import sort_tables # noqa -from .elements import _find_columns +from .cache_key import HasCacheKey as HasCacheKey +from .ddl import sort_tables as sort_tables +from .elements import _find_columns as _find_columns from .elements import _label_reference from .elements import _textual_label_reference from .elements import BindParameter @@ -67,10 +69,13 @@ from ..util.typing import Protocol if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument + from ._typing import _EquivalentColumnMap from ._typing import _TypeEngineArgument + from .elements import TextClause from .roles import FromClauseRole from .selectable import _JoinTargetElement from .selectable import _OnClauseElement + from .selectable import _SelectIterable from .selectable import Selectable from .visitors import _TraverseCallableType from .visitors import ExternallyTraversible @@ -752,7 +757,29 @@ def splice_joins( return ret -def reduce_columns(columns, *clauses, **kw): +@overload +def reduce_columns( + columns: Iterable[ColumnElement[Any]], + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Sequence[ColumnElement[Any]]: + ... + + +@overload +def reduce_columns( + columns: _SelectIterable, + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Sequence[Union[ColumnElement[Any], TextClause]]: + ... + + +def reduce_columns( + columns: _SelectIterable, + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Collection[Union[ColumnElement[Any], TextClause]]: r"""given a list of columns, return a 'reduced' set based on natural equivalents. @@ -775,12 +802,15 @@ def reduce_columns(columns, *clauses, **kw): ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) only_synonyms = kw.pop("only_synonyms", False) - columns = util.ordered_column_set(columns) + column_set = util.OrderedSet(columns) + cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference( + c for c in column_set if is_text_clause(c) # type: ignore + ) omit = util.column_set() - for col in columns: + for col in cset_no_text: for fk in chain(*[c.foreign_keys for c in col.proxy_set]): - for c in columns: + for c in cset_no_text: if c is col: continue try: @@ -810,10 +840,12 @@ def reduce_columns(columns, *clauses, **kw): def visit_binary(binary): if binary.operator == operators.eq: cols = util.column_set( - chain(*[c.proxy_set for c in columns.difference(omit)]) + chain( + *[c.proxy_set for c in cset_no_text.difference(omit)] + ) ) if binary.left in cols and binary.right in cols: - for c in reversed(columns): + for c in reversed(cset_no_text): if c.shares_lineage(binary.right) and ( not only_synonyms or c.name == binary.left.name ): @@ -824,7 +856,7 @@ def reduce_columns(columns, *clauses, **kw): if clause is not None: visitors.traverse(clause, {}, {"binary": visit_binary}) - return ColumnSet(columns.difference(omit)) + return column_set.difference(omit) def criterion_as_pairs( @@ -923,9 +955,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): def __init__( self, selectable: Selectable, - equivalents: Optional[ - Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] - ] = None, + equivalents: Optional[_EquivalentColumnMap] = None, include_fn: Optional[Callable[[ClauseElement], bool]] = None, exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, adapt_on_names: bool = False, @@ -1059,9 +1089,23 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): class _ColumnLookup(Protocol): - def __getitem__( - self, key: ColumnElement[Any] - ) -> Optional[ColumnElement[Any]]: + @overload + def __getitem__(self, key: None) -> None: + ... + + @overload + def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: + ... + + @overload + def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: + ... + + @overload + def __getitem__(self, key: _ET) -> _ET: + ... + + def __getitem__(self, key: Any) -> Any: ... @@ -1101,9 +1145,7 @@ class ColumnAdapter(ClauseAdapter): def __init__( self, selectable: Selectable, - equivalents: Optional[ - Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] - ] = None, + equivalents: Optional[_EquivalentColumnMap] = None, adapt_required: bool = False, include_fn: Optional[Callable[[ClauseElement], bool]] = None, exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, @@ -1155,7 +1197,17 @@ class ColumnAdapter(ClauseAdapter): return ac - def traverse(self, obj): + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + @overload + def traverse(self, obj: _ET) -> _ET: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: return self.columns[obj] def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: @@ -1172,7 +1224,9 @@ class ColumnAdapter(ClauseAdapter): adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process - def adapt_check_present(self, col): + def adapt_check_present( + self, col: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: newcol = self.columns[col] if newcol is col and self._corresponding_column(col, True) is None: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 7363f9ddc..e0a66fbcf 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -961,12 +961,16 @@ def cloned_traverse( ... +# a bit of controversy here, as the clone of the lead element +# *could* in theory replace with an entirely different kind of element. +# however this is really not how cloned_traverse is ever used internally +# at least. @overload def cloned_traverse( - obj: ExternallyTraversible, + obj: _ET, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> ExternallyTraversible: +) -> _ET: ... |
