summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/interfaces.py14
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py162
-rw-r--r--lib/sqlalchemy/orm/util.py49
-rw-r--r--lib/sqlalchemy/sql/base.py10
-rw-r--r--lib/sqlalchemy/sql/cache_key.py15
-rw-r--r--lib/sqlalchemy/sql/elements.py4
-rw-r--r--lib/sqlalchemy/sql/traversals.py209
-rw-r--r--lib/sqlalchemy/sql/visitors.py105
-rw-r--r--lib/sqlalchemy/util/langhelpers.py4
9 files changed, 464 insertions, 108 deletions
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 08189a1b7..b9a5aaf51 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -64,20 +64,24 @@ __all__ = (
class ORMStatementRole(roles.StatementRole):
+ __slots__ = ()
_role_name = (
"Executable SQL or text() construct, including ORM " "aware objects"
)
class ORMColumnsClauseRole(roles.ColumnsClauseRole):
+ __slots__ = ()
_role_name = "ORM mapped entity, aliased entity, or Column expression"
class ORMEntityColumnsClauseRole(ORMColumnsClauseRole):
+ __slots__ = ()
_role_name = "ORM mapped or aliased entity"
class ORMFromClauseRole(roles.StrictFromClauseRole):
+ __slots__ = ()
_role_name = "ORM mapped entity, aliased entity, or FROM expression"
@@ -798,6 +802,8 @@ class CompileStateOption(HasCacheKey, ORMOption):
"""
+ __slots__ = ()
+
_is_compile_state = True
def process_compile_state(self, compile_state):
@@ -832,6 +838,8 @@ class LoaderOption(CompileStateOption):
"""
+ __slots__ = ()
+
def process_compile_state_replaced_entities(
self, compile_state, mapper_entities
):
@@ -846,6 +854,8 @@ class CriteriaOption(CompileStateOption):
"""
+ __slots__ = ()
+
_is_criteria_option = True
def get_global_criteria(self, attributes):
@@ -861,6 +871,8 @@ class UserDefinedOption(ORMOption):
"""
+ __slots__ = ("payload",)
+
_is_legacy_option = False
propagate_to_loaders = False
@@ -887,6 +899,8 @@ class UserDefinedOption(ORMOption):
class MapperOption(ORMOption):
"""Describe a modification to a Query"""
+ __slots__ = ()
+
_is_legacy_option = True
propagate_to_loaders = False
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index c2cfbb9fc..0f993b86c 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -13,6 +13,7 @@ from typing import Any
from typing import cast
from typing import Mapping
from typing import NoReturn
+from typing import Optional
from typing import Tuple
from typing import Union
@@ -32,9 +33,9 @@ from ..sql import and_
from ..sql import cache_key
from ..sql import coercions
from ..sql import roles
+from ..sql import traversals
from ..sql import visitors
from ..sql.base import _generative
-from ..sql.base import Generative
_RELATIONSHIP_TOKEN = "relationship"
_COLUMN_TOKEN = "column"
@@ -45,9 +46,11 @@ if typing.TYPE_CHECKING:
Self_AbstractLoad = typing.TypeVar("Self_AbstractLoad", bound="_AbstractLoad")
-class _AbstractLoad(Generative, LoaderOption):
+class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
+ __slots__ = ("propagate_to_loaders",)
+
_is_strategy_option = True
- propagate_to_loaders = False
+ propagate_to_loaders: bool
def contains_eager(self, attr, alias=None, _is_chain=False):
r"""Indicate that the given attribute should be eagerly loaded from
@@ -882,13 +885,20 @@ class Load(_AbstractLoad):
"""
- _cache_key_traversal = [
+ __slots__ = (
+ "path",
+ "context",
+ )
+
+ _traverse_internals = [
("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
(
"context",
visitors.InternalTraversal.dp_has_cache_key_list,
),
+ ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean),
]
+ _cache_key_traversal = None
path: PathRegistry
context: Tuple["_LoadElement", ...]
@@ -899,6 +909,7 @@ class Load(_AbstractLoad):
self.path = insp._path_registry
self.context = ()
+ self.propagate_to_loaders = False
def __str__(self):
return f"Load({self.path[0]})"
@@ -908,6 +919,7 @@ class Load(_AbstractLoad):
load = cls.__new__(cls)
load.path = path
load.context = ()
+ load.propagate_to_loaders = False
return load
def _adjust_for_extra_criteria(self, context):
@@ -1128,13 +1140,13 @@ class Load(_AbstractLoad):
self.context += (load_element,)
def __getstate__(self):
- d = self.__dict__.copy()
+ d = self._shallow_to_dict()
d["path"] = self.path.serialize()
return d
def __setstate__(self, state):
- self.__dict__.update(state)
- self.path = PathRegistry.deserialize(self.path)
+ state["path"] = PathRegistry.deserialize(state["path"])
+ self._shallow_from_dict(state)
SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
@@ -1143,16 +1155,27 @@ SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
class _WildcardLoad(_AbstractLoad):
"""represent a standalone '*' load operation"""
- _cache_key_traversal = [
+ __slots__ = ("strategy", "path", "local_opts")
+
+ _traverse_internals = [
("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("path", visitors.ExtendedInternalTraversal.dp_plain_obj),
(
"local_opts",
visitors.ExtendedInternalTraversal.dp_string_multi_dict,
),
]
+ cache_key_traversal = None
- local_opts = util.EMPTY_DICT
- path: Tuple[str, ...] = ()
+ strategy: Optional[Tuple[Any, ...]]
+ local_opts: Mapping[str, Any]
+ path: Tuple[str, ...]
+ propagate_to_loaders = False
+
+ def __init__(self):
+ self.path = ()
+ self.strategy = None
+ self.local_opts = util.EMPTY_DICT
def _clone_for_bind_strategy(
self,
@@ -1171,16 +1194,6 @@ class _WildcardLoad(_AbstractLoad):
and attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN)
)
- if attr == _DEFAULT_TOKEN:
- # for someload('*'), this currently does propagate=False,
- # to prevent it from taking effect for lazy loads.
- # it seems like adjusting for current_path for a lazy load etc.
- # should be taking care of that, so that the option still takes
- # effect for a refresh as well, but currently it does not.
- # probably should be adjusted to be more accurate re: current
- # path vs. refresh
- self.propagate_to_loaders = False
-
attr = f"{wildcard_key}:{attr}"
self.strategy = strategy
@@ -1310,13 +1323,16 @@ class _WildcardLoad(_AbstractLoad):
return None
def __getstate__(self):
- return self.__dict__.copy()
+ d = self._shallow_to_dict()
+ return d
def __setstate__(self, state):
- self.__dict__.update(state)
+ self._shallow_from_dict(state)
-class _LoadElement(cache_key.HasCacheKey):
+class _LoadElement(
+ cache_key.HasCacheKey, traversals.HasShallowCopy, visitors.Traversible
+):
"""represents strategy information to select for a LoaderStrategy
and pass options to it.
@@ -1328,40 +1344,66 @@ class _LoadElement(cache_key.HasCacheKey):
"""
- _cache_key_traversal = [
+ __slots__ = (
+ "path",
+ "strategy",
+ "propagate_to_loaders",
+ "local_opts",
+ "_extra_criteria",
+ "_reconcile_to_other",
+ )
+ __visit_name__ = "load_element"
+
+ _traverse_internals = [
("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
(
"local_opts",
visitors.ExtendedInternalTraversal.dp_string_multi_dict,
),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
+ ("propagate_to_loaders", visitors.InternalTraversal.dp_plain_obj),
+ ("_reconcile_to_other", visitors.InternalTraversal.dp_plain_obj),
]
+ _cache_key_traversal = None
- _extra_criteria = ()
+ _extra_criteria: Tuple[Any, ...]
- _reconcile_to_other = None
- strategy = None
+ _reconcile_to_other: Optional[bool]
+ strategy: Tuple[Any, ...]
path: PathRegistry
- propagate_to_loaders = False
+ propagate_to_loaders: bool
local_opts: Mapping[str, Any]
is_token_strategy: bool
is_class_strategy: bool
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return traversals.compare(self, other)
+
@property
def is_opts_only(self):
return bool(self.local_opts and self.strategy is None)
+ def _clone(self):
+ cls = self.__class__
+ s = cls.__new__(cls)
+
+ self._shallow_copy_to(s)
+ return s
+
def __getstate__(self):
- d = self.__dict__.copy()
+ d = self._shallow_to_dict()
d["path"] = self.path.serialize()
-
return d
def __setstate__(self, state):
state["path"] = PathRegistry.deserialize(state["path"])
- self.__dict__.update(state)
+ self._shallow_from_dict(state)
def _raise_for_no_match(self, parent_loader, mapper_entities):
path = parent_loader.path
@@ -1498,11 +1540,14 @@ class _LoadElement(cache_key.HasCacheKey):
opt.local_opts = (
util.immutabledict(local_opts) if local_opts else util.EMPTY_DICT
)
+ opt._extra_criteria = ()
if reconcile_to_other is not None:
opt._reconcile_to_other = reconcile_to_other
elif strategy is None and not local_opts:
opt._reconcile_to_other = True
+ else:
+ opt._reconcile_to_other = None
path = opt._init_path(path, attr, wildcard_key, attr_group, raiseerr)
@@ -1517,12 +1562,6 @@ class _LoadElement(cache_key.HasCacheKey):
def __init__(self, path, strategy, local_opts, propagate_to_loaders):
raise NotImplementedError()
- def _clone(self):
- cls = self.__class__
- s = cls.__new__(cls)
- s.__dict__ = self.__dict__.copy()
- return s
-
def _prepend_path_from(self, parent):
"""adjust the path of this :class:`._LoadElement` to be
a subpath of that of the given parent :class:`_orm.Load` object's
@@ -1617,20 +1656,28 @@ class _AttributeStrategyLoad(_LoadElement):
"""
- _cache_key_traversal = _LoadElement._cache_key_traversal + [
+ __slots__ = ("_of_type", "_path_with_polymorphic_path")
+
+ __visit_name__ = "attribute_strategy_load_element"
+
+ _traverse_internals = _LoadElement._traverse_internals + [
("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
- ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
+ (
+ "_path_with_polymorphic_path",
+ visitors.ExtendedInternalTraversal.dp_has_cache_key,
+ ),
]
- _of_type: Union["Mapper", AliasedInsp, None] = None
- _path_with_polymorphic_path = None
+ _of_type: Union["Mapper", AliasedInsp, None]
+ _path_with_polymorphic_path: Optional[PathRegistry]
- inherit_cache = True
is_class_strategy = False
is_token_strategy = False
def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr):
assert attr is not None
+ self._of_type = None
+ self._path_with_polymorphic_path = None
insp, _, prop = _parse_attr_argument(attr)
if insp.is_property:
@@ -1832,12 +1879,14 @@ class _AttributeStrategyLoad(_LoadElement):
return [("loader", cast(PathRegistry, effective_path).natural_path)]
def __getstate__(self):
- d = self.__dict__.copy()
+ d = super().__getstate__()
+
+ # can't pickle this. See
+ # test_pickled.py -> test_lazyload_extra_criteria_not_supported
+ # where we should be emitting a warning for the usual case where this
+ # would be non-None
d["_extra_criteria"] = ()
- d["path"] = self.path.serialize()
- # TODO: we hope to do this logic only at compile time so that
- # we aren't carrying these extra attributes around
if self._path_with_polymorphic_path:
d[
"_path_with_polymorphic_path"
@@ -1854,14 +1903,19 @@ class _AttributeStrategyLoad(_LoadElement):
return d
def __setstate__(self, state):
- state["path"] = PathRegistry.deserialize(state["path"])
- self.__dict__.update(state)
- if "_path_with_polymorphic_path" in state:
+ super().__setstate__(state)
+
+ if state.get("_path_with_polymorphic_path", None):
self._path_with_polymorphic_path = PathRegistry.deserialize(
- self._path_with_polymorphic_path
+ state["_path_with_polymorphic_path"]
)
- if self._of_type is not None:
- self._of_type = inspect(self._of_type)
+ else:
+ self._path_with_polymorphic_path = None
+
+ if state.get("_of_type", None):
+ self._of_type = inspect(state["_of_type"])
+ else:
+ self._of_type = None
class _TokenStrategyLoad(_LoadElement):
@@ -1877,6 +1931,8 @@ class _TokenStrategyLoad(_LoadElement):
"""
+ __visit_name__ = "token_strategy_load_element"
+
inherit_cache = True
is_class_strategy = False
is_token_strategy = True
@@ -1962,6 +2018,8 @@ class _ClassStrategyLoad(_LoadElement):
is_class_strategy = True
is_token_strategy = False
+ __visit_name__ = "class_strategy_load_element"
+
def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr):
return path
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index e84517670..75f711007 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -45,6 +45,7 @@ from ..sql import util as sql_util
from ..sql import visitors
from ..sql.annotation import SupportsCloneAnnotations
from ..sql.base import ColumnCollection
+from ..util.langhelpers import MemoizedSlots
all_cascades = frozenset(
@@ -609,8 +610,9 @@ class AliasedClass:
class AliasedInsp(
ORMEntityColumnsClauseRole,
ORMFromClauseRole,
- sql_base.MemoizedHasCacheKey,
+ sql_base.HasCacheKey,
InspectionAttr,
+ MemoizedSlots,
):
"""Provide an inspection interface for an
:class:`.AliasedClass` object.
@@ -650,6 +652,30 @@ class AliasedInsp(
"""
+ __slots__ = (
+ "__weakref__",
+ "_weak_entity",
+ "mapper",
+ "selectable",
+ "name",
+ "_adapt_on_names",
+ "with_polymorphic_mappers",
+ "polymorphic_on",
+ "_use_mapper_path",
+ "_base_alias",
+ "represents_outer_join",
+ "persist_selectable",
+ "local_table",
+ "_is_with_polymorphic",
+ "_with_polymorphic_entities",
+ "_adapter",
+ "_target",
+ "__clause_element__",
+ "_memoized_values",
+ "_all_column_expressions",
+ "_nest_adapters",
+ )
+
def __init__(
self,
entity,
@@ -738,8 +764,7 @@ class AliasedInsp(
is_aliased_class = True
"always returns True"
- @util.memoized_instancemethod
- def __clause_element__(self):
+ def _memoized_method___clause_element__(self):
return self.selectable._annotate(
{
"parentmapper": self.mapper,
@@ -863,8 +888,7 @@ class AliasedInsp(
else:
assert False, "mapper %s doesn't correspond to %s" % (mapper, self)
- @util.memoized_property
- def _get_clause(self):
+ def _memoized_attr__get_clause(self):
onclause, replacemap = self.mapper._get_clause
return (
self._adapter.traverse(onclause),
@@ -874,12 +898,10 @@ class AliasedInsp(
},
)
- @util.memoized_property
- def _memoized_values(self):
+ def _memoized_attr__memoized_values(self):
return {}
- @util.memoized_property
- def _all_column_expressions(self):
+ def _memoized_attr__all_column_expressions(self):
if self._is_with_polymorphic:
cols_plus_keys = self.mapper._columns_plus_keys(
[ent.mapper for ent in self._with_polymorphic_entities]
@@ -965,6 +987,15 @@ class LoaderCriteriaOption(CriteriaOption):
"""
+ __slots__ = (
+ "root_entity",
+ "entity",
+ "deferred_where_criteria",
+ "where_criteria",
+ "include_aliases",
+ "propagate_to_loaders",
+ )
+
_traverse_internals = [
("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj),
("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key),
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 74469b035..8ae8f8f65 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -17,6 +17,7 @@ from itertools import zip_longest
import operator
import re
import typing
+from typing import TypeVar
from . import roles
from . import visitors
@@ -571,11 +572,14 @@ class CompileState:
return decorate
+SelfGenerative = TypeVar("SelfGenerative", bound="Generative")
+
+
class Generative(HasMemoized):
"""Provide a method-chaining pattern in conjunction with the
@_generative decorator."""
- def _generate(self):
+ def _generate(self: SelfGenerative) -> SelfGenerative:
skip = self._memoized_keys
cls = self.__class__
s = cls.__new__(cls)
@@ -783,6 +787,8 @@ class Options(metaclass=_MetaOptions):
class CacheableOptions(Options, HasCacheKey):
+ __slots__ = ()
+
@hybridmethod
def _gen_cache_key(self, anon_map, bindparams):
return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
@@ -797,6 +803,8 @@ class CacheableOptions(Options, HasCacheKey):
class ExecutableOption(HasCopyInternals):
+ __slots__ = ()
+
_annotations = util.EMPTY_DICT
__visit_name__ = "executable_option"
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index 8dd44dbf0..42bd60353 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -47,6 +47,11 @@ class CacheTraverseTarget(enum.Enum):
class HasCacheKey:
"""Mixin for objects which can produce a cache key.
+ This class is usually in a hierarchy that starts with the
+ :class:`.HasTraverseInternals` base, but this is optional. Currently,
+ the class should be able to work on its own without including
+ :class:`.HasTraverseInternals`.
+
.. seealso::
:class:`.CacheKey`
@@ -55,6 +60,8 @@ class HasCacheKey:
"""
+ __slots__ = ()
+
_cache_key_traversal = NO_CACHE
_is_has_cache_key = True
@@ -106,11 +113,17 @@ class HasCacheKey:
_cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
if _cache_key_traversal is None:
try:
+ # this would be HasTraverseInternals
_cache_key_traversal = cls._traverse_internals
except AttributeError:
cls._generated_cache_key_traversal = NO_CACHE
return NO_CACHE
+ assert _cache_key_traversal is not NO_CACHE, (
+ f"class {cls} has _cache_key_traversal=NO_CACHE, "
+ "which conflicts with inherit_cache=True"
+ )
+
# TODO: wouldn't we instead get this from our superclass?
# also, our superclass may not have this yet, but in any case,
# we'd generate for the superclass that has it. this is a little
@@ -323,6 +336,8 @@ class HasCacheKey:
class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
+ __slots__ = ()
+
@HasMemoized.memoized_instancemethod
def _generate_cache_key(self):
return HasCacheKey._generate_cache_key(self)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 43979b4ae..d14521ba7 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -46,7 +46,7 @@ from .traversals import HasCopyInternals
from .visitors import cloned_traverse
from .visitors import InternalTraversal
from .visitors import traverse
-from .visitors import Traversible
+from .visitors import Visitable
from .. import exc
from .. import inspection
from .. import util
@@ -126,7 +126,7 @@ def literal_column(text, type_=None):
return ColumnClause(text, type_=type_, is_literal=True)
-class CompilerElement(Traversible):
+class CompilerElement(Visitable):
"""base class for SQL elements that can be compiled to produce a
SQL string.
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 2fa3a0408..18fd1d4b8 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -10,12 +10,22 @@ import collections.abc as collections_abc
import itertools
from itertools import zip_longest
import operator
+import typing
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Type
+from typing import TypeVar
from . import operators
+from .cache_key import HasCacheKey
+from .visitors import _TraverseInternalsType
from .visitors import anon_map
+from .visitors import ExtendedInternalTraversal
+from .visitors import HasTraverseInternals
from .visitors import InternalTraversal
from .. import util
-
+from ..util import langhelpers
SKIP_TRAVERSE = util.symbol("skip_traverse")
COMPARE_FAILED = False
@@ -47,11 +57,158 @@ def _preconfigure_traversals(target_hierarchy):
)
+SelfHasShallowCopy = TypeVar("SelfHasShallowCopy", bound="HasShallowCopy")
+
+
+class HasShallowCopy(HasTraverseInternals):
+ """attribute-wide operations that are useful for classes that use
+ __slots__ and therefore can't operate on their attributes in a dictionary.
+
+
+ """
+
+ __slots__ = ()
+
+ if typing.TYPE_CHECKING:
+
+ def _generated_shallow_copy_traversal(
+ self: SelfHasShallowCopy, other: SelfHasShallowCopy
+ ) -> None:
+ ...
+
+ def _generated_shallow_from_dict_traversal(
+ self, d: Dict[str, Any]
+ ) -> None:
+ ...
+
+ def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]:
+ ...
+
+ @classmethod
+ def _generate_shallow_copy(
+ cls: Type[SelfHasShallowCopy],
+ internal_dispatch: _TraverseInternalsType,
+ method_name: str,
+ ) -> Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None]:
+ code = "\n".join(
+ f" other.{attrname} = self.{attrname}"
+ for attrname, _ in internal_dispatch
+ )
+ meth_text = f"def {method_name}(self, other):\n{code}\n"
+ return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+
+ @classmethod
+ def _generate_shallow_to_dict(
+ cls: Type[SelfHasShallowCopy],
+ internal_dispatch: _TraverseInternalsType,
+ method_name: str,
+ ) -> Callable[[SelfHasShallowCopy], Dict[str, Any]]:
+ code = ",\n".join(
+ f" '{attrname}': self.{attrname}"
+ for attrname, _ in internal_dispatch
+ )
+ meth_text = f"def {method_name}(self):\n return {{{code}}}\n"
+ return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+
+ @classmethod
+ def _generate_shallow_from_dict(
+ cls: Type[SelfHasShallowCopy],
+ internal_dispatch: _TraverseInternalsType,
+ method_name: str,
+ ) -> Callable[[SelfHasShallowCopy, Dict[str, Any]], None]:
+ code = "\n".join(
+ f" self.{attrname} = d['{attrname}']"
+ for attrname, _ in internal_dispatch
+ )
+ meth_text = f"def {method_name}(self, d):\n{code}\n"
+ return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+
+ def _shallow_from_dict(self, d: Dict) -> None:
+ cls = self.__class__
+
+ try:
+ shallow_from_dict = cls.__dict__[
+ "_generated_shallow_from_dict_traversal"
+ ]
+ except KeyError:
+ shallow_from_dict = (
+ cls._generated_shallow_from_dict_traversal # type: ignore
+ ) = self._generate_shallow_from_dict(
+ cls._traverse_internals,
+ "_generated_shallow_from_dict_traversal",
+ )
+
+ shallow_from_dict(self, d)
+
+ def _shallow_to_dict(self) -> Dict[str, Any]:
+ cls = self.__class__
+
+ try:
+ shallow_to_dict = cls.__dict__[
+ "_generated_shallow_to_dict_traversal"
+ ]
+ except KeyError:
+ shallow_to_dict = (
+ cls._generated_shallow_to_dict_traversal # type: ignore
+ ) = self._generate_shallow_to_dict(
+ cls._traverse_internals, "_generated_shallow_to_dict_traversal"
+ )
+
+ return shallow_to_dict(self)
+
+ def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy):
+ cls = self.__class__
+
+ try:
+ shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
+ except KeyError:
+ shallow_copy = (
+ cls._generated_shallow_copy_traversal # type: ignore
+ ) = self._generate_shallow_copy(
+ cls._traverse_internals, "_generated_shallow_copy_traversal"
+ )
+
+ shallow_copy(self, other)
+
+ def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy:
+ """Create a shallow copy"""
+ c = self.__class__.__new__(self.__class__)
+ self._shallow_copy_to(c)
+ return c
+
+
+SelfGenerativeOnTraversal = TypeVar(
+ "SelfGenerativeOnTraversal", bound="GenerativeOnTraversal"
+)
+
+
+class GenerativeOnTraversal(HasShallowCopy):
+ """Supplies Generative behavior but making use of traversals to shallow
+ copy.
+
+ .. seealso::
+
+ :class:`sqlalchemy.sql.base.Generative`
+
+
+ """
+
+ __slots__ = ()
+
+ def _generate(
+ self: SelfGenerativeOnTraversal,
+ ) -> SelfGenerativeOnTraversal:
+ cls = self.__class__
+ s = cls.__new__(cls)
+ self._shallow_copy_to(s)
+ return s
+
+
def _clone(element, **kw):
return element._clone()
-class HasCopyInternals:
+class HasCopyInternals(HasTraverseInternals):
__slots__ = ()
def _clone(self, **kw):
@@ -304,7 +461,9 @@ def _resolve_name_for_compare(element, name, anon_map, **kw):
return name
-class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
+class TraversalComparatorStrategy(
+ ExtendedInternalTraversal, util.MemoizedSlots
+):
__slots__ = "stack", "cache", "anon_map"
def __init__(self):
@@ -377,6 +536,10 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
continue
dispatch = self.dispatch(left_visit_sym)
+ assert dispatch, (
+ f"{self.__class__} has no dispatch for "
+ f"'{self._dispatch_lookup[left_visit_sym]}'"
+ )
left_child = operator.attrgetter(left_attrname)(left)
right_child = operator.attrgetter(right_attrname)(right)
if left_child is None:
@@ -517,6 +680,46 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
):
return left == right
+ def visit_string_multi_dict(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+
+ for lk, rk in zip_longest(
+ sorted(left.keys()), sorted(right.keys()), fillvalue=(None, None)
+ ):
+ if lk != rk:
+ return COMPARE_FAILED
+
+ lv, rv = left[lk], right[rk]
+
+ lhc = isinstance(left, HasCacheKey)
+ rhc = isinstance(right, HasCacheKey)
+ if lhc and rhc:
+ if lv._gen_cache_key(
+ self.anon_map[0], []
+ ) != rv._gen_cache_key(self.anon_map[1], []):
+ return COMPARE_FAILED
+ elif lhc != rhc:
+ return COMPARE_FAILED
+ elif lv != rv:
+ return COMPARE_FAILED
+
+ def visit_multi(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+
+ lhc = isinstance(left, HasCacheKey)
+ rhc = isinstance(right, HasCacheKey)
+ if lhc and rhc:
+ if left._gen_cache_key(
+ self.anon_map[0], []
+ ) != right._gen_cache_key(self.anon_map[1], []):
+ return COMPARE_FAILED
+ elif lhc != rhc:
+ return COMPARE_FAILED
+ else:
+ return left == right
+
def visit_anon_name(
self, attrname, left_parent, left, right_parent, right, **kw
):
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 70c4dc133..78384782b 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -26,11 +26,14 @@ https://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
from collections import deque
import itertools
import operator
+from typing import List
+from typing import Tuple
from .. import exc
from .. import util
from ..util import langhelpers
from ..util import symbol
+from ..util.langhelpers import _symbol
try:
from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
@@ -43,14 +46,67 @@ __all__ = [
"traverse",
"cloned_traverse",
"replacement_traverse",
- "Traversible",
+ "Visitable",
"ExternalTraversal",
"InternalTraversal",
]
+_TraverseInternalsType = List[Tuple[str, _symbol]]
-class Traversible:
- """Base class for visitable objects."""
+
+class HasTraverseInternals:
+ """base for classes that have a "traverse internals" element,
+ which defines all kinds of ways of traversing the elements of an object.
+
+ """
+
+ __slots__ = ()
+
+ _traverse_internals: _TraverseInternalsType
+
+ @util.preload_module("sqlalchemy.sql.traversals")
+ def get_children(self, omit_attrs=(), **kw):
+ r"""Return immediate child :class:`.visitors.Visitable`
+ elements of this :class:`.visitors.Visitable`.
+
+ This is used for visit traversal.
+
+ \**kw may contain flags that change the collection that is
+ returned, for example to return a subset of items in order to
+ cut down on larger traversals, or to return child items from a
+ different context (such as schema-level collections instead of
+ clause-level).
+
+ """
+
+ traversals = util.preloaded.sql_traversals
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return []
+
+ dispatch = traversals._get_children.run_generated_dispatch
+ return itertools.chain.from_iterable(
+ meth(obj, **kw)
+ for attrname, obj, meth in dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ )
+ if attrname not in omit_attrs and obj is not None
+ )
+
+
+class Visitable:
+ """Base class for visitable objects.
+
+ .. versionchanged:: 2.0 The :class:`.Visitable` class was named
+ :class:`.Traversible` in the 1.4 series; the name is changed back
+ to :class:`.Visitable` in 2.0 which is what it was prior to 1.4.
+
+ Both names remain importable in both 1.4 and 2.0 versions.
+
+ """
__slots__ = ()
@@ -120,38 +176,6 @@ class Traversible:
# allow generic classes in py3.9+
return cls
- @util.preload_module("sqlalchemy.sql.traversals")
- def get_children(self, omit_attrs=(), **kw):
- r"""Return immediate child :class:`.visitors.Traversible`
- elements of this :class:`.visitors.Traversible`.
-
- This is used for visit traversal.
-
- \**kw may contain flags that change the collection that is
- returned, for example to return a subset of items in order to
- cut down on larger traversals, or to return child items from a
- different context (such as schema-level collections instead of
- clause-level).
-
- """
-
- traversals = util.preloaded.sql_traversals
-
- try:
- traverse_internals = self._traverse_internals
- except AttributeError:
- # user-defined classes may not have a _traverse_internals
- return []
-
- dispatch = traversals._get_children.run_generated_dispatch
- return itertools.chain.from_iterable(
- meth(obj, **kw)
- for attrname, obj, meth in dispatch(
- self, traverse_internals, "_generated_get_children_traversal"
- )
- if attrname not in omit_attrs and obj is not None
- )
-
class _HasTraversalDispatch:
r"""Define infrastructure for the :class:`.InternalTraversal` class.
@@ -261,14 +285,14 @@ class InternalTraversal(_HasTraversalDispatch):
:class:`.InternalTraversible` will have the following methods automatically
implemented:
- * :meth:`.Traversible.get_children`
+ * :meth:`.HasTraverseInternals.get_children`
- * :meth:`.Traversible._copy_internals`
+ * :meth:`.HasTraverseInternals._copy_internals`
- * :meth:`.Traversible._gen_cache_key`
+ * :meth:`.HasCacheKey._gen_cache_key`
Subclasses can also implement these methods directly, particularly for the
- :meth:`.Traversible._copy_internals` method, when special steps
+ :meth:`.HasTraverseInternals._copy_internals` method, when special steps
are needed.
.. versionadded:: 1.4
@@ -625,7 +649,8 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
# backwards compatibility
-Visitable = Traversible
+Traversible = Visitable
+
ClauseVisitor = ExternalTraversal
CloningVisitor = CloningExternalTraversal
ReplacingCloningVisitor = ReplacingExternalTraversal
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 80ef3458c..8b65fb4cf 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -1156,7 +1156,9 @@ class MemoizedSlots:
raise AttributeError(key)
def __getattr__(self, key):
- if key.startswith("_memoized"):
+ if key.startswith("_memoized_attr_") or key.startswith(
+ "_memoized_method_"
+ ):
raise AttributeError(key)
elif hasattr(self, "_memoized_attr_%s" % key):
value = getattr(self, "_memoized_attr_%s" % key)()