diff options
Diffstat (limited to 'lib/sqlalchemy/event/registry.py')
| -rw-r--r-- | lib/sqlalchemy/event/registry.py | 164 |
1 files changed, 128 insertions, 36 deletions
diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index d831a332f..e20d3e0b5 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -14,15 +14,60 @@ membership in all those collections can be revoked at once, based on an equivalent :class:`._EventKey`. """ +from __future__ import annotations + import collections import types +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Deque +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union import weakref from .. import exc from .. import util +from ..util.typing import Protocol + +if typing.TYPE_CHECKING: + from .attr import RefCollection + from .base import dispatcher + +_ListenerFnType = Callable[..., Any] +_ListenerFnKeyType = Union[int, Tuple[int, int]] +_EventKeyTupleType = Tuple[int, str, _ListenerFnKeyType] + + +class _EventTargetType(Protocol): + """represents an event target, that is, something we can listen on + either with that target as a class or as an instance. + + Examples include: Connection, Mapper, Table, Session, + InstrumentedAttribute, Engine, Pool, Dialect. + """ -_key_to_collection = collections.defaultdict(dict) + dispatch: ClassVar[dispatcher[Any]] + + +_ET = TypeVar("_ET", bound=_EventTargetType) + +_RefCollectionToListenerType = Dict[ + "weakref.ref[RefCollection[Any]]", + "weakref.ref[_ListenerFnType]", +] + +_key_to_collection: Dict[ + _EventKeyTupleType, _RefCollectionToListenerType +] = collections.defaultdict(dict) """ Given an original listen() argument, can locate all listener collections and the listener fn contained @@ -34,7 +79,14 @@ listener collections and the listener fn contained } """ -_collection_to_key = collections.defaultdict(dict) +_ListenerToEventKeyType = Dict[ + "weakref.ref[_ListenerFnType]", + _EventKeyTupleType, +] +_collection_to_key: Dict[ + weakref.ref[RefCollection[Any]], + _ListenerToEventKeyType, +] = collections.defaultdict(dict) """ Given a _ListenerCollection or _ClsLevelListener, can locate all the original listen() arguments and the listener fn contained @@ -47,10 +99,13 @@ ref(listenercollection) -> { """ -def _collection_gced(ref): +def _collection_gced(ref: weakref.ref[Any]) -> None: # defaultdict, so can't get a KeyError if not _collection_to_key or ref not in _collection_to_key: return + + ref = cast("weakref.ref[RefCollection[_EventTargetType]]", ref) + listener_to_key = _collection_to_key.pop(ref) for key in listener_to_key.values(): if key in _key_to_collection: @@ -61,7 +116,9 @@ def _collection_gced(ref): _key_to_collection.pop(key) -def _stored_in_collection(event_key, owner): +def _stored_in_collection( + event_key: _EventKey[_ET], owner: RefCollection[_ET] +) -> bool: key = event_key._key dispatch_reg = _key_to_collection[key] @@ -80,7 +137,9 @@ def _stored_in_collection(event_key, owner): return True -def _removed_from_collection(event_key, owner): +def _removed_from_collection( + event_key: _EventKey[_ET], owner: RefCollection[_ET] +) -> None: key = event_key._key dispatch_reg = _key_to_collection[key] @@ -97,15 +156,19 @@ def _removed_from_collection(event_key, owner): listener_to_key.pop(listen_ref) -def _stored_in_collection_multi(newowner, oldowner, elements): +def _stored_in_collection_multi( + newowner: RefCollection[_ET], + oldowner: RefCollection[_ET], + elements: Iterable[_ListenerFnType], +) -> None: if not elements: return - oldowner = oldowner.ref - newowner = newowner.ref + oldowner_ref = oldowner.ref + newowner_ref = newowner.ref - old_listener_to_key = _collection_to_key[oldowner] - new_listener_to_key = _collection_to_key[newowner] + old_listener_to_key = _collection_to_key[oldowner_ref] + new_listener_to_key = _collection_to_key[newowner_ref] for listen_fn in elements: listen_ref = weakref.ref(listen_fn) @@ -121,31 +184,34 @@ def _stored_in_collection_multi(newowner, oldowner, elements): except KeyError: continue - if newowner in dispatch_reg: - assert dispatch_reg[newowner] == listen_ref + if newowner_ref in dispatch_reg: + assert dispatch_reg[newowner_ref] == listen_ref else: - dispatch_reg[newowner] = listen_ref + dispatch_reg[newowner_ref] = listen_ref new_listener_to_key[listen_ref] = key -def _clear(owner, elements): +def _clear( + owner: RefCollection[_ET], + elements: Iterable[_ListenerFnType], +) -> None: if not elements: return - owner = owner.ref - listener_to_key = _collection_to_key[owner] + owner_ref = owner.ref + listener_to_key = _collection_to_key[owner_ref] for listen_fn in elements: listen_ref = weakref.ref(listen_fn) key = listener_to_key[listen_ref] dispatch_reg = _key_to_collection[key] - dispatch_reg.pop(owner, None) + dispatch_reg.pop(owner_ref, None) if not dispatch_reg: del _key_to_collection[key] -class _EventKey: +class _EventKey(Generic[_ET]): """Represent :func:`.listen` arguments.""" __slots__ = ( @@ -157,10 +223,24 @@ class _EventKey: "dispatch_target", ) - def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None): + target: _ET + identifier: str + fn: _ListenerFnType + fn_key: _ListenerFnKeyType + dispatch_target: Any + _fn_wrap: Optional[_ListenerFnType] + + def __init__( + self, + target: _ET, + identifier: str, + fn: _ListenerFnType, + dispatch_target: Any, + _fn_wrap: Optional[_ListenerFnType] = None, + ): self.target = target self.identifier = identifier - self.fn = fn + self.fn = fn # type: ignore[assignment] if isinstance(fn, types.MethodType): self.fn_key = id(fn.__func__), id(fn.__self__) else: @@ -169,10 +249,10 @@ class _EventKey: self.dispatch_target = dispatch_target @property - def _key(self): + def _key(self) -> _EventKeyTupleType: return (id(self.target), self.identifier, self.fn_key) - def with_wrapper(self, fn_wrap): + def with_wrapper(self, fn_wrap: _ListenerFnType) -> _EventKey[_ET]: if fn_wrap is self._listen_fn: return self else: @@ -184,7 +264,7 @@ class _EventKey: _fn_wrap=fn_wrap, ) - def with_dispatch_target(self, dispatch_target): + def with_dispatch_target(self, dispatch_target: Any) -> _EventKey[_ET]: if dispatch_target is self.dispatch_target: return self else: @@ -196,7 +276,7 @@ class _EventKey: _fn_wrap=self.fn_wrap, ) - def listen(self, *args, **kw): + def listen(self, *args: Any, **kw: Any) -> None: once = kw.pop("once", False) once_unless_exception = kw.pop("_once_unless_exception", False) named = kw.pop("named", False) @@ -228,7 +308,7 @@ class _EventKey: else: self.dispatch_target.dispatch._listen(self, *args, **kw) - def remove(self): + def remove(self) -> None: key = self._key if key not in _key_to_collection: @@ -245,18 +325,18 @@ class _EventKey: if collection is not None and listener_fn is not None: collection.remove(self.with_wrapper(listener_fn)) - def contains(self): + def contains(self) -> bool: """Return True if this event key is registered to listen.""" return self._key in _key_to_collection def base_listen( self, - propagate=False, - insert=False, - named=False, - retval=None, - asyncio=False, - ): + propagate: bool = False, + insert: bool = False, + named: bool = False, + retval: Optional[bool] = None, + asyncio: bool = False, + ) -> None: target, identifier = self.dispatch_target, self.identifier @@ -272,21 +352,33 @@ class _EventKey: for_modify.append(self, propagate) @property - def _listen_fn(self): + def _listen_fn(self) -> _ListenerFnType: return self.fn_wrap or self.fn - def append_to_list(self, owner, list_): + def append_to_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> bool: if _stored_in_collection(self, owner): list_.append(self._listen_fn) return True else: return False - def remove_from_list(self, owner, list_): + def remove_from_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> None: _removed_from_collection(self, owner) list_.remove(self._listen_fn) - def prepend_to_list(self, owner, list_): + def prepend_to_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> bool: if _stored_in_collection(self, owner): list_.appendleft(self._listen_fn) return True |
