summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/visitors.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-30 18:01:58 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-04 09:26:43 -0400
commit3b4d62f4f72e8dfad7f38db192a6a90a8551608c (patch)
treed0334c4bb52f803bd7dad661f2e6a12e25f5880c /lib/sqlalchemy/sql/visitors.py
parent4e603e23755f31278f27a45449120a8dea470a45 (diff)
downloadsqlalchemy-3b4d62f4f72e8dfad7f38db192a6a90a8551608c.tar.gz
pep484 - sql.selectable
the pep484 task becomes more intense as there is mounting pressure to come up with a consistency in how data moves from end-user to instance variable. current thinking is coming into: 1. there are _typing._XYZArgument objects that represent "what the user sent" 2. there's the roles, which represent a kind of "filter" for different kinds of objects. These are mostly important as the argument we pass to coerce(). 3. there's the thing that coerce() returns, which should be what the construct uses as its internal representation of the thing. This is _typing._XYZElement. but there's some controversy over whether or not we should pass actual ClauseElements around by their role or not. I think we shouldn't at the moment, but this makes the "role-ness" of something a little less portable. Like, we have to set DMLTableRole for TableClause, Join, and Alias, but then also we have to repeat those three types in order to set up _DMLTableElement. Other change introduced here, there was a deannotate=True for the left/right of a sql.join(). All tests pass without that. I'd rather not have that there as if we have a join(A, B) where A, B are mapped classes, we want them inside of the _annotations. The rationale seems to be performance, but this performance can be illustrated to be on the compile side which we hope is cached in the normal case. CTEs now accommodate for text selects including recursive. Get typing to accommodate "util.preloaded" cleanly; add "preloaded" as a real module. This seemed like we would have needed pep562 `__getattr__()` but we don't, just set names in globals() as we import them. References: #6810 Change-Id: I34d17f617de2fe2c086fc556bd55748dc782faf0
Diffstat (limited to 'lib/sqlalchemy/sql/visitors.py')
-rw-r--r--lib/sqlalchemy/sql/visitors.py143
1 files changed, 133 insertions, 10 deletions
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 903aae648..081faf1e9 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -28,6 +28,7 @@ from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
+from typing import overload
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -37,6 +38,7 @@ from .. import exc
from .. import util
from ..util import langhelpers
from ..util._has_cy import HAS_CYEXTENSION
+from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
@@ -599,8 +601,8 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
raise NotImplementedError()
def _copy_internals(
- self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any
- ) -> Self:
+ self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> None:
"""Reassign internal elements to be clones of themselves.
Called during a copy-and-traverse operation on newly
@@ -615,10 +617,24 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
_ET = TypeVar("_ET", bound=ExternallyTraversible)
+
+
_TraverseCallableType = Callable[[_ET], None]
-_TraverseTransformCallableType = Callable[
- [ExternallyTraversible], Optional[ExternallyTraversible]
-]
+
+
+class _CloneCallableType(Protocol):
+ def __call__(self, element: _ET, **kw: Any) -> _ET:
+ ...
+
+
+class _TraverseTransformCallableType(Protocol):
+ def __call__(
+ self, element: ExternallyTraversible, **kw: Any
+ ) -> Optional[ExternallyTraversible]:
+ ...
+
+
+_ExtT = TypeVar("_ExtT", bound="ExternalTraversal")
class ExternalTraversal:
@@ -640,7 +656,7 @@ class ExternalTraversal:
return meth(obj, **kw)
def iterate(
- self, obj: ExternallyTraversible
+ self, obj: Optional[ExternallyTraversible]
) -> Iterator[ExternallyTraversible]:
"""Traverse the given expression structure, returning an iterator
of all elements.
@@ -648,7 +664,17 @@ class ExternalTraversal:
"""
return iterate(obj, self.__traverse_options__)
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure."""
return traverse(obj, self.__traverse_options__, self._visitor_dict)
@@ -671,7 +697,7 @@ class ExternalTraversal:
yield v
v = getattr(v, "_next", None)
- def chain(self, visitor: ExternalTraversal) -> ExternalTraversal:
+ def chain(self: _ExtT, visitor: ExternalTraversal) -> _ExtT:
"""'Chain' an additional ExternalTraversal onto this ExternalTraversal
The chained visitor will receive all visit events after this one.
@@ -701,7 +727,17 @@ class CloningExternalTraversal(ExternalTraversal):
"""
return [self.traverse(x) for x in list_]
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure."""
return cloned_traverse(
@@ -729,14 +765,25 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
"""
return None
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure."""
def replace(
- elem: ExternallyTraversible,
+ element: ExternallyTraversible,
+ **kw: Any,
) -> Optional[ExternallyTraversible]:
for v in self.visitor_iterator:
- e = cast(ReplacingExternalTraversal, v).replace(elem)
+ e = cast(ReplacingExternalTraversal, v).replace(element)
if e is not None:
return e
@@ -754,7 +801,8 @@ ReplacingCloningVisitor = ReplacingExternalTraversal
def iterate(
- obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any] = util.EMPTY_DICT,
) -> Iterator[ExternallyTraversible]:
r"""Traverse the given expression structure, returning an iterator.
@@ -776,6 +824,9 @@ def iterate(
empty in modern usage.
"""
+ if obj is None:
+ return
+
yield obj
children = obj.get_children(**opts)
@@ -790,11 +841,29 @@ def iterate(
stack.append(t.get_children(**opts))
+@overload
+def traverse_using(
+ iterator: Iterable[ExternallyTraversible],
+ obj: Literal[None],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> None:
+ ...
+
+
+@overload
def traverse_using(
iterator: Iterable[ExternallyTraversible],
obj: ExternallyTraversible,
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
+ ...
+
+
+def traverse_using(
+ iterator: Iterable[ExternallyTraversible],
+ obj: Optional[ExternallyTraversible],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> Optional[ExternallyTraversible]:
"""Visit the given expression structure using the given iterator of
objects.
@@ -826,11 +895,29 @@ def traverse_using(
return obj
+@overload
+def traverse(
+ obj: Literal[None],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> None:
+ ...
+
+
+@overload
def traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
+ ...
+
+
+def traverse(
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure using the default
iterator.
@@ -863,11 +950,29 @@ def traverse(
return traverse_using(iterate(obj, opts), obj, visitors)
+@overload
+def cloned_traverse(
+ obj: Literal[None],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> None:
+ ...
+
+
+@overload
def cloned_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
+ ...
+
+
+def cloned_traverse(
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> Optional[ExternallyTraversible]:
"""Clone the given expression structure, allowing modifications by
visitors.
@@ -931,11 +1036,29 @@ def cloned_traverse(
return obj
+@overload
+def replacement_traverse(
+ obj: Literal[None],
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType,
+) -> None:
+ ...
+
+
+@overload
def replacement_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
replace: _TraverseTransformCallableType,
) -> ExternallyTraversible:
+ ...
+
+
+def replacement_traverse(
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType,
+) -> Optional[ExternallyTraversible]:
"""Clone the given expression structure, allowing element
replacement by a given replacement function.