summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/event/registry.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/event/registry.py')
-rw-r--r--lib/sqlalchemy/event/registry.py164
1 files changed, 128 insertions, 36 deletions
diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py
index d831a332f..e20d3e0b5 100644
--- a/lib/sqlalchemy/event/registry.py
+++ b/lib/sqlalchemy/event/registry.py
@@ -14,15 +14,60 @@ membership in all those collections can be revoked at once, based on
an equivalent :class:`._EventKey`.
"""
+from __future__ import annotations
+
import collections
import types
+import typing
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import ClassVar
+from typing import Deque
+from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Optional
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
import weakref
from .. import exc
from .. import util
+from ..util.typing import Protocol
+
+if typing.TYPE_CHECKING:
+ from .attr import RefCollection
+ from .base import dispatcher
+
+_ListenerFnType = Callable[..., Any]
+_ListenerFnKeyType = Union[int, Tuple[int, int]]
+_EventKeyTupleType = Tuple[int, str, _ListenerFnKeyType]
+
+
+class _EventTargetType(Protocol):
+ """represents an event target, that is, something we can listen on
+ either with that target as a class or as an instance.
+
+ Examples include: Connection, Mapper, Table, Session,
+ InstrumentedAttribute, Engine, Pool, Dialect.
+ """
-_key_to_collection = collections.defaultdict(dict)
+ dispatch: ClassVar[dispatcher[Any]]
+
+
+_ET = TypeVar("_ET", bound=_EventTargetType)
+
+_RefCollectionToListenerType = Dict[
+ "weakref.ref[RefCollection[Any]]",
+ "weakref.ref[_ListenerFnType]",
+]
+
+_key_to_collection: Dict[
+ _EventKeyTupleType, _RefCollectionToListenerType
+] = collections.defaultdict(dict)
"""
Given an original listen() argument, can locate all
listener collections and the listener fn contained
@@ -34,7 +79,14 @@ listener collections and the listener fn contained
}
"""
-_collection_to_key = collections.defaultdict(dict)
+_ListenerToEventKeyType = Dict[
+ "weakref.ref[_ListenerFnType]",
+ _EventKeyTupleType,
+]
+_collection_to_key: Dict[
+ weakref.ref[RefCollection[Any]],
+ _ListenerToEventKeyType,
+] = collections.defaultdict(dict)
"""
Given a _ListenerCollection or _ClsLevelListener, can locate
all the original listen() arguments and the listener fn contained
@@ -47,10 +99,13 @@ ref(listenercollection) -> {
"""
-def _collection_gced(ref):
+def _collection_gced(ref: weakref.ref[Any]) -> None:
# defaultdict, so can't get a KeyError
if not _collection_to_key or ref not in _collection_to_key:
return
+
+ ref = cast("weakref.ref[RefCollection[_EventTargetType]]", ref)
+
listener_to_key = _collection_to_key.pop(ref)
for key in listener_to_key.values():
if key in _key_to_collection:
@@ -61,7 +116,9 @@ def _collection_gced(ref):
_key_to_collection.pop(key)
-def _stored_in_collection(event_key, owner):
+def _stored_in_collection(
+ event_key: _EventKey[_ET], owner: RefCollection[_ET]
+) -> bool:
key = event_key._key
dispatch_reg = _key_to_collection[key]
@@ -80,7 +137,9 @@ def _stored_in_collection(event_key, owner):
return True
-def _removed_from_collection(event_key, owner):
+def _removed_from_collection(
+ event_key: _EventKey[_ET], owner: RefCollection[_ET]
+) -> None:
key = event_key._key
dispatch_reg = _key_to_collection[key]
@@ -97,15 +156,19 @@ def _removed_from_collection(event_key, owner):
listener_to_key.pop(listen_ref)
-def _stored_in_collection_multi(newowner, oldowner, elements):
+def _stored_in_collection_multi(
+ newowner: RefCollection[_ET],
+ oldowner: RefCollection[_ET],
+ elements: Iterable[_ListenerFnType],
+) -> None:
if not elements:
return
- oldowner = oldowner.ref
- newowner = newowner.ref
+ oldowner_ref = oldowner.ref
+ newowner_ref = newowner.ref
- old_listener_to_key = _collection_to_key[oldowner]
- new_listener_to_key = _collection_to_key[newowner]
+ old_listener_to_key = _collection_to_key[oldowner_ref]
+ new_listener_to_key = _collection_to_key[newowner_ref]
for listen_fn in elements:
listen_ref = weakref.ref(listen_fn)
@@ -121,31 +184,34 @@ def _stored_in_collection_multi(newowner, oldowner, elements):
except KeyError:
continue
- if newowner in dispatch_reg:
- assert dispatch_reg[newowner] == listen_ref
+ if newowner_ref in dispatch_reg:
+ assert dispatch_reg[newowner_ref] == listen_ref
else:
- dispatch_reg[newowner] = listen_ref
+ dispatch_reg[newowner_ref] = listen_ref
new_listener_to_key[listen_ref] = key
-def _clear(owner, elements):
+def _clear(
+ owner: RefCollection[_ET],
+ elements: Iterable[_ListenerFnType],
+) -> None:
if not elements:
return
- owner = owner.ref
- listener_to_key = _collection_to_key[owner]
+ owner_ref = owner.ref
+ listener_to_key = _collection_to_key[owner_ref]
for listen_fn in elements:
listen_ref = weakref.ref(listen_fn)
key = listener_to_key[listen_ref]
dispatch_reg = _key_to_collection[key]
- dispatch_reg.pop(owner, None)
+ dispatch_reg.pop(owner_ref, None)
if not dispatch_reg:
del _key_to_collection[key]
-class _EventKey:
+class _EventKey(Generic[_ET]):
"""Represent :func:`.listen` arguments."""
__slots__ = (
@@ -157,10 +223,24 @@ class _EventKey:
"dispatch_target",
)
- def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None):
+ target: _ET
+ identifier: str
+ fn: _ListenerFnType
+ fn_key: _ListenerFnKeyType
+ dispatch_target: Any
+ _fn_wrap: Optional[_ListenerFnType]
+
+ def __init__(
+ self,
+ target: _ET,
+ identifier: str,
+ fn: _ListenerFnType,
+ dispatch_target: Any,
+ _fn_wrap: Optional[_ListenerFnType] = None,
+ ):
self.target = target
self.identifier = identifier
- self.fn = fn
+ self.fn = fn # type: ignore[assignment]
if isinstance(fn, types.MethodType):
self.fn_key = id(fn.__func__), id(fn.__self__)
else:
@@ -169,10 +249,10 @@ class _EventKey:
self.dispatch_target = dispatch_target
@property
- def _key(self):
+ def _key(self) -> _EventKeyTupleType:
return (id(self.target), self.identifier, self.fn_key)
- def with_wrapper(self, fn_wrap):
+ def with_wrapper(self, fn_wrap: _ListenerFnType) -> _EventKey[_ET]:
if fn_wrap is self._listen_fn:
return self
else:
@@ -184,7 +264,7 @@ class _EventKey:
_fn_wrap=fn_wrap,
)
- def with_dispatch_target(self, dispatch_target):
+ def with_dispatch_target(self, dispatch_target: Any) -> _EventKey[_ET]:
if dispatch_target is self.dispatch_target:
return self
else:
@@ -196,7 +276,7 @@ class _EventKey:
_fn_wrap=self.fn_wrap,
)
- def listen(self, *args, **kw):
+ def listen(self, *args: Any, **kw: Any) -> None:
once = kw.pop("once", False)
once_unless_exception = kw.pop("_once_unless_exception", False)
named = kw.pop("named", False)
@@ -228,7 +308,7 @@ class _EventKey:
else:
self.dispatch_target.dispatch._listen(self, *args, **kw)
- def remove(self):
+ def remove(self) -> None:
key = self._key
if key not in _key_to_collection:
@@ -245,18 +325,18 @@ class _EventKey:
if collection is not None and listener_fn is not None:
collection.remove(self.with_wrapper(listener_fn))
- def contains(self):
+ def contains(self) -> bool:
"""Return True if this event key is registered to listen."""
return self._key in _key_to_collection
def base_listen(
self,
- propagate=False,
- insert=False,
- named=False,
- retval=None,
- asyncio=False,
- ):
+ propagate: bool = False,
+ insert: bool = False,
+ named: bool = False,
+ retval: Optional[bool] = None,
+ asyncio: bool = False,
+ ) -> None:
target, identifier = self.dispatch_target, self.identifier
@@ -272,21 +352,33 @@ class _EventKey:
for_modify.append(self, propagate)
@property
- def _listen_fn(self):
+ def _listen_fn(self) -> _ListenerFnType:
return self.fn_wrap or self.fn
- def append_to_list(self, owner, list_):
+ def append_to_list(
+ self,
+ owner: RefCollection[_ET],
+ list_: Deque[_ListenerFnType],
+ ) -> bool:
if _stored_in_collection(self, owner):
list_.append(self._listen_fn)
return True
else:
return False
- def remove_from_list(self, owner, list_):
+ def remove_from_list(
+ self,
+ owner: RefCollection[_ET],
+ list_: Deque[_ListenerFnType],
+ ) -> None:
_removed_from_collection(self, owner)
list_.remove(self._listen_fn)
- def prepend_to_list(self, owner, list_):
+ def prepend_to_list(
+ self,
+ owner: RefCollection[_ET],
+ list_: Deque[_ListenerFnType],
+ ) -> bool:
if _stored_in_collection(self, owner):
list_.appendleft(self._listen_fn)
return True