summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-15 11:05:36 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-20 15:14:09 -0400
commitaeeff72e806420bf85e2e6723b1f941df38a3e1a (patch)
tree0bed521b4d7c4860f998e51ba5e318d18b2f5900 /lib/sqlalchemy/sql/util.py
parent13a8552053c21a9fa7ff6f992ed49ee92cca73e4 (diff)
downloadsqlalchemy-aeeff72e806420bf85e2e6723b1f941df38a3e1a.tar.gz
pep-484: ORM public API, constructors
for the moment, abandoning using @overload with relationship() and mapped_column(). The overloads are very difficult to get working at all, and the overloads that were there all wouldn't pass on mypy. various techniques of getting them to "work", meaning having right hand side dictate what's legal on the left, have mixed success and wont give consistent results; additionally, it's legal to have Optional / non-optional independent of nullable in any case for columns. relationship cases are less ambiguous but mypy was not going along with things. we have a comprehensive system of allowing left side annotations to drive the right side, in the absense of explicit settings on the right. so type-centric SQLAlchemy will be left-side driven just like dataclasses, and the various flags and switches on the right side will just not be needed very much. in other matters, one surprise, forgot to remove string support from orm.join(A, B, "somename") or do deprecations for it in 1.4. This is a really not-directly-used structure barely mentioned in the docs for many years, the example shows a relationship being used, not a string, so we will just change it to raise the usual error here. Change-Id: Iefbbb8d34548b538023890ab8b7c9a5d9496ec6e
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py98
1 files changed, 76 insertions, 22 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 284343154..d08fef60a 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -17,7 +17,9 @@ from typing import AbstractSet
from typing import Any
from typing import Callable
from typing import cast
+from typing import Collection
from typing import Dict
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
@@ -32,15 +34,15 @@ from . import coercions
from . import operators
from . import roles
from . import visitors
+from ._typing import is_text_clause
from .annotation import _deep_annotate as _deep_annotate
from .annotation import _deep_deannotate as _deep_deannotate
from .annotation import _shallow_annotate as _shallow_annotate
from .base import _expand_cloned
from .base import _from_objects
-from .base import ColumnSet
-from .cache_key import HasCacheKey # noqa
-from .ddl import sort_tables # noqa
-from .elements import _find_columns
+from .cache_key import HasCacheKey as HasCacheKey
+from .ddl import sort_tables as sort_tables
+from .elements import _find_columns as _find_columns
from .elements import _label_reference
from .elements import _textual_label_reference
from .elements import BindParameter
@@ -67,10 +69,13 @@ from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
+ from ._typing import _EquivalentColumnMap
from ._typing import _TypeEngineArgument
+ from .elements import TextClause
from .roles import FromClauseRole
from .selectable import _JoinTargetElement
from .selectable import _OnClauseElement
+ from .selectable import _SelectIterable
from .selectable import Selectable
from .visitors import _TraverseCallableType
from .visitors import ExternallyTraversible
@@ -752,7 +757,29 @@ def splice_joins(
return ret
-def reduce_columns(columns, *clauses, **kw):
+@overload
+def reduce_columns(
+ columns: Iterable[ColumnElement[Any]],
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Sequence[ColumnElement[Any]]:
+ ...
+
+
+@overload
+def reduce_columns(
+ columns: _SelectIterable,
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Sequence[Union[ColumnElement[Any], TextClause]]:
+ ...
+
+
+def reduce_columns(
+ columns: _SelectIterable,
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Collection[Union[ColumnElement[Any], TextClause]]:
r"""given a list of columns, return a 'reduced' set based on natural
equivalents.
@@ -775,12 +802,15 @@ def reduce_columns(columns, *clauses, **kw):
ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
only_synonyms = kw.pop("only_synonyms", False)
- columns = util.ordered_column_set(columns)
+ column_set = util.OrderedSet(columns)
+ cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
+ c for c in column_set if is_text_clause(c) # type: ignore
+ )
omit = util.column_set()
- for col in columns:
+ for col in cset_no_text:
for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
- for c in columns:
+ for c in cset_no_text:
if c is col:
continue
try:
@@ -810,10 +840,12 @@ def reduce_columns(columns, *clauses, **kw):
def visit_binary(binary):
if binary.operator == operators.eq:
cols = util.column_set(
- chain(*[c.proxy_set for c in columns.difference(omit)])
+ chain(
+ *[c.proxy_set for c in cset_no_text.difference(omit)]
+ )
)
if binary.left in cols and binary.right in cols:
- for c in reversed(columns):
+ for c in reversed(cset_no_text):
if c.shares_lineage(binary.right) and (
not only_synonyms or c.name == binary.left.name
):
@@ -824,7 +856,7 @@ def reduce_columns(columns, *clauses, **kw):
if clause is not None:
visitors.traverse(clause, {}, {"binary": visit_binary})
- return ColumnSet(columns.difference(omit))
+ return column_set.difference(omit)
def criterion_as_pairs(
@@ -923,9 +955,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
def __init__(
self,
selectable: Selectable,
- equivalents: Optional[
- Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
- ] = None,
+ equivalents: Optional[_EquivalentColumnMap] = None,
include_fn: Optional[Callable[[ClauseElement], bool]] = None,
exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
adapt_on_names: bool = False,
@@ -1059,9 +1089,23 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
class _ColumnLookup(Protocol):
- def __getitem__(
- self, key: ColumnElement[Any]
- ) -> Optional[ColumnElement[Any]]:
+ @overload
+ def __getitem__(self, key: None) -> None:
+ ...
+
+ @overload
+ def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]:
+ ...
+
+ @overload
+ def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]:
+ ...
+
+ @overload
+ def __getitem__(self, key: _ET) -> _ET:
+ ...
+
+ def __getitem__(self, key: Any) -> Any:
...
@@ -1101,9 +1145,7 @@ class ColumnAdapter(ClauseAdapter):
def __init__(
self,
selectable: Selectable,
- equivalents: Optional[
- Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
- ] = None,
+ equivalents: Optional[_EquivalentColumnMap] = None,
adapt_required: bool = False,
include_fn: Optional[Callable[[ClauseElement], bool]] = None,
exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
@@ -1155,7 +1197,17 @@ class ColumnAdapter(ClauseAdapter):
return ac
- def traverse(self, obj):
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
+ def traverse(self, obj: _ET) -> _ET:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
return self.columns[obj]
def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
@@ -1172,7 +1224,9 @@ class ColumnAdapter(ClauseAdapter):
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
- def adapt_check_present(self, col):
+ def adapt_check_present(
+ self, col: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
newcol = self.columns[col]
if newcol is col and self._corresponding_column(col, True) is None: