diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-28 16:19:43 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-05-03 15:58:45 -0400 |
| commit | 1fa3e2e3814b4d28deca7426bb3f36e7fb515496 (patch) | |
| tree | 9b07b8437b1190227c2e8c51f2e942936721000f /lib/sqlalchemy/orm/collections.py | |
| parent | 6a496a5f40efe6d58b09eeca9320829789ceaa54 (diff) | |
| download | sqlalchemy-1fa3e2e3814b4d28deca7426bb3f36e7fb515496.tar.gz | |
pep484: attributes and related
also implements __slots__ for QueryableAttribute,
InstrumentedAttribute, Relationship.Comparator.
Change-Id: I47e823160706fc35a616f1179a06c7864089e5b5
Diffstat (limited to 'lib/sqlalchemy/orm/collections.py')
| -rw-r--r-- | lib/sqlalchemy/orm/collections.py | 208 |
1 files changed, 146 insertions, 62 deletions
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 717f1d0d6..da0da0fcf 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -4,7 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +# mypy: allow-untyped-defs, allow-untyped-calls """Support for collections of mapped entities. @@ -109,17 +109,34 @@ import operator import threading import typing from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union import weakref from .. import exc as sa_exc from .. import util from ..util.compat import inspect_getfullargspec +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from .attributes import CollectionAttributeImpl from .mapped_collection import attribute_mapped_collection from .mapped_collection import column_mapped_collection from .mapped_collection import mapped_collection from .mapped_collection import MappedCollection # noqa: F401 + from .state import InstanceState + __all__ = [ "collection", @@ -132,6 +149,28 @@ __all__ = [ __instrumentation_mutex = threading.Lock() +_CollectionFactoryType = Callable[[], "_AdaptedCollectionProtocol"] + +_T = TypeVar("_T", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) +_COL = TypeVar("_COL", bound="Collection[Any]") +_FN = TypeVar("_FN", bound="Callable[..., Any]") + + +class _CollectionConverterProtocol(Protocol): + def __call__(self, collection: _COL) -> _COL: + ... + + +class _AdaptedCollectionProtocol(Protocol): + _sa_adapter: CollectionAdapter + _sa_appender: Callable[..., Any] + _sa_remover: Callable[..., Any] + _sa_iterator: Callable[..., Iterable[Any]] + _sa_converter: _CollectionConverterProtocol + + class collection: """Decorators for entity collection classes. @@ -396,8 +435,13 @@ class collection: return decorator -collection_adapter = operator.attrgetter("_sa_adapter") -"""Fetch the :class:`.CollectionAdapter` for a collection.""" +if TYPE_CHECKING: + + def collection_adapter(collection: Collection[Any]) -> CollectionAdapter: + """Fetch the :class:`.CollectionAdapter` for a collection.""" + +else: + collection_adapter = operator.attrgetter("_sa_adapter") class CollectionAdapter: @@ -423,10 +467,33 @@ class CollectionAdapter: "empty", ) - def __init__(self, attr, owner_state, data): + attr: CollectionAttributeImpl + _key: str + + # this is actually a weakref; see note in constructor + _data: Callable[..., _AdaptedCollectionProtocol] + + owner_state: InstanceState[Any] + _converter: _CollectionConverterProtocol + invalidated: bool + empty: bool + + def __init__( + self, + attr: CollectionAttributeImpl, + owner_state: InstanceState[Any], + data: _AdaptedCollectionProtocol, + ): self.attr = attr self._key = attr.key - self._data = weakref.ref(data) + + # this weakref stays referenced throughout the lifespan of + # CollectionAdapter. so while the weakref can return None, this + # is realistically only during garbage collection of this object, so + # we type this as a callable that returns _AdaptedCollectionProtocol + # in all cases. + self._data = weakref.ref(data) # type: ignore + self.owner_state = owner_state data._sa_adapter = self self._converter = data._sa_converter @@ -437,7 +504,7 @@ class CollectionAdapter: util.warn("This collection has been invalidated.") @property - def data(self): + def data(self) -> _AdaptedCollectionProtocol: "The entity collection being adapted." return self._data() @@ -634,7 +701,10 @@ class CollectionAdapter: def __setstate__(self, d): self._key = d["key"] self.owner_state = d["owner_state"] - self._data = weakref.ref(d["data"]) + + # see note in constructor regarding this type: ignore + self._data = weakref.ref(d["data"]) # type: ignore + self._converter = d["data"]._sa_converter d["data"]._sa_adapter = self self.invalidated = d["invalidated"] @@ -682,7 +752,9 @@ def bulk_replace(values, existing_adapter, new_adapter, initiator=None): existing_adapter.fire_remove_event(member, initiator=initiator) -def prepare_instrumentation(factory): +def prepare_instrumentation( + factory: Union[Type[Collection[Any]], _CollectionFactoryType], +) -> _CollectionFactoryType: """Prepare a callable for future use as a collection class factory. Given a collection class factory (either a type or no-arg callable), @@ -693,18 +765,30 @@ def prepare_instrumentation(factory): into the run-time behavior of collection_class=InstrumentedList. """ + + impl_factory: _CollectionFactoryType + # Convert a builtin to 'Instrumented*' if factory in __canned_instrumentation: - factory = __canned_instrumentation[factory] + impl_factory = __canned_instrumentation[factory] + else: + impl_factory = cast(_CollectionFactoryType, factory) + + cls: Union[_CollectionFactoryType, Type[Collection[Any]]] # Create a specimen - cls = type(factory()) + cls = type(impl_factory()) # Did factory callable return a builtin? if cls in __canned_instrumentation: - # Wrap it so that it returns our 'Instrumented*' - factory = __converting_factory(cls, factory) - cls = factory() + + # if so, just convert. + # in previous major releases, this codepath wasn't working and was + # not covered by tests. prior to that it supplied a "wrapper" + # function that would return the class, though the rationale for this + # case is not known + impl_factory = __canned_instrumentation[cls] + cls = type(impl_factory()) # Instrument the class if needed. if __instrumentation_mutex.acquire(): @@ -714,26 +798,7 @@ def prepare_instrumentation(factory): finally: __instrumentation_mutex.release() - return factory - - -def __converting_factory(specimen_cls, original_factory): - """Return a wrapper that converts a "canned" collection like - set, dict, list into the Instrumented* version. - - """ - - instrumented_cls = __canned_instrumentation[specimen_cls] - - def wrapper(): - collection = original_factory() - return instrumented_cls(collection) - - # often flawed but better than nothing - wrapper.__name__ = "%sWrapper" % original_factory.__name__ - wrapper.__doc__ = original_factory.__doc__ - - return wrapper + return impl_factory def _instrument_class(cls): @@ -763,8 +828,8 @@ def _locate_roles_and_methods(cls): """ - roles = {} - methods = {} + roles: Dict[str, str] = {} + methods: Dict[str, Tuple[Optional[str], Optional[int], Optional[str]]] = {} for supercls in cls.__mro__: for name, method in vars(supercls).items(): @@ -784,7 +849,9 @@ def _locate_roles_and_methods(cls): # transfer instrumentation requests from decorated function # to the combined queue - before, after = None, None + before: Optional[Tuple[str, int]] = None + after: Optional[str] = None + if hasattr(method, "_sa_instrument_before"): op, argument = method._sa_instrument_before assert op in ("fire_append_event", "fire_remove_event") @@ -809,6 +876,7 @@ def _setup_canned_roles(cls, roles, methods): """ collection_type = util.duck_type_collection(cls) if collection_type in __interfaces: + assert collection_type is not None canned_roles, decorators = __interfaces[collection_type] for role, name in canned_roles.items(): roles.setdefault(role, name) @@ -934,9 +1002,9 @@ def _instrument_membership_mutator(method, before, argument, after): getattr(executor, after)(res, initiator) return res - wrapper._sa_instrumented = True + wrapper._sa_instrumented = True # type: ignore[attr-defined] if hasattr(method, "_sa_instrument_role"): - wrapper._sa_instrument_role = method._sa_instrument_role + wrapper._sa_instrument_role = method._sa_instrument_role # type: ignore[attr-defined] # noqa: E501 wrapper.__name__ = method.__name__ wrapper.__doc__ = method.__doc__ return wrapper @@ -990,7 +1058,7 @@ def __before_pop(collection, _sa_initiator=None): executor.fire_pre_remove_event(_sa_initiator) -def _list_decorators(): +def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: """Tailored instrumentation wrappers for any list-like class.""" def _tidy(fn): @@ -1131,7 +1199,7 @@ def _list_decorators(): return l -def _dict_decorators(): +def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]: """Tailored instrumentation wrappers for any dict-like mapping class.""" def _tidy(fn): @@ -1255,7 +1323,7 @@ def _set_binops_check_loose(self: Any, obj: Any) -> bool: ) -def _set_decorators(): +def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]: """Tailored instrumentation wrappers for any set-like class.""" def _tidy(fn): @@ -1420,36 +1488,52 @@ def _set_decorators(): return l -class InstrumentedList(list): +class InstrumentedList(List[_T]): """An instrumented version of the built-in list.""" -class InstrumentedSet(set): +class InstrumentedSet(Set[_T]): """An instrumented version of the built-in set.""" -class InstrumentedDict(dict): +class InstrumentedDict(Dict[_KT, _VT]): """An instrumented version of the built-in dict.""" -__canned_instrumentation = { - list: InstrumentedList, - set: InstrumentedSet, - dict: InstrumentedDict, -} - -__interfaces = { - list: ( - {"appender": "append", "remover": "remove", "iterator": "__iter__"}, - _list_decorators(), - ), - set: ( - {"appender": "add", "remover": "remove", "iterator": "__iter__"}, - _set_decorators(), - ), - # decorators are required for dicts and object collections. - dict: ({"iterator": "values"}, _dict_decorators()), -} +__canned_instrumentation: util.immutabledict[ + Any, _CollectionFactoryType +] = util.immutabledict( + { + list: InstrumentedList, + set: InstrumentedSet, + dict: InstrumentedDict, + } +) + +__interfaces: util.immutabledict[ + Any, + Tuple[ + Dict[str, str], + Dict[str, Callable[..., Any]], + ], +] = util.immutabledict( + { + list: ( + { + "appender": "append", + "remover": "remove", + "iterator": "__iter__", + }, + _list_decorators(), + ), + set: ( + {"appender": "add", "remover": "remove", "iterator": "__iter__"}, + _set_decorators(), + ), + # decorators are required for dicts and object collections. + dict: ({"iterator": "values"}, _dict_decorators()), + } +) def __go(lcls): |
