summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/collections.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-28 16:19:43 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-05-03 15:58:45 -0400
commit1fa3e2e3814b4d28deca7426bb3f36e7fb515496 (patch)
tree9b07b8437b1190227c2e8c51f2e942936721000f /lib/sqlalchemy/orm/collections.py
parent6a496a5f40efe6d58b09eeca9320829789ceaa54 (diff)
downloadsqlalchemy-1fa3e2e3814b4d28deca7426bb3f36e7fb515496.tar.gz
pep484: attributes and related
also implements __slots__ for QueryableAttribute, InstrumentedAttribute, Relationship.Comparator. Change-Id: I47e823160706fc35a616f1179a06c7864089e5b5
Diffstat (limited to 'lib/sqlalchemy/orm/collections.py')
-rw-r--r--lib/sqlalchemy/orm/collections.py208
1 files changed, 146 insertions, 62 deletions
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index 717f1d0d6..da0da0fcf 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -4,7 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+# mypy: allow-untyped-defs, allow-untyped-calls
"""Support for collections of mapped entities.
@@ -109,17 +109,34 @@ import operator
import threading
import typing
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
import weakref
from .. import exc as sa_exc
from .. import util
from ..util.compat import inspect_getfullargspec
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from .attributes import CollectionAttributeImpl
from .mapped_collection import attribute_mapped_collection
from .mapped_collection import column_mapped_collection
from .mapped_collection import mapped_collection
from .mapped_collection import MappedCollection # noqa: F401
+ from .state import InstanceState
+
__all__ = [
"collection",
@@ -132,6 +149,28 @@ __all__ = [
__instrumentation_mutex = threading.Lock()
+_CollectionFactoryType = Callable[[], "_AdaptedCollectionProtocol"]
+
+_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
+_COL = TypeVar("_COL", bound="Collection[Any]")
+_FN = TypeVar("_FN", bound="Callable[..., Any]")
+
+
+class _CollectionConverterProtocol(Protocol):
+ def __call__(self, collection: _COL) -> _COL:
+ ...
+
+
+class _AdaptedCollectionProtocol(Protocol):
+ _sa_adapter: CollectionAdapter
+ _sa_appender: Callable[..., Any]
+ _sa_remover: Callable[..., Any]
+ _sa_iterator: Callable[..., Iterable[Any]]
+ _sa_converter: _CollectionConverterProtocol
+
+
class collection:
"""Decorators for entity collection classes.
@@ -396,8 +435,13 @@ class collection:
return decorator
-collection_adapter = operator.attrgetter("_sa_adapter")
-"""Fetch the :class:`.CollectionAdapter` for a collection."""
+if TYPE_CHECKING:
+
+ def collection_adapter(collection: Collection[Any]) -> CollectionAdapter:
+ """Fetch the :class:`.CollectionAdapter` for a collection."""
+
+else:
+ collection_adapter = operator.attrgetter("_sa_adapter")
class CollectionAdapter:
@@ -423,10 +467,33 @@ class CollectionAdapter:
"empty",
)
- def __init__(self, attr, owner_state, data):
+ attr: CollectionAttributeImpl
+ _key: str
+
+ # this is actually a weakref; see note in constructor
+ _data: Callable[..., _AdaptedCollectionProtocol]
+
+ owner_state: InstanceState[Any]
+ _converter: _CollectionConverterProtocol
+ invalidated: bool
+ empty: bool
+
+ def __init__(
+ self,
+ attr: CollectionAttributeImpl,
+ owner_state: InstanceState[Any],
+ data: _AdaptedCollectionProtocol,
+ ):
self.attr = attr
self._key = attr.key
- self._data = weakref.ref(data)
+
+ # this weakref stays referenced throughout the lifespan of
+ # CollectionAdapter. so while the weakref can return None, this
+ # is realistically only during garbage collection of this object, so
+ # we type this as a callable that returns _AdaptedCollectionProtocol
+ # in all cases.
+ self._data = weakref.ref(data) # type: ignore
+
self.owner_state = owner_state
data._sa_adapter = self
self._converter = data._sa_converter
@@ -437,7 +504,7 @@ class CollectionAdapter:
util.warn("This collection has been invalidated.")
@property
- def data(self):
+ def data(self) -> _AdaptedCollectionProtocol:
"The entity collection being adapted."
return self._data()
@@ -634,7 +701,10 @@ class CollectionAdapter:
def __setstate__(self, d):
self._key = d["key"]
self.owner_state = d["owner_state"]
- self._data = weakref.ref(d["data"])
+
+ # see note in constructor regarding this type: ignore
+ self._data = weakref.ref(d["data"]) # type: ignore
+
self._converter = d["data"]._sa_converter
d["data"]._sa_adapter = self
self.invalidated = d["invalidated"]
@@ -682,7 +752,9 @@ def bulk_replace(values, existing_adapter, new_adapter, initiator=None):
existing_adapter.fire_remove_event(member, initiator=initiator)
-def prepare_instrumentation(factory):
+def prepare_instrumentation(
+ factory: Union[Type[Collection[Any]], _CollectionFactoryType],
+) -> _CollectionFactoryType:
"""Prepare a callable for future use as a collection class factory.
Given a collection class factory (either a type or no-arg callable),
@@ -693,18 +765,30 @@ def prepare_instrumentation(factory):
into the run-time behavior of collection_class=InstrumentedList.
"""
+
+ impl_factory: _CollectionFactoryType
+
# Convert a builtin to 'Instrumented*'
if factory in __canned_instrumentation:
- factory = __canned_instrumentation[factory]
+ impl_factory = __canned_instrumentation[factory]
+ else:
+ impl_factory = cast(_CollectionFactoryType, factory)
+
+ cls: Union[_CollectionFactoryType, Type[Collection[Any]]]
# Create a specimen
- cls = type(factory())
+ cls = type(impl_factory())
# Did factory callable return a builtin?
if cls in __canned_instrumentation:
- # Wrap it so that it returns our 'Instrumented*'
- factory = __converting_factory(cls, factory)
- cls = factory()
+
+ # if so, just convert.
+ # in previous major releases, this codepath wasn't working and was
+ # not covered by tests. prior to that it supplied a "wrapper"
+ # function that would return the class, though the rationale for this
+ # case is not known
+ impl_factory = __canned_instrumentation[cls]
+ cls = type(impl_factory())
# Instrument the class if needed.
if __instrumentation_mutex.acquire():
@@ -714,26 +798,7 @@ def prepare_instrumentation(factory):
finally:
__instrumentation_mutex.release()
- return factory
-
-
-def __converting_factory(specimen_cls, original_factory):
- """Return a wrapper that converts a "canned" collection like
- set, dict, list into the Instrumented* version.
-
- """
-
- instrumented_cls = __canned_instrumentation[specimen_cls]
-
- def wrapper():
- collection = original_factory()
- return instrumented_cls(collection)
-
- # often flawed but better than nothing
- wrapper.__name__ = "%sWrapper" % original_factory.__name__
- wrapper.__doc__ = original_factory.__doc__
-
- return wrapper
+ return impl_factory
def _instrument_class(cls):
@@ -763,8 +828,8 @@ def _locate_roles_and_methods(cls):
"""
- roles = {}
- methods = {}
+ roles: Dict[str, str] = {}
+ methods: Dict[str, Tuple[Optional[str], Optional[int], Optional[str]]] = {}
for supercls in cls.__mro__:
for name, method in vars(supercls).items():
@@ -784,7 +849,9 @@ def _locate_roles_and_methods(cls):
# transfer instrumentation requests from decorated function
# to the combined queue
- before, after = None, None
+ before: Optional[Tuple[str, int]] = None
+ after: Optional[str] = None
+
if hasattr(method, "_sa_instrument_before"):
op, argument = method._sa_instrument_before
assert op in ("fire_append_event", "fire_remove_event")
@@ -809,6 +876,7 @@ def _setup_canned_roles(cls, roles, methods):
"""
collection_type = util.duck_type_collection(cls)
if collection_type in __interfaces:
+ assert collection_type is not None
canned_roles, decorators = __interfaces[collection_type]
for role, name in canned_roles.items():
roles.setdefault(role, name)
@@ -934,9 +1002,9 @@ def _instrument_membership_mutator(method, before, argument, after):
getattr(executor, after)(res, initiator)
return res
- wrapper._sa_instrumented = True
+ wrapper._sa_instrumented = True # type: ignore[attr-defined]
if hasattr(method, "_sa_instrument_role"):
- wrapper._sa_instrument_role = method._sa_instrument_role
+ wrapper._sa_instrument_role = method._sa_instrument_role # type: ignore[attr-defined] # noqa: E501
wrapper.__name__ = method.__name__
wrapper.__doc__ = method.__doc__
return wrapper
@@ -990,7 +1058,7 @@ def __before_pop(collection, _sa_initiator=None):
executor.fire_pre_remove_event(_sa_initiator)
-def _list_decorators():
+def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
"""Tailored instrumentation wrappers for any list-like class."""
def _tidy(fn):
@@ -1131,7 +1199,7 @@ def _list_decorators():
return l
-def _dict_decorators():
+def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
"""Tailored instrumentation wrappers for any dict-like mapping class."""
def _tidy(fn):
@@ -1255,7 +1323,7 @@ def _set_binops_check_loose(self: Any, obj: Any) -> bool:
)
-def _set_decorators():
+def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
"""Tailored instrumentation wrappers for any set-like class."""
def _tidy(fn):
@@ -1420,36 +1488,52 @@ def _set_decorators():
return l
-class InstrumentedList(list):
+class InstrumentedList(List[_T]):
"""An instrumented version of the built-in list."""
-class InstrumentedSet(set):
+class InstrumentedSet(Set[_T]):
"""An instrumented version of the built-in set."""
-class InstrumentedDict(dict):
+class InstrumentedDict(Dict[_KT, _VT]):
"""An instrumented version of the built-in dict."""
-__canned_instrumentation = {
- list: InstrumentedList,
- set: InstrumentedSet,
- dict: InstrumentedDict,
-}
-
-__interfaces = {
- list: (
- {"appender": "append", "remover": "remove", "iterator": "__iter__"},
- _list_decorators(),
- ),
- set: (
- {"appender": "add", "remover": "remove", "iterator": "__iter__"},
- _set_decorators(),
- ),
- # decorators are required for dicts and object collections.
- dict: ({"iterator": "values"}, _dict_decorators()),
-}
+__canned_instrumentation: util.immutabledict[
+ Any, _CollectionFactoryType
+] = util.immutabledict(
+ {
+ list: InstrumentedList,
+ set: InstrumentedSet,
+ dict: InstrumentedDict,
+ }
+)
+
+__interfaces: util.immutabledict[
+ Any,
+ Tuple[
+ Dict[str, str],
+ Dict[str, Callable[..., Any]],
+ ],
+] = util.immutabledict(
+ {
+ list: (
+ {
+ "appender": "append",
+ "remover": "remove",
+ "iterator": "__iter__",
+ },
+ _list_decorators(),
+ ),
+ set: (
+ {"appender": "add", "remover": "remove", "iterator": "__iter__"},
+ _set_decorators(),
+ ),
+ # decorators are required for dicts and object collections.
+ dict: ({"iterator": "values"}, _dict_decorators()),
+ }
+)
def __go(lcls):