summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-15 11:05:36 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-20 15:14:09 -0400
commitaeeff72e806420bf85e2e6723b1f941df38a3e1a (patch)
tree0bed521b4d7c4860f998e51ba5e318d18b2f5900 /lib/sqlalchemy/sql
parent13a8552053c21a9fa7ff6f992ed49ee92cca73e4 (diff)
downloadsqlalchemy-aeeff72e806420bf85e2e6723b1f941df38a3e1a.tar.gz
pep-484: ORM public API, constructors
for the moment, abandoning using @overload with relationship() and mapped_column(). The overloads are very difficult to get working at all, and the overloads that were there all wouldn't pass on mypy. various techniques of getting them to "work", meaning having right hand side dictate what's legal on the left, have mixed success and wont give consistent results; additionally, it's legal to have Optional / non-optional independent of nullable in any case for columns. relationship cases are less ambiguous but mypy was not going along with things. we have a comprehensive system of allowing left side annotations to drive the right side, in the absense of explicit settings on the right. so type-centric SQLAlchemy will be left-side driven just like dataclasses, and the various flags and switches on the right side will just not be needed very much. in other matters, one surprise, forgot to remove string support from orm.join(A, B, "somename") or do deprecations for it in 1.4. This is a really not-directly-used structure barely mentioned in the docs for many years, the example shows a relationship being used, not a string, so we will just change it to raise the usual error here. Change-Id: Iefbbb8d34548b538023890ab8b7c9a5d9496ec6e
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/_elements_constructors.py7
-rw-r--r--lib/sqlalchemy/sql/_typing.py21
-rw-r--r--lib/sqlalchemy/sql/base.py17
-rw-r--r--lib/sqlalchemy/sql/coercions.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py2
-rw-r--r--lib/sqlalchemy/sql/ddl.py3
-rw-r--r--lib/sqlalchemy/sql/elements.py37
-rw-r--r--lib/sqlalchemy/sql/lambdas.py4
-rw-r--r--lib/sqlalchemy/sql/roles.py11
-rw-r--r--lib/sqlalchemy/sql/schema.py135
-rw-r--r--lib/sqlalchemy/sql/selectable.py4
-rw-r--r--lib/sqlalchemy/sql/util.py98
-rw-r--r--lib/sqlalchemy/sql/visitors.py8
13 files changed, 172 insertions, 176 deletions
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py
index ea21e01c6..605f75ec4 100644
--- a/lib/sqlalchemy/sql/_elements_constructors.py
+++ b/lib/sqlalchemy/sql/_elements_constructors.py
@@ -389,7 +389,7 @@ def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]:
def bindparam(
- key: str,
+ key: Optional[str],
value: Any = _NoArg.NO_ARG,
type_: Optional[TypeEngine[_T]] = None,
unique: bool = False,
@@ -521,6 +521,11 @@ def bindparam(
key, or if its length is too long and truncation is
required.
+ If omitted, an "anonymous" name is generated for the bound parameter;
+ when given a value to bind, the end result is equivalent to calling upon
+ the :func:`.literal` function with a value to bind, particularly
+ if the :paramref:`.bindparam.unique` parameter is also provided.
+
:param value:
Initial value for this bind param. Will be used at statement
execution time as the value for this parameter passed to the
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index b0a717a1a..53d29b628 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -2,13 +2,14 @@ from __future__ import annotations
import operator
from typing import Any
+from typing import Callable
from typing import Dict
+from typing import Set
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from sqlalchemy.sql.base import Executable
from . import roles
from .. import util
from ..inspection import Inspectable
@@ -16,6 +17,7 @@ from ..util.typing import Literal
from ..util.typing import Protocol
if TYPE_CHECKING:
+ from .base import Executable
from .compiler import Compiled
from .compiler import DDLCompiler
from .compiler import SQLCompiler
@@ -27,17 +29,20 @@ if TYPE_CHECKING:
from .elements import quoted_name
from .elements import SQLCoreOperations
from .elements import TextClause
+ from .lambdas import LambdaElement
from .roles import ColumnsClauseRole
from .roles import FromClauseRole
from .schema import Column
from .schema import DefaultGenerator
from .schema import Sequence
+ from .schema import Table
from .selectable import Alias
from .selectable import FromClause
from .selectable import Join
from .selectable import NamedFromClause
from .selectable import ReturnsRows
from .selectable import Select
+ from .selectable import Selectable
from .selectable import SelectBase
from .selectable import Subquery
from .selectable import TableClause
@@ -46,7 +51,6 @@ if TYPE_CHECKING:
from .type_api import TypeEngine
from ..util.typing import TypeGuard
-
_T = TypeVar("_T", bound=Any)
@@ -89,7 +93,11 @@ sets; select(...), insert().returning(...), etc.
"""
_ColumnExpressionArgument = Union[
- "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T]
+ "ColumnElement[_T]",
+ _HasClauseElement,
+ roles.ExpressionElementRole[_T],
+ Callable[[], "ColumnElement[_T]"],
+ "LambdaElement",
]
"""narrower "column expression" argument.
@@ -103,6 +111,7 @@ overall which brings in the TextClause object also.
"""
+
_InfoType = Dict[Any, Any]
"""the .info dictionary accepted and used throughout Core /ORM"""
@@ -169,6 +178,8 @@ _PropagateAttrsType = util.immutabledict[str, Any]
_TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
+_EquivalentColumnMap = Dict["ColumnElement[Any]", Set["ColumnElement[Any]"]]
+
if TYPE_CHECKING:
def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]:
@@ -195,6 +206,9 @@ if TYPE_CHECKING:
def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]:
...
+ def is_selectable(t: Any) -> TypeGuard[Selectable]:
+ ...
+
def is_select_base(
t: Union[Executable, ReturnsRows]
) -> TypeGuard[SelectBase]:
@@ -224,6 +238,7 @@ else:
is_from_clause = operator.attrgetter("_is_from_clause")
is_tuple_type = operator.attrgetter("_is_tuple_type")
is_table_value_type = operator.attrgetter("_is_table_value")
+ is_selectable = operator.attrgetter("is_selectable")
is_select_base = operator.attrgetter("_is_select_base")
is_select_statement = operator.attrgetter("_is_select_statement")
is_table = operator.attrgetter("_is_table")
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index f7692dbc2..f81878d55 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -218,7 +218,7 @@ def _generative(fn: _Fn) -> _Fn:
"""
- @util.decorator
+ @util.decorator # type: ignore
def _generative(
fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any
) -> _SelfGenerativeType:
@@ -244,7 +244,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
for name in names
]
- @util.decorator
+ @util.decorator # type: ignore
def check(fn, *args, **kw):
# make pylance happy by not including "self" in the argument
# list
@@ -260,7 +260,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
raise exc.InvalidRequestError(msg)
return fn(self, *args, **kw)
- return check
+ return check # type: ignore
def _clone(element, **kw):
@@ -1750,15 +1750,14 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
self._collection.append((k, col))
self._colset.update(c for (k, c) in self._collection)
- # https://github.com/python/mypy/issues/12610
self._index.update(
- (idx, c) for idx, (k, c) in enumerate(self._collection) # type: ignore # noqa: E501
+ (idx, c) for idx, (k, c) in enumerate(self._collection)
)
for col in replace_col:
self.replace(col)
def extend(self, iter_: Iterable[_NAMEDCOL]) -> None:
- self._populate_separate_keys((col.key, col) for col in iter_) # type: ignore # noqa: E501
+ self._populate_separate_keys((col.key, col) for col in iter_)
def remove(self, column: _NAMEDCOL) -> None:
if column not in self._colset:
@@ -1772,9 +1771,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
(k, c) for (k, c) in self._collection if c is not column
]
- # https://github.com/python/mypy/issues/12610
self._index.update(
- {idx: col for idx, (k, col) in enumerate(self._collection)} # type: ignore # noqa: E501
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
)
# delete higher index
del self._index[len(self._collection)]
@@ -1827,9 +1825,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
self._index.clear()
- # https://github.com/python/mypy/issues/12610
self._index.update(
- {idx: col for idx, (k, col) in enumerate(self._collection)} # type: ignore # noqa: E501
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
)
self._index.update(self._collection)
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index 4bf45da9c..0659709ab 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -214,6 +214,7 @@ def expect(
Type[roles.ExpressionElementRole[Any]],
Type[roles.LimitOffsetRole],
Type[roles.WhereHavingRole],
+ Type[roles.OnClauseRole],
],
element: Any,
**kw: Any,
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 938be0f81..c524a2602 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1078,7 +1078,7 @@ class SQLCompiler(Compiled):
return list(self.insert_prefetch) + list(self.update_prefetch)
@util.memoized_property
- def _global_attributes(self):
+ def _global_attributes(self) -> Dict[Any, Any]:
return {}
@util.memoized_instancemethod
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 6ac7c2448..052af6ac9 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -14,6 +14,7 @@ from __future__ import annotations
import typing
from typing import Any
from typing import Callable
+from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence as typing_Sequence
@@ -1143,7 +1144,7 @@ class SchemaDropper(InvokeDDLBase):
def sort_tables(
- tables: typing_Sequence["Table"],
+ tables: Iterable["Table"],
skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None,
extra_dependencies: Optional[
typing_Sequence[Tuple["Table", "Table"]]
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index ea0fa7996..34d5127ab 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -293,11 +293,18 @@ class ClauseElement(
__visit_name__ = "clause"
- _propagate_attrs: _PropagateAttrsType = util.immutabledict()
- """like annotations, however these propagate outwards liberally
- as SQL constructs are built, and are set up at construction time.
+ if TYPE_CHECKING:
- """
+ @util.memoized_property
+ def _propagate_attrs(self) -> _PropagateAttrsType:
+ """like annotations, however these propagate outwards liberally
+ as SQL constructs are built, and are set up at construction time.
+
+ """
+ ...
+
+ else:
+ _propagate_attrs = util.EMPTY_DICT
@util.ro_memoized_property
def description(self) -> Optional[str]:
@@ -343,7 +350,9 @@ class ClauseElement(
def _from_objects(self) -> List[FromClause]:
return []
- def _set_propagate_attrs(self, values):
+ def _set_propagate_attrs(
+ self: SelfClauseElement, values: Mapping[str, Any]
+ ) -> SelfClauseElement:
# usually, self._propagate_attrs is empty here. one case where it's
# not is a subquery against ORM select, that is then pulled as a
# property of an aliased class. should all be good
@@ -526,13 +535,10 @@ class ClauseElement(
if unique:
bind._convert_to_unique()
- return cast(
- SelfClauseElement,
- cloned_traverse(
- self,
- {"maintain_key": True, "detect_subquery_cols": True},
- {"bindparam": visit_bindparam},
- ),
+ return cloned_traverse(
+ self,
+ {"maintain_key": True, "detect_subquery_cols": True},
+ {"bindparam": visit_bindparam},
)
def compare(self, other, **kw):
@@ -730,7 +736,9 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
# redefined with the specific types returned by ColumnElement hierarchies
if typing.TYPE_CHECKING:
- _propagate_attrs: _PropagateAttrsType
+ @util.non_memoized_property
+ def _propagate_attrs(self) -> _PropagateAttrsType:
+ ...
def operate(
self, op: OperatorType, *other: Any, **kwargs: Any
@@ -2064,10 +2072,11 @@ class TextClause(
roles.OrderByRole,
roles.FromClauseRole,
roles.SelectStatementRole,
- roles.BinaryElementRole[Any],
roles.InElementRole,
Executable,
DQLDMLClauseElement,
+ roles.BinaryElementRole[Any],
+ inspection.Inspectable["TextClause"],
):
"""Represent a literal SQL text fragment.
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index da15c305f..4b220188f 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -444,7 +444,7 @@ class DeferredLambdaElement(LambdaElement):
def _invoke_user_fn(self, fn, *arg):
return fn(*self.lambda_args)
- def _resolve_with_args(self, *lambda_args):
+ def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement:
assert isinstance(self._rec, AnalyzedFunction)
tracker_fn = self._rec.tracker_instrumented_fn
expr = tracker_fn(*lambda_args)
@@ -478,7 +478,7 @@ class DeferredLambdaElement(LambdaElement):
for deferred_copy_internals in self._transforms:
expr = deferred_copy_internals(expr)
- return expr
+ return expr # type: ignore
def _copy_internals(
self, clone=_clone, deferred_copy_internals=None, **kw
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 577d868fd..231c70a5b 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -22,9 +22,7 @@ if TYPE_CHECKING:
from .base import _EntityNamespace
from .base import ColumnCollection
from .base import ReadOnlyColumnCollection
- from .elements import ClauseElement
from .elements import ColumnClause
- from .elements import ColumnElement
from .elements import Label
from .elements import NamedColumn
from .selectable import _SelectIterable
@@ -271,7 +269,14 @@ class StatementRole(SQLRole):
__slots__ = ()
_role_name = "Executable SQL or text() construct"
- _propagate_attrs: _PropagateAttrsType = util.immutabledict()
+ if TYPE_CHECKING:
+
+ @util.memoized_property
+ def _propagate_attrs(self) -> _PropagateAttrsType:
+ ...
+
+ else:
+ _propagate_attrs = util.EMPTY_DICT
class SelectStatementRole(StatementRole, ReturnsRowsRole):
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 92b9cc62c..52ba60a62 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -144,9 +144,9 @@ class SchemaConst(Enum):
NULL_UNSPECIFIED = 3
"""Symbol indicating the "nullable" keyword was not passed to a Column.
- Normally we would expect None to be acceptable for this but some backends
- such as that of SQL Server place special signficance on a "nullability"
- value of None.
+ This is used to distinguish between the use case of passing
+ ``nullable=None`` to a :class:`.Column`, which has special meaning
+ on some backends such as SQL Server.
"""
@@ -308,7 +308,9 @@ class HasSchemaAttr(SchemaItem):
schema: Optional[str]
-class Table(DialectKWArgs, HasSchemaAttr, TableClause):
+class Table(
+ DialectKWArgs, HasSchemaAttr, TableClause, inspection.Inspectable["Table"]
+):
r"""Represent a table in a database.
e.g.::
@@ -1318,117 +1320,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
inherit_cache = True
key: str
- @overload
- def __init__(
- self,
- *args: SchemaEventTarget,
- autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
- doc: Optional[str] = None,
- key: Optional[str] = None,
- index: Optional[bool] = None,
- unique: Optional[bool] = None,
- info: Optional[_InfoType] = None,
- nullable: Optional[
- Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
- ] = NULL_UNSPECIFIED,
- onupdate: Optional[Any] = None,
- primary_key: bool = False,
- server_default: Optional[_ServerDefaultType] = None,
- server_onupdate: Optional[FetchedValue] = None,
- quote: Optional[bool] = None,
- system: bool = False,
- comment: Optional[str] = None,
- _proxies: Optional[Any] = None,
- **dialect_kwargs: Any,
- ):
- ...
-
- @overload
- def __init__(
- self,
- __name: str,
- *args: SchemaEventTarget,
- autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
- doc: Optional[str] = None,
- key: Optional[str] = None,
- index: Optional[bool] = None,
- unique: Optional[bool] = None,
- info: Optional[_InfoType] = None,
- nullable: Optional[
- Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
- ] = NULL_UNSPECIFIED,
- onupdate: Optional[Any] = None,
- primary_key: bool = False,
- server_default: Optional[_ServerDefaultType] = None,
- server_onupdate: Optional[FetchedValue] = None,
- quote: Optional[bool] = None,
- system: bool = False,
- comment: Optional[str] = None,
- _proxies: Optional[Any] = None,
- **dialect_kwargs: Any,
- ):
- ...
-
- @overload
def __init__(
self,
- __type: _TypeEngineArgument[_T],
- *args: SchemaEventTarget,
- autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
- doc: Optional[str] = None,
- key: Optional[str] = None,
- index: Optional[bool] = None,
- unique: Optional[bool] = None,
- info: Optional[_InfoType] = None,
- nullable: Optional[
- Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
- ] = NULL_UNSPECIFIED,
- onupdate: Optional[Any] = None,
- primary_key: bool = False,
- server_default: Optional[_ServerDefaultType] = None,
- server_onupdate: Optional[FetchedValue] = None,
- quote: Optional[bool] = None,
- system: bool = False,
- comment: Optional[str] = None,
- _proxies: Optional[Any] = None,
- **dialect_kwargs: Any,
- ):
- ...
-
- @overload
- def __init__(
- self,
- __name: str,
- __type: _TypeEngineArgument[_T],
+ __name_pos: Optional[
+ Union[str, _TypeEngineArgument[_T], SchemaEventTarget]
+ ] = None,
+ __type_pos: Optional[
+ Union[_TypeEngineArgument[_T], SchemaEventTarget]
+ ] = None,
*args: SchemaEventTarget,
- autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
- doc: Optional[str] = None,
- key: Optional[str] = None,
- index: Optional[bool] = None,
- unique: Optional[bool] = None,
- info: Optional[_InfoType] = None,
- nullable: Optional[
- Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
- ] = NULL_UNSPECIFIED,
- onupdate: Optional[Any] = None,
- primary_key: bool = False,
- server_default: Optional[_ServerDefaultType] = None,
- server_onupdate: Optional[FetchedValue] = None,
- quote: Optional[bool] = None,
- system: bool = False,
- comment: Optional[str] = None,
- _proxies: Optional[Any] = None,
- **dialect_kwargs: Any,
- ):
- ...
-
- def __init__(
- self,
- *args: Union[str, _TypeEngineArgument[_T], SchemaEventTarget],
name: Optional[str] = None,
type_: Optional[_TypeEngineArgument[_T]] = None,
autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
@@ -1440,7 +1340,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
info: Optional[_InfoType] = None,
nullable: Optional[
Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
- ] = NULL_UNSPECIFIED,
+ ] = SchemaConst.NULL_UNSPECIFIED,
onupdate: Optional[Any] = None,
primary_key: bool = False,
server_default: Optional[_ServerDefaultType] = None,
@@ -1953,7 +1853,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
""" # noqa: E501, RST201, RST202
- l_args = list(args)
+ l_args = [__name_pos, __type_pos] + list(args)
del args
if l_args:
@@ -1963,6 +1863,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
"May not pass name positionally and as a keyword."
)
name = l_args.pop(0) # type: ignore
+ elif l_args[0] is None:
+ l_args.pop(0)
if l_args:
coltype = l_args[0]
@@ -1972,6 +1874,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
"May not pass type_ positionally and as a keyword."
)
type_ = l_args.pop(0) # type: ignore
+ elif l_args[0] is None:
+ l_args.pop(0)
if name is not None:
name = quoted_name(name, quote)
@@ -1989,7 +1893,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
self.primary_key = primary_key
self._user_defined_nullable = udn = nullable
-
if udn is not NULL_UNSPECIFIED:
self.nullable = udn
else:
@@ -5128,7 +5031,7 @@ class MetaData(HasSchemaAttr):
def clear(self) -> None:
"""Clear all Table objects from this MetaData."""
- dict.clear(self.tables)
+ dict.clear(self.tables) # type: ignore
self._schemas.clear()
self._fk_memos.clear()
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index aab3c678c..9d4d1d6c7 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -1223,7 +1223,9 @@ class Join(roles.DMLTableRole, FromClause):
@util.preload_module("sqlalchemy.sql.util")
def _populate_column_collection(self):
sqlutil = util.preloaded.sql_util
- columns = [c for c in self.left.c] + [c for c in self.right.c]
+ columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [
+ c for c in self.right.c
+ ]
self.primary_key.extend( # type: ignore
sqlutil.reduce_columns(
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 284343154..d08fef60a 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -17,7 +17,9 @@ from typing import AbstractSet
from typing import Any
from typing import Callable
from typing import cast
+from typing import Collection
from typing import Dict
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
@@ -32,15 +34,15 @@ from . import coercions
from . import operators
from . import roles
from . import visitors
+from ._typing import is_text_clause
from .annotation import _deep_annotate as _deep_annotate
from .annotation import _deep_deannotate as _deep_deannotate
from .annotation import _shallow_annotate as _shallow_annotate
from .base import _expand_cloned
from .base import _from_objects
-from .base import ColumnSet
-from .cache_key import HasCacheKey # noqa
-from .ddl import sort_tables # noqa
-from .elements import _find_columns
+from .cache_key import HasCacheKey as HasCacheKey
+from .ddl import sort_tables as sort_tables
+from .elements import _find_columns as _find_columns
from .elements import _label_reference
from .elements import _textual_label_reference
from .elements import BindParameter
@@ -67,10 +69,13 @@ from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
+ from ._typing import _EquivalentColumnMap
from ._typing import _TypeEngineArgument
+ from .elements import TextClause
from .roles import FromClauseRole
from .selectable import _JoinTargetElement
from .selectable import _OnClauseElement
+ from .selectable import _SelectIterable
from .selectable import Selectable
from .visitors import _TraverseCallableType
from .visitors import ExternallyTraversible
@@ -752,7 +757,29 @@ def splice_joins(
return ret
-def reduce_columns(columns, *clauses, **kw):
+@overload
+def reduce_columns(
+ columns: Iterable[ColumnElement[Any]],
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Sequence[ColumnElement[Any]]:
+ ...
+
+
+@overload
+def reduce_columns(
+ columns: _SelectIterable,
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Sequence[Union[ColumnElement[Any], TextClause]]:
+ ...
+
+
+def reduce_columns(
+ columns: _SelectIterable,
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Collection[Union[ColumnElement[Any], TextClause]]:
r"""given a list of columns, return a 'reduced' set based on natural
equivalents.
@@ -775,12 +802,15 @@ def reduce_columns(columns, *clauses, **kw):
ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
only_synonyms = kw.pop("only_synonyms", False)
- columns = util.ordered_column_set(columns)
+ column_set = util.OrderedSet(columns)
+ cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
+ c for c in column_set if is_text_clause(c) # type: ignore
+ )
omit = util.column_set()
- for col in columns:
+ for col in cset_no_text:
for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
- for c in columns:
+ for c in cset_no_text:
if c is col:
continue
try:
@@ -810,10 +840,12 @@ def reduce_columns(columns, *clauses, **kw):
def visit_binary(binary):
if binary.operator == operators.eq:
cols = util.column_set(
- chain(*[c.proxy_set for c in columns.difference(omit)])
+ chain(
+ *[c.proxy_set for c in cset_no_text.difference(omit)]
+ )
)
if binary.left in cols and binary.right in cols:
- for c in reversed(columns):
+ for c in reversed(cset_no_text):
if c.shares_lineage(binary.right) and (
not only_synonyms or c.name == binary.left.name
):
@@ -824,7 +856,7 @@ def reduce_columns(columns, *clauses, **kw):
if clause is not None:
visitors.traverse(clause, {}, {"binary": visit_binary})
- return ColumnSet(columns.difference(omit))
+ return column_set.difference(omit)
def criterion_as_pairs(
@@ -923,9 +955,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
def __init__(
self,
selectable: Selectable,
- equivalents: Optional[
- Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
- ] = None,
+ equivalents: Optional[_EquivalentColumnMap] = None,
include_fn: Optional[Callable[[ClauseElement], bool]] = None,
exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
adapt_on_names: bool = False,
@@ -1059,9 +1089,23 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
class _ColumnLookup(Protocol):
- def __getitem__(
- self, key: ColumnElement[Any]
- ) -> Optional[ColumnElement[Any]]:
+ @overload
+ def __getitem__(self, key: None) -> None:
+ ...
+
+ @overload
+ def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]:
+ ...
+
+ @overload
+ def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]:
+ ...
+
+ @overload
+ def __getitem__(self, key: _ET) -> _ET:
+ ...
+
+ def __getitem__(self, key: Any) -> Any:
...
@@ -1101,9 +1145,7 @@ class ColumnAdapter(ClauseAdapter):
def __init__(
self,
selectable: Selectable,
- equivalents: Optional[
- Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
- ] = None,
+ equivalents: Optional[_EquivalentColumnMap] = None,
adapt_required: bool = False,
include_fn: Optional[Callable[[ClauseElement], bool]] = None,
exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
@@ -1155,7 +1197,17 @@ class ColumnAdapter(ClauseAdapter):
return ac
- def traverse(self, obj):
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
+ def traverse(self, obj: _ET) -> _ET:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
return self.columns[obj]
def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
@@ -1172,7 +1224,9 @@ class ColumnAdapter(ClauseAdapter):
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
- def adapt_check_present(self, col):
+ def adapt_check_present(
+ self, col: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
newcol = self.columns[col]
if newcol is col and self._corresponding_column(col, True) is None:
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 7363f9ddc..e0a66fbcf 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -961,12 +961,16 @@ def cloned_traverse(
...
+# a bit of controversy here, as the clone of the lead element
+# *could* in theory replace with an entirely different kind of element.
+# however this is really not how cloned_traverse is ever used internally
+# at least.
@overload
def cloned_traverse(
- obj: ExternallyTraversible,
+ obj: _ET,
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
-) -> ExternallyTraversible:
+) -> _ET:
...