diff options
Diffstat (limited to 'lib/sqlalchemy/event')
| -rw-r--r-- | lib/sqlalchemy/event/__init__.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/event/api.py | 25 | ||||
| -rw-r--r-- | lib/sqlalchemy/event/attr.py | 296 | ||||
| -rw-r--r-- | lib/sqlalchemy/event/base.py | 212 | ||||
| -rw-r--r-- | lib/sqlalchemy/event/legacy.py | 77 | ||||
| -rw-r--r-- | lib/sqlalchemy/event/registry.py | 164 |
6 files changed, 581 insertions, 213 deletions
diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py index a89bea894..2d10372ab 100644 --- a/lib/sqlalchemy/event/__init__.py +++ b/lib/sqlalchemy/event/__init__.py @@ -5,13 +5,15 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -from .api import CANCEL -from .api import contains -from .api import listen -from .api import listens_for -from .api import NO_RETVAL -from .api import remove -from .attr import RefCollection -from .base import dispatcher -from .base import Events +from __future__ import annotations + +from .api import CANCEL as CANCEL +from .api import contains as contains +from .api import listen as listen +from .api import listens_for as listens_for +from .api import NO_RETVAL as NO_RETVAL +from .api import remove as remove +from .attr import RefCollection as RefCollection +from .base import dispatcher as dispatcher +from .base import Events as Events from .legacy import _legacy_signature diff --git a/lib/sqlalchemy/event/api.py b/lib/sqlalchemy/event/api.py index d2fd9473c..52f796b19 100644 --- a/lib/sqlalchemy/event/api.py +++ b/lib/sqlalchemy/event/api.py @@ -8,8 +8,15 @@ """Public API functions for the event system. """ +from __future__ import annotations + +from typing import Any +from typing import Callable + from .base import _registrars +from .registry import _ET from .registry import _EventKey +from .registry import _ListenerFnType from .. import exc from .. import util @@ -18,7 +25,9 @@ CANCEL = util.symbol("CANCEL") NO_RETVAL = util.symbol("NO_RETVAL") -def _event_key(target, identifier, fn): +def _event_key( + target: _ET, identifier: str, fn: _ListenerFnType +) -> _EventKey[_ET]: for evt_cls in _registrars[identifier]: tgt = evt_cls._accept_with(target) if tgt is not None: @@ -29,7 +38,9 @@ def _event_key(target, identifier, fn): ) -def listen(target, identifier, fn, *args, **kw): +def listen( + target: Any, identifier: str, fn: Callable[..., Any], *args: Any, **kw: Any +) -> None: """Register a listener function for the given target. The :func:`.listen` function is part of the primary interface for the @@ -113,7 +124,9 @@ def listen(target, identifier, fn, *args, **kw): _event_key(target, identifier, fn).listen(*args, **kw) -def listens_for(target, identifier, *args, **kw): +def listens_for( + target: Any, identifier: str, *args: Any, **kw: Any +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Decorate a function as a listener for the given target + identifier. The :func:`.listens_for` decorator is part of the primary interface for the @@ -154,14 +167,14 @@ def listens_for(target, identifier, *args, **kw): """ - def decorate(fn): + def decorate(fn: Callable[..., Any]) -> Callable[..., Any]: listen(target, identifier, fn, *args, **kw) return fn return decorate -def remove(target, identifier, fn): +def remove(target: Any, identifier: str, fn: Callable[..., Any]) -> None: """Remove an event listener. The arguments here should match exactly those which were sent to @@ -211,7 +224,7 @@ def remove(target, identifier, fn): _event_key(target, identifier, fn).remove() -def contains(target, identifier, fn): +def contains(target: Any, identifier: str, fn: Callable[..., Any]) -> bool: """Return True if the given target/ident/fn is set up to listen.""" return _event_key(target, identifier, fn).contains() diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index a05966222..d1ae7a845 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -28,43 +28,89 @@ as well as support for subclass propagation (e.g. events assigned to ``Pool`` vs. ``QueuePool``) are all implemented here. """ +from __future__ import annotations + import collections from itertools import chain import threading +from types import TracebackType +import typing +from typing import Any +from typing import cast +from typing import Collection +from typing import Deque +from typing import FrozenSet +from typing import Generic +from typing import Iterator +from typing import MutableMapping +from typing import MutableSequence +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union import weakref from . import legacy from . import registry +from .registry import _ET +from .registry import _EventKey +from .registry import _ListenerFnType from .. import exc from .. import util from ..util.concurrency import AsyncAdaptedLock +from ..util.typing import Protocol + +_T = TypeVar("_T", bound=Any) + +if typing.TYPE_CHECKING: + from .base import _Dispatch + from .base import _HasEventsDispatch + from .base import _JoinedDispatcher -class RefCollection(util.MemoizedSlots): +class RefCollection(util.MemoizedSlots, Generic[_ET]): __slots__ = ("ref",) - def _memoized_attr_ref(self): + ref: weakref.ref[RefCollection[_ET]] + + def _memoized_attr_ref(self) -> weakref.ref[RefCollection[_ET]]: return weakref.ref(self, registry._collection_gced) -class _empty_collection: - def append(self, element): +class _empty_collection(Collection[_T]): + def append(self, element: _T) -> None: + pass + + def appendleft(self, element: _T) -> None: pass - def extend(self, other): + def extend(self, other: Sequence[_T]) -> None: pass - def remove(self, element): + def remove(self, element: _T) -> None: pass - def __iter__(self): + def __contains__(self, element: Any) -> bool: + return False + + def __iter__(self) -> Iterator[_T]: return iter([]) - def clear(self): + def clear(self) -> None: pass + def __len__(self) -> int: + return 0 + + +_ListenerFnSequenceType = Union[Deque[_T], _empty_collection[_T]] + -class _ClsLevelDispatch(RefCollection): +class _ClsLevelDispatch(RefCollection[_ET]): """Class-level events on :class:`._Dispatch` classes.""" __slots__ = ( @@ -77,7 +123,20 @@ class _ClsLevelDispatch(RefCollection): "__weakref__", ) - def __init__(self, parent_dispatch_cls, fn): + clsname: str + name: str + arg_names: Sequence[str] + has_kw: bool + legacy_signatures: MutableSequence[legacy._LegacySignatureType] + _clslevel: MutableMapping[ + Type[_ET], _ListenerFnSequenceType[_ListenerFnType] + ] + + def __init__( + self, + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + fn: _ListenerFnType, + ): self.name = fn.__name__ self.clsname = parent_dispatch_cls.__name__ argspec = util.inspect_getfullargspec(fn) @@ -94,7 +153,9 @@ class _ClsLevelDispatch(RefCollection): self._clslevel = weakref.WeakKeyDictionary() - def _adjust_fn_spec(self, fn, named): + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: if named: fn = self._wrap_fn_for_kw(fn) if self.legacy_signatures: @@ -106,15 +167,15 @@ class _ClsLevelDispatch(RefCollection): fn = legacy._wrap_fn_for_legacy(self, fn, argspec) return fn - def _wrap_fn_for_kw(self, fn): - def wrap_kw(*args, **kw): + def _wrap_fn_for_kw(self, fn: _ListenerFnType) -> _ListenerFnType: + def wrap_kw(*args: Any, **kw: Any) -> Any: argdict = dict(zip(self.arg_names, args)) argdict.update(kw) return fn(**argdict) return wrap_kw - def insert(self, event_key, propagate): + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: target = event_key.dispatch_target assert isinstance( target, type @@ -125,6 +186,7 @@ class _ClsLevelDispatch(RefCollection): ) for cls in util.walk_subclasses(target): + cls = cast(Type[_ET], cls) if cls is not target and cls not in self._clslevel: self.update_subclass(cls) else: @@ -133,7 +195,7 @@ class _ClsLevelDispatch(RefCollection): self._clslevel[cls].appendleft(event_key._listen_fn) registry._stored_in_collection(event_key, self) - def append(self, event_key, propagate): + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: target = event_key.dispatch_target assert isinstance( target, type @@ -143,6 +205,7 @@ class _ClsLevelDispatch(RefCollection): "Can't assign an event directly to the %s class" % target ) for cls in util.walk_subclasses(target): + cls = cast("Type[_ET]", cls) if cls is not target and cls not in self._clslevel: self.update_subclass(cls) else: @@ -151,39 +214,41 @@ class _ClsLevelDispatch(RefCollection): self._clslevel[cls].append(event_key._listen_fn) registry._stored_in_collection(event_key, self) - def _assign_cls_collection(self, target): + def _assign_cls_collection(self, target: Type[_ET]) -> None: if getattr(target, "_sa_propagate_class_events", True): self._clslevel[target] = collections.deque() else: self._clslevel[target] = _empty_collection() - def update_subclass(self, target): + def update_subclass(self, target: Type[_ET]) -> None: if target not in self._clslevel: self._assign_cls_collection(target) clslevel = self._clslevel[target] for cls in target.__mro__[1:]: + cls = cast("Type[_ET]", cls) if cls in self._clslevel: clslevel.extend( [fn for fn in self._clslevel[cls] if fn not in clslevel] ) - def remove(self, event_key): + def remove(self, event_key: _EventKey[_ET]) -> None: target = event_key.dispatch_target for cls in util.walk_subclasses(target): + cls = cast("Type[_ET]", cls) if cls in self._clslevel: self._clslevel[cls].remove(event_key._listen_fn) registry._removed_from_collection(event_key, self) - def clear(self): + def clear(self) -> None: """Clear all class level listeners""" - to_clear = set() + to_clear: Set[_ListenerFnType] = set() for dispatcher in self._clslevel.values(): to_clear.update(dispatcher) dispatcher.clear() registry._clear(self, to_clear) - def for_modify(self, obj): + def for_modify(self, obj: _Dispatch[_ET]) -> _ClsLevelDispatch[_ET]: """Return an event collection which can be modified. For _ClsLevelDispatch at the class level of @@ -193,14 +258,30 @@ class _ClsLevelDispatch(RefCollection): return self -class _InstanceLevelDispatch(RefCollection): +class _InstanceLevelDispatch(RefCollection[_ET], Collection[_ListenerFnType]): __slots__ = () - def _adjust_fn_spec(self, fn, named): + parent: _ClsLevelDispatch[_ET] + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: return self.parent._adjust_fn_spec(fn, named) + def __contains__(self, item: Any) -> bool: + raise NotImplementedError() + + def __len__(self) -> int: + raise NotImplementedError() + + def __iter__(self) -> Iterator[_ListenerFnType]: + raise NotImplementedError() + + def __bool__(self) -> bool: + raise NotImplementedError() + -class _EmptyListener(_InstanceLevelDispatch): +class _EmptyListener(_InstanceLevelDispatch[_ET]): """Serves as a proxy interface to the events served by a _ClsLevelDispatch, when there are no instance-level events present. @@ -210,19 +291,22 @@ class _EmptyListener(_InstanceLevelDispatch): """ - propagate = frozenset() - listeners = () - __slots__ = "parent", "parent_listeners", "name" - def __init__(self, parent, target_cls): + propagate: FrozenSet[_ListenerFnType] = frozenset() + listeners: Tuple[()] = () + parent: _ClsLevelDispatch[_ET] + parent_listeners: _ListenerFnSequenceType[_ListenerFnType] + name: str + + def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) - self.parent = parent # _ClsLevelDispatch + self.parent = parent self.parent_listeners = parent._clslevel[target_cls] self.name = parent.name - def for_modify(self, obj): + def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]: """Return an event collection which can be modified. For _EmptyListener at the instance level of @@ -231,6 +315,7 @@ class _EmptyListener(_InstanceLevelDispatch): and returns it. """ + assert obj._instance_cls is not None result = _ListenerCollection(self.parent, obj._instance_cls) if getattr(obj, self.name) is self: setattr(obj, self.name, result) @@ -238,41 +323,79 @@ class _EmptyListener(_InstanceLevelDispatch): assert isinstance(getattr(obj, self.name), _JoinedListener) return result - def _needs_modify(self, *args, **kw): + def _needs_modify(self, *args: Any, **kw: Any) -> NoReturn: raise NotImplementedError("need to call for_modify()") - exec_once = ( - exec_once_unless_exception - ) = insert = append = remove = clear = _needs_modify + def exec_once(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def insert(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) - def __call__(self, *args, **kw): + def append(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def remove(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def clear(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def __call__(self, *args: Any, **kw: Any) -> None: """Execute this event.""" for fn in self.parent_listeners: fn(*args, **kw) - def __len__(self): + def __contains__(self, item: Any) -> bool: + return item in self.parent_listeners + + def __len__(self) -> int: return len(self.parent_listeners) - def __iter__(self): + def __iter__(self) -> Iterator[_ListenerFnType]: return iter(self.parent_listeners) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.parent_listeners) __nonzero__ = __bool__ -class _CompoundListener(_InstanceLevelDispatch): +class _MutexProtocol(Protocol): + def __enter__(self) -> bool: + ... + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + ... + + +class _CompoundListener(_InstanceLevelDispatch[_ET]): __slots__ = "_exec_once_mutex", "_exec_once", "_exec_w_sync_once" - def _set_asyncio(self): + _exec_once_mutex: _MutexProtocol + parent_listeners: Collection[_ListenerFnType] + listeners: Collection[_ListenerFnType] + _exec_once: bool + _exec_w_sync_once: bool + + def _set_asyncio(self) -> None: self._exec_once_mutex = AsyncAdaptedLock() - def _memoized_attr__exec_once_mutex(self): + def _memoized_attr__exec_once_mutex(self) -> _MutexProtocol: return threading.Lock() - def _exec_once_impl(self, retry_on_exception, *args, **kw): + def _exec_once_impl( + self, retry_on_exception: bool, *args: Any, **kw: Any + ) -> None: with self._exec_once_mutex: if not self._exec_once: try: @@ -285,14 +408,14 @@ class _CompoundListener(_InstanceLevelDispatch): if not exception or not retry_on_exception: self._exec_once = True - def exec_once(self, *args, **kw): + def exec_once(self, *args: Any, **kw: Any) -> None: """Execute this event, but only if it has not been executed already for this collection.""" if not self._exec_once: self._exec_once_impl(False, *args, **kw) - def exec_once_unless_exception(self, *args, **kw): + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None: """Execute this event, but only if it has not been executed already for this collection, or was called by a previous exec_once_unless_exception call and @@ -307,7 +430,7 @@ class _CompoundListener(_InstanceLevelDispatch): if not self._exec_once: self._exec_once_impl(True, *args, **kw) - def _exec_w_sync_on_first_run(self, *args, **kw): + def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None: """Execute this event, and use a mutex if it has not been executed already for this collection, or was called by a previous _exec_w_sync_on_first_run call and @@ -330,7 +453,7 @@ class _CompoundListener(_InstanceLevelDispatch): else: self(*args, **kw) - def __call__(self, *args, **kw): + def __call__(self, *args: Any, **kw: Any) -> None: """Execute this event.""" for fn in self.parent_listeners: @@ -338,19 +461,22 @@ class _CompoundListener(_InstanceLevelDispatch): for fn in self.listeners: fn(*args, **kw) - def __len__(self): + def __contains__(self, item: Any) -> bool: + return item in self.parent_listeners or item in self.listeners + + def __len__(self) -> int: return len(self.parent_listeners) + len(self.listeners) - def __iter__(self): + def __iter__(self) -> Iterator[_ListenerFnType]: return chain(self.parent_listeners, self.listeners) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.listeners or self.parent_listeners) __nonzero__ = __bool__ -class _ListenerCollection(_CompoundListener): +class _ListenerCollection(_CompoundListener[_ET]): """Instance-level attributes on instances of :class:`._Dispatch`. Represents a collection of listeners. @@ -369,7 +495,13 @@ class _ListenerCollection(_CompoundListener): "__weakref__", ) - def __init__(self, parent, target_cls): + parent_listeners: Collection[_ListenerFnType] + parent: _ClsLevelDispatch[_ET] + name: str + listeners: Deque[_ListenerFnType] + propagate: Set[_ListenerFnType] + + def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) self._exec_once = False @@ -380,7 +512,7 @@ class _ListenerCollection(_CompoundListener): self.listeners = collections.deque() self.propagate = set() - def for_modify(self, obj): + def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]: """Return an event collection which can be modified. For _ListenerCollection at the instance level of @@ -389,10 +521,11 @@ class _ListenerCollection(_CompoundListener): """ return self - def _update(self, other, only_propagate=True): + def _update( + self, other: _ListenerCollection[_ET], only_propagate: bool = True + ) -> None: """Populate from the listeners in another :class:`_Dispatch` object.""" - existing_listeners = self.listeners existing_listener_set = set(existing_listeners) self.propagate.update(other.propagate) @@ -409,56 +542,75 @@ class _ListenerCollection(_CompoundListener): to_associate = other.propagate.union(other_listeners) registry._stored_in_collection_multi(self, other, to_associate) - def insert(self, event_key, propagate): + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: if event_key.prepend_to_list(self, self.listeners): if propagate: self.propagate.add(event_key._listen_fn) - def append(self, event_key, propagate): + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: if event_key.append_to_list(self, self.listeners): if propagate: self.propagate.add(event_key._listen_fn) - def remove(self, event_key): + def remove(self, event_key: _EventKey[_ET]) -> None: self.listeners.remove(event_key._listen_fn) self.propagate.discard(event_key._listen_fn) registry._removed_from_collection(event_key, self) - def clear(self): + def clear(self) -> None: registry._clear(self, self.listeners) self.propagate.clear() self.listeners.clear() -class _JoinedListener(_CompoundListener): - __slots__ = "parent", "name", "local", "parent_listeners" +class _JoinedListener(_CompoundListener[_ET]): + __slots__ = "parent_dispatch", "name", "local", "parent_listeners" + + parent_dispatch: _Dispatch[_ET] + name: str + local: _InstanceLevelDispatch[_ET] + parent_listeners: Collection[_ListenerFnType] - def __init__(self, parent, name, local): + def __init__( + self, + parent_dispatch: _Dispatch[_ET], + name: str, + local: _EmptyListener[_ET], + ): self._exec_once = False - self.parent = parent + self.parent_dispatch = parent_dispatch self.name = name self.local = local self.parent_listeners = self.local - @property - def listeners(self): - return getattr(self.parent, self.name) - - def _adjust_fn_spec(self, fn, named): + if not typing.TYPE_CHECKING: + # first error, I don't really understand: + # Signature of "listeners" incompatible with + # supertype "_CompoundListener" [override] + # the name / return type are exactly the same + # second error is getattr_isn't typed, the cast() here + # adds too much method overhead + @property + def listeners(self) -> Collection[_ListenerFnType]: + return getattr(self.parent_dispatch, self.name) + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: return self.local._adjust_fn_spec(fn, named) - def for_modify(self, obj): + def for_modify(self, obj: _JoinedDispatcher[_ET]) -> _JoinedListener[_ET]: self.local = self.parent_listeners = self.local.for_modify(obj) return self - def insert(self, event_key, propagate): + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: self.local.insert(event_key, propagate) - def append(self, event_key, propagate): + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: self.local.append(event_key, propagate) - def remove(self, event_key): + def remove(self, event_key: _EventKey[_ET]) -> None: self.local.remove(event_key) - def clear(self): + def clear(self) -> None: raise NotImplementedError() diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index 25d369240..0e0647036 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -15,21 +15,37 @@ at the class level of a particular ``_Dispatch`` class as well as within instances of ``_Dispatch``. """ -from typing import ClassVar +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterator +from typing import List +from typing import MutableMapping from typing import Optional +from typing import overload +from typing import Tuple from typing import Type +from typing import Union import weakref from .attr import _ClsLevelDispatch from .attr import _EmptyListener +from .attr import _InstanceLevelDispatch from .attr import _JoinedListener +from .registry import _ET +from .registry import _EventKey from .. import util -from ..util.typing import Protocol +from ..util.typing import Literal -_registrars = util.defaultdict(list) +_registrars: MutableMapping[ + str, List[Type[_HasEventsDispatch[Any]]] +] = util.defaultdict(list) -def _is_event_name(name): +def _is_event_name(name: str) -> bool: # _sa_event prefix is special to support internal-only event names. # most event names are just plain method names that aren't # underscored. @@ -45,17 +61,17 @@ class _UnpickleDispatch: """ - def __call__(self, _instance_cls): + def __call__(self, _instance_cls: Type[_ET]) -> _Dispatch[_ET]: for cls in _instance_cls.__mro__: if "dispatch" in cls.__dict__: - return cls.__dict__["dispatch"].dispatch._for_class( - _instance_cls - ) + return cast( + "_Dispatch[_ET]", cls.__dict__["dispatch"].dispatch + )._for_class(_instance_cls) else: raise AttributeError("No class with a 'dispatch' member present.") -class _Dispatch: +class _Dispatch(Generic[_ET]): """Mirror the event listening definitions of an Events class with listener collections. @@ -79,20 +95,35 @@ class _Dispatch: # so __dict__ is used in just that case and potentially others. __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners" - _empty_listener_reg = weakref.WeakKeyDictionary() + _empty_listener_reg: MutableMapping[ + Type[_ET], Dict[str, _EmptyListener[_ET]] + ] = weakref.WeakKeyDictionary() + + _empty_listeners: Dict[str, _EmptyListener[_ET]] + + _event_names: List[str] + + _instance_cls: Optional[Type[_ET]] - _events: Type["_HasEventsDispatch"] + _joined_dispatch_cls: Type[_JoinedDispatcher[_ET]] + + _events: Type[_HasEventsDispatch[_ET]] """reference back to the Events class. Bidirectional against _HasEventsDispatch.dispatch """ - def __init__(self, parent, instance_cls=None): + def __init__( + self, + parent: Optional[_Dispatch[_ET]], + instance_cls: Optional[Type[_ET]] = None, + ): self._parent = parent self._instance_cls = instance_cls if instance_cls: + assert parent is not None try: self._empty_listeners = self._empty_listener_reg[instance_cls] except KeyError: @@ -105,7 +136,7 @@ class _Dispatch: else: self._empty_listeners = {} - def __getattr__(self, name): + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: # Assign EmptyListeners as attributes on demand # to reduce startup time for new dispatch objects. try: @@ -117,24 +148,23 @@ class _Dispatch: return ls @property - def _event_descriptors(self): + def _event_descriptors(self) -> Iterator[_ClsLevelDispatch[_ET]]: for k in self._event_names: # Yield _ClsLevelDispatch related # to relevant event name. yield getattr(self, k) - @property - def _listen(self): - return self._events._listen + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self._events._listen(event_key, **kw) - def _for_class(self, instance_cls): + def _for_class(self, instance_cls: Type[_ET]) -> _Dispatch[_ET]: return self.__class__(self, instance_cls) - def _for_instance(self, instance): + def _for_instance(self, instance: _ET) -> _Dispatch[_ET]: instance_cls = instance.__class__ return self._for_class(instance_cls) - def _join(self, other): + def _join(self, other: _Dispatch[_ET]) -> _JoinedDispatcher[_ET]: """Create a 'join' of this :class:`._Dispatch` and another. This new dispatcher will dispatch events to both @@ -147,14 +177,15 @@ class _Dispatch: (_JoinedDispatcher,), {"__slots__": self._event_names}, ) - self.__class__._joined_dispatch_cls = cls return self._joined_dispatch_cls(self, other) - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return _UnpickleDispatch(), (self._instance_cls,) - def _update(self, other, only_propagate=True): + def _update( + self, other: _Dispatch[_ET], only_propagate: bool = True + ) -> None: """Populate from the listeners in another :class:`_Dispatch` object.""" for ls in other._event_descriptors: @@ -164,32 +195,23 @@ class _Dispatch: ls, only_propagate=only_propagate ) - def _clear(self): + def _clear(self) -> None: for ls in self._event_descriptors: ls.for_modify(self).clear() -def _remove_dispatcher(cls): +def _remove_dispatcher(cls: Type[_HasEventsDispatch[_ET]]) -> None: for k in cls.dispatch._event_names: _registrars[k].remove(cls) if not _registrars[k]: del _registrars[k] -class _HasEventsDispatchProto(Protocol): - """protocol for non-event classes that will also receive the 'dispatch' - attribute in the form of a descriptor. - - """ - - dispatch: ClassVar["dispatcher"] - - -class _HasEventsDispatch: - _dispatch_target: Optional[Type[_HasEventsDispatchProto]] +class _HasEventsDispatch(Generic[_ET]): + _dispatch_target: Optional[Type[_ET]] """class which will receive the .dispatch collection""" - dispatch: _Dispatch + dispatch: _Dispatch[_ET] """reference back to the _Dispatch class. Bidirectional against _Dispatch._events @@ -202,19 +224,41 @@ class _HasEventsDispatch: cls._create_dispatcher_class(cls.__name__, cls.__bases__, cls.__dict__) + @classmethod + def _accept_with( + cls, target: Union[_ET, Type[_ET]] + ) -> Optional[Union[_ET, Type[_ET]]]: + raise NotImplementedError() + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + propagate: bool = False, + insert: bool = False, + named: bool = False, + asyncio: bool = False, + ) -> None: + raise NotImplementedError() + @staticmethod - def _set_dispatch(cls, dispatch_cls): + def _set_dispatch( + klass: Type[_HasEventsDispatch[_ET]], + dispatch_cls: Type[_Dispatch[_ET]], + ) -> _Dispatch[_ET]: # This allows an Events subclass to define additional utility # methods made available to the target via # "self.dispatch._events.<utilitymethod>" # @staticmethod to allow easy "super" calls while in a metaclass # constructor. - cls.dispatch = dispatch_cls(None) - dispatch_cls._events = cls - return cls.dispatch + klass.dispatch = dispatch_cls(None) + dispatch_cls._events = klass + return klass.dispatch @classmethod - def _create_dispatcher_class(cls, classname, bases, dict_): + def _create_dispatcher_class( + cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any] + ) -> None: """Create a :class:`._Dispatch` class corresponding to an :class:`.Events` class.""" @@ -227,14 +271,16 @@ class _HasEventsDispatch: dispatch_base = _Dispatch event_names = [k for k in dict_ if _is_event_name(k)] - dispatch_cls = type( - "%sDispatch" % classname, - (dispatch_base,), - {"__slots__": event_names}, + dispatch_cls = cast( + "Type[_Dispatch[_ET]]", + type( + "%sDispatch" % classname, + (dispatch_base,), # type: ignore + {"__slots__": event_names}, + ), ) dispatch_cls._event_names = event_names - dispatch_inst = cls._set_dispatch(cls, dispatch_cls) for k in dispatch_cls._event_names: setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) @@ -251,23 +297,28 @@ class _HasEventsDispatch: assert dispatch_target_cls is not None if ( hasattr(dispatch_target_cls, "__slots__") - and "_slots_dispatch" in dispatch_target_cls.__slots__ + and "_slots_dispatch" in dispatch_target_cls.__slots__ # type: ignore # noqa E501 ): dispatch_target_cls.dispatch = slots_dispatcher(cls) else: dispatch_target_cls.dispatch = dispatcher(cls) -class Events(_HasEventsDispatch): +class Events(_HasEventsDispatch[_ET]): """Define event listening functions for a particular target type.""" @classmethod - def _accept_with(cls, target): - def dispatch_is(*types): + def _accept_with( + cls, target: Union[_ET, Type[_ET]] + ) -> Optional[Union[_ET, Type[_ET]]]: + def dispatch_is(*types: Type[Any]) -> bool: return all(isinstance(target.dispatch, t) for t in types) - def dispatch_parent_is(t): - return isinstance(target.dispatch.parent, t) + def dispatch_parent_is(t: Type[Any]) -> bool: + + return isinstance( + cast("_JoinedDispatcher[_ET]", target.dispatch).parent, t + ) # Mapper, ClassManager, Session override this to # also accept classes, scoped_sessions, sessionmakers, etc. @@ -282,39 +333,45 @@ class Events(_HasEventsDispatch): ): return target + return None + @classmethod def _listen( cls, - event_key, - propagate=False, - insert=False, - named=False, - asyncio=False, - ): + event_key: _EventKey[_ET], + propagate: bool = False, + insert: bool = False, + named: bool = False, + asyncio: bool = False, + ) -> None: event_key.base_listen( propagate=propagate, insert=insert, named=named, asyncio=asyncio ) @classmethod - def _remove(cls, event_key): + def _remove(cls, event_key: _EventKey[_ET]) -> None: event_key.remove() @classmethod - def _clear(cls): + def _clear(cls) -> None: cls.dispatch._clear() -class _JoinedDispatcher: +class _JoinedDispatcher(Generic[_ET]): """Represent a connection between two _Dispatch objects.""" __slots__ = "local", "parent", "_instance_cls" - def __init__(self, local, parent): + local: _Dispatch[_ET] + parent: _Dispatch[_ET] + _instance_cls: Optional[Type[_ET]] + + def __init__(self, local: _Dispatch[_ET], parent: _Dispatch[_ET]): self.local = local self.parent = parent self._instance_cls = self.local._instance_cls - def __getattr__(self, name): + def __getattr__(self, name: str) -> _JoinedListener[_ET]: # Assign _JoinedListeners as attributes on demand # to reduce startup time for new dispatch objects. ls = getattr(self.local, name) @@ -322,16 +379,15 @@ class _JoinedDispatcher: setattr(self, ls.name, jl) return jl - @property - def _listen(self): - return self.parent._listen + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self.parent._listen(event_key, **kw) @property - def _events(self): + def _events(self) -> Type[_HasEventsDispatch[_ET]]: return self.parent._events -class dispatcher: +class dispatcher(Generic[_ET]): """Descriptor used by target classes to deliver the _Dispatch class at the class level and produce new _Dispatch instances for target @@ -339,11 +395,21 @@ class dispatcher: """ - def __init__(self, events): + def __init__(self, events: Type[_HasEventsDispatch[_ET]]): self.dispatch = events.dispatch self.events = events - def __get__(self, obj, cls): + @overload + def __get__( + self, obj: Literal[None], cls: Type[Any] + ) -> Type[_HasEventsDispatch[_ET]]: + ... + + @overload + def __get__(self, obj: Any, cls: Type[Any]) -> _HasEventsDispatch[_ET]: + ... + + def __get__(self, obj: Any, cls: Type[Any]) -> Any: if obj is None: return self.dispatch @@ -358,8 +424,8 @@ class dispatcher: return disp -class slots_dispatcher(dispatcher): - def __get__(self, obj, cls): +class slots_dispatcher(dispatcher[_ET]): + def __get__(self, obj: Any, cls: Type[Any]) -> Any: if obj is None: return self.dispatch diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py index 053b47eaa..75e5be7fe 100644 --- a/lib/sqlalchemy/event/legacy.py +++ b/lib/sqlalchemy/event/legacy.py @@ -9,11 +9,34 @@ generation of deprecation notes and docstrings. """ - +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type + +from .registry import _ET +from .registry import _ListenerFnType from .. import util +from ..util.compat import FullArgSpec + +if typing.TYPE_CHECKING: + from .attr import _ClsLevelDispatch + from .base import _HasEventsDispatch + + +_LegacySignatureType = Tuple[str, List[str], Optional[Callable[..., Any]]] -def _legacy_signature(since, argnames, converter=None): +def _legacy_signature( + since: str, + argnames: List[str], + converter: Optional[Callable[..., Any]] = None, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """legacy sig decorator @@ -25,16 +48,20 @@ def _legacy_signature(since, argnames, converter=None): """ - def leg(fn): + def leg(fn: Callable[..., Any]) -> Callable[..., Any]: if not hasattr(fn, "_legacy_signatures"): - fn._legacy_signatures = [] - fn._legacy_signatures.append((since, argnames, converter)) + fn._legacy_signatures = [] # type: ignore[attr-defined] + fn._legacy_signatures.append((since, argnames, converter)) # type: ignore[attr-defined] # noqa E501 return fn return leg -def _wrap_fn_for_legacy(dispatch_collection, fn, argspec): +def _wrap_fn_for_legacy( + dispatch_collection: "_ClsLevelDispatch[_ET]", + fn: _ListenerFnType, + argspec: FullArgSpec, +) -> _ListenerFnType: for since, argnames, conv in dispatch_collection.legacy_signatures: if argnames[-1] == "**kw": has_kw = True @@ -64,34 +91,39 @@ def _wrap_fn_for_legacy(dispatch_collection, fn, argspec): ) ) - if conv: + if conv is not None: assert not has_kw - def wrap_leg(*args): + def wrap_leg(*args: Any, **kw: Any) -> Any: util.warn_deprecated(warning_txt, version=since) + assert conv is not None return fn(*conv(*args)) else: - def wrap_leg(*args, **kw): + def wrap_leg(*args: Any, **kw: Any) -> Any: util.warn_deprecated(warning_txt, version=since) argdict = dict(zip(dispatch_collection.arg_names, args)) - args = [argdict[name] for name in argnames] + args_from_dict = [argdict[name] for name in argnames] if has_kw: - return fn(*args, **kw) + return fn(*args_from_dict, **kw) else: - return fn(*args) + return fn(*args_from_dict) return wrap_leg else: return fn -def _indent(text, indent): +def _indent(text: str, indent: str) -> str: return "\n".join(indent + line for line in text.split("\n")) -def _standard_listen_example(dispatch_collection, sample_target, fn): +def _standard_listen_example( + dispatch_collection: "_ClsLevelDispatch[_ET]", + sample_target: Any, + fn: _ListenerFnType, +) -> str: example_kw_arg = _indent( "\n".join( "%(arg)s = kw['%(arg)s']" % {"arg": arg} @@ -128,7 +160,11 @@ def _standard_listen_example(dispatch_collection, sample_target, fn): return text -def _legacy_listen_examples(dispatch_collection, sample_target, fn): +def _legacy_listen_examples( + dispatch_collection: "_ClsLevelDispatch[_ET]", + sample_target: str, + fn: _ListenerFnType, +) -> str: text = "" for since, args, conv in dispatch_collection.legacy_signatures: text += ( @@ -152,7 +188,10 @@ def _legacy_listen_examples(dispatch_collection, sample_target, fn): return text -def _version_signature_changes(parent_dispatch_cls, dispatch_collection): +def _version_signature_changes( + parent_dispatch_cls: Type["_HasEventsDispatch[_ET]"], + dispatch_collection: "_ClsLevelDispatch[_ET]", +) -> str: since, args, conv = dispatch_collection.legacy_signatures[0] return ( "\n.. deprecated:: %(since)s\n" @@ -171,7 +210,11 @@ def _version_signature_changes(parent_dispatch_cls, dispatch_collection): ) -def _augment_fn_docs(dispatch_collection, parent_dispatch_cls, fn): +def _augment_fn_docs( + dispatch_collection: "_ClsLevelDispatch[_ET]", + parent_dispatch_cls: Type["_HasEventsDispatch[_ET]"], + fn: _ListenerFnType, +) -> str: header = ( ".. container:: event_signatures\n\n" " Example argument forms::\n" diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index d831a332f..e20d3e0b5 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -14,15 +14,60 @@ membership in all those collections can be revoked at once, based on an equivalent :class:`._EventKey`. """ +from __future__ import annotations + import collections import types +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Deque +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union import weakref from .. import exc from .. import util +from ..util.typing import Protocol + +if typing.TYPE_CHECKING: + from .attr import RefCollection + from .base import dispatcher + +_ListenerFnType = Callable[..., Any] +_ListenerFnKeyType = Union[int, Tuple[int, int]] +_EventKeyTupleType = Tuple[int, str, _ListenerFnKeyType] + + +class _EventTargetType(Protocol): + """represents an event target, that is, something we can listen on + either with that target as a class or as an instance. + + Examples include: Connection, Mapper, Table, Session, + InstrumentedAttribute, Engine, Pool, Dialect. + """ -_key_to_collection = collections.defaultdict(dict) + dispatch: ClassVar[dispatcher[Any]] + + +_ET = TypeVar("_ET", bound=_EventTargetType) + +_RefCollectionToListenerType = Dict[ + "weakref.ref[RefCollection[Any]]", + "weakref.ref[_ListenerFnType]", +] + +_key_to_collection: Dict[ + _EventKeyTupleType, _RefCollectionToListenerType +] = collections.defaultdict(dict) """ Given an original listen() argument, can locate all listener collections and the listener fn contained @@ -34,7 +79,14 @@ listener collections and the listener fn contained } """ -_collection_to_key = collections.defaultdict(dict) +_ListenerToEventKeyType = Dict[ + "weakref.ref[_ListenerFnType]", + _EventKeyTupleType, +] +_collection_to_key: Dict[ + weakref.ref[RefCollection[Any]], + _ListenerToEventKeyType, +] = collections.defaultdict(dict) """ Given a _ListenerCollection or _ClsLevelListener, can locate all the original listen() arguments and the listener fn contained @@ -47,10 +99,13 @@ ref(listenercollection) -> { """ -def _collection_gced(ref): +def _collection_gced(ref: weakref.ref[Any]) -> None: # defaultdict, so can't get a KeyError if not _collection_to_key or ref not in _collection_to_key: return + + ref = cast("weakref.ref[RefCollection[_EventTargetType]]", ref) + listener_to_key = _collection_to_key.pop(ref) for key in listener_to_key.values(): if key in _key_to_collection: @@ -61,7 +116,9 @@ def _collection_gced(ref): _key_to_collection.pop(key) -def _stored_in_collection(event_key, owner): +def _stored_in_collection( + event_key: _EventKey[_ET], owner: RefCollection[_ET] +) -> bool: key = event_key._key dispatch_reg = _key_to_collection[key] @@ -80,7 +137,9 @@ def _stored_in_collection(event_key, owner): return True -def _removed_from_collection(event_key, owner): +def _removed_from_collection( + event_key: _EventKey[_ET], owner: RefCollection[_ET] +) -> None: key = event_key._key dispatch_reg = _key_to_collection[key] @@ -97,15 +156,19 @@ def _removed_from_collection(event_key, owner): listener_to_key.pop(listen_ref) -def _stored_in_collection_multi(newowner, oldowner, elements): +def _stored_in_collection_multi( + newowner: RefCollection[_ET], + oldowner: RefCollection[_ET], + elements: Iterable[_ListenerFnType], +) -> None: if not elements: return - oldowner = oldowner.ref - newowner = newowner.ref + oldowner_ref = oldowner.ref + newowner_ref = newowner.ref - old_listener_to_key = _collection_to_key[oldowner] - new_listener_to_key = _collection_to_key[newowner] + old_listener_to_key = _collection_to_key[oldowner_ref] + new_listener_to_key = _collection_to_key[newowner_ref] for listen_fn in elements: listen_ref = weakref.ref(listen_fn) @@ -121,31 +184,34 @@ def _stored_in_collection_multi(newowner, oldowner, elements): except KeyError: continue - if newowner in dispatch_reg: - assert dispatch_reg[newowner] == listen_ref + if newowner_ref in dispatch_reg: + assert dispatch_reg[newowner_ref] == listen_ref else: - dispatch_reg[newowner] = listen_ref + dispatch_reg[newowner_ref] = listen_ref new_listener_to_key[listen_ref] = key -def _clear(owner, elements): +def _clear( + owner: RefCollection[_ET], + elements: Iterable[_ListenerFnType], +) -> None: if not elements: return - owner = owner.ref - listener_to_key = _collection_to_key[owner] + owner_ref = owner.ref + listener_to_key = _collection_to_key[owner_ref] for listen_fn in elements: listen_ref = weakref.ref(listen_fn) key = listener_to_key[listen_ref] dispatch_reg = _key_to_collection[key] - dispatch_reg.pop(owner, None) + dispatch_reg.pop(owner_ref, None) if not dispatch_reg: del _key_to_collection[key] -class _EventKey: +class _EventKey(Generic[_ET]): """Represent :func:`.listen` arguments.""" __slots__ = ( @@ -157,10 +223,24 @@ class _EventKey: "dispatch_target", ) - def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None): + target: _ET + identifier: str + fn: _ListenerFnType + fn_key: _ListenerFnKeyType + dispatch_target: Any + _fn_wrap: Optional[_ListenerFnType] + + def __init__( + self, + target: _ET, + identifier: str, + fn: _ListenerFnType, + dispatch_target: Any, + _fn_wrap: Optional[_ListenerFnType] = None, + ): self.target = target self.identifier = identifier - self.fn = fn + self.fn = fn # type: ignore[assignment] if isinstance(fn, types.MethodType): self.fn_key = id(fn.__func__), id(fn.__self__) else: @@ -169,10 +249,10 @@ class _EventKey: self.dispatch_target = dispatch_target @property - def _key(self): + def _key(self) -> _EventKeyTupleType: return (id(self.target), self.identifier, self.fn_key) - def with_wrapper(self, fn_wrap): + def with_wrapper(self, fn_wrap: _ListenerFnType) -> _EventKey[_ET]: if fn_wrap is self._listen_fn: return self else: @@ -184,7 +264,7 @@ class _EventKey: _fn_wrap=fn_wrap, ) - def with_dispatch_target(self, dispatch_target): + def with_dispatch_target(self, dispatch_target: Any) -> _EventKey[_ET]: if dispatch_target is self.dispatch_target: return self else: @@ -196,7 +276,7 @@ class _EventKey: _fn_wrap=self.fn_wrap, ) - def listen(self, *args, **kw): + def listen(self, *args: Any, **kw: Any) -> None: once = kw.pop("once", False) once_unless_exception = kw.pop("_once_unless_exception", False) named = kw.pop("named", False) @@ -228,7 +308,7 @@ class _EventKey: else: self.dispatch_target.dispatch._listen(self, *args, **kw) - def remove(self): + def remove(self) -> None: key = self._key if key not in _key_to_collection: @@ -245,18 +325,18 @@ class _EventKey: if collection is not None and listener_fn is not None: collection.remove(self.with_wrapper(listener_fn)) - def contains(self): + def contains(self) -> bool: """Return True if this event key is registered to listen.""" return self._key in _key_to_collection def base_listen( self, - propagate=False, - insert=False, - named=False, - retval=None, - asyncio=False, - ): + propagate: bool = False, + insert: bool = False, + named: bool = False, + retval: Optional[bool] = None, + asyncio: bool = False, + ) -> None: target, identifier = self.dispatch_target, self.identifier @@ -272,21 +352,33 @@ class _EventKey: for_modify.append(self, propagate) @property - def _listen_fn(self): + def _listen_fn(self) -> _ListenerFnType: return self.fn_wrap or self.fn - def append_to_list(self, owner, list_): + def append_to_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> bool: if _stored_in_collection(self, owner): list_.append(self._listen_fn) return True else: return False - def remove_from_list(self, owner, list_): + def remove_from_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> None: _removed_from_collection(self, owner) list_.remove(self._listen_fn) - def prepend_to_list(self, owner, list_): + def prepend_to_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> bool: if _stored_in_collection(self, owner): list_.appendleft(self._listen_fn) return True |
