diff options
Diffstat (limited to 'lib/sqlalchemy/event/base.py')
| -rw-r--r-- | lib/sqlalchemy/event/base.py | 212 |
1 files changed, 139 insertions, 73 deletions
diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index 25d369240..0e0647036 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -15,21 +15,37 @@ at the class level of a particular ``_Dispatch`` class as well as within instances of ``_Dispatch``. """ -from typing import ClassVar +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterator +from typing import List +from typing import MutableMapping from typing import Optional +from typing import overload +from typing import Tuple from typing import Type +from typing import Union import weakref from .attr import _ClsLevelDispatch from .attr import _EmptyListener +from .attr import _InstanceLevelDispatch from .attr import _JoinedListener +from .registry import _ET +from .registry import _EventKey from .. import util -from ..util.typing import Protocol +from ..util.typing import Literal -_registrars = util.defaultdict(list) +_registrars: MutableMapping[ + str, List[Type[_HasEventsDispatch[Any]]] +] = util.defaultdict(list) -def _is_event_name(name): +def _is_event_name(name: str) -> bool: # _sa_event prefix is special to support internal-only event names. # most event names are just plain method names that aren't # underscored. @@ -45,17 +61,17 @@ class _UnpickleDispatch: """ - def __call__(self, _instance_cls): + def __call__(self, _instance_cls: Type[_ET]) -> _Dispatch[_ET]: for cls in _instance_cls.__mro__: if "dispatch" in cls.__dict__: - return cls.__dict__["dispatch"].dispatch._for_class( - _instance_cls - ) + return cast( + "_Dispatch[_ET]", cls.__dict__["dispatch"].dispatch + )._for_class(_instance_cls) else: raise AttributeError("No class with a 'dispatch' member present.") -class _Dispatch: +class _Dispatch(Generic[_ET]): """Mirror the event listening definitions of an Events class with listener collections. @@ -79,20 +95,35 @@ class _Dispatch: # so __dict__ is used in just that case and potentially others. __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners" - _empty_listener_reg = weakref.WeakKeyDictionary() + _empty_listener_reg: MutableMapping[ + Type[_ET], Dict[str, _EmptyListener[_ET]] + ] = weakref.WeakKeyDictionary() + + _empty_listeners: Dict[str, _EmptyListener[_ET]] + + _event_names: List[str] + + _instance_cls: Optional[Type[_ET]] - _events: Type["_HasEventsDispatch"] + _joined_dispatch_cls: Type[_JoinedDispatcher[_ET]] + + _events: Type[_HasEventsDispatch[_ET]] """reference back to the Events class. Bidirectional against _HasEventsDispatch.dispatch """ - def __init__(self, parent, instance_cls=None): + def __init__( + self, + parent: Optional[_Dispatch[_ET]], + instance_cls: Optional[Type[_ET]] = None, + ): self._parent = parent self._instance_cls = instance_cls if instance_cls: + assert parent is not None try: self._empty_listeners = self._empty_listener_reg[instance_cls] except KeyError: @@ -105,7 +136,7 @@ class _Dispatch: else: self._empty_listeners = {} - def __getattr__(self, name): + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: # Assign EmptyListeners as attributes on demand # to reduce startup time for new dispatch objects. try: @@ -117,24 +148,23 @@ class _Dispatch: return ls @property - def _event_descriptors(self): + def _event_descriptors(self) -> Iterator[_ClsLevelDispatch[_ET]]: for k in self._event_names: # Yield _ClsLevelDispatch related # to relevant event name. yield getattr(self, k) - @property - def _listen(self): - return self._events._listen + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self._events._listen(event_key, **kw) - def _for_class(self, instance_cls): + def _for_class(self, instance_cls: Type[_ET]) -> _Dispatch[_ET]: return self.__class__(self, instance_cls) - def _for_instance(self, instance): + def _for_instance(self, instance: _ET) -> _Dispatch[_ET]: instance_cls = instance.__class__ return self._for_class(instance_cls) - def _join(self, other): + def _join(self, other: _Dispatch[_ET]) -> _JoinedDispatcher[_ET]: """Create a 'join' of this :class:`._Dispatch` and another. This new dispatcher will dispatch events to both @@ -147,14 +177,15 @@ class _Dispatch: (_JoinedDispatcher,), {"__slots__": self._event_names}, ) - self.__class__._joined_dispatch_cls = cls return self._joined_dispatch_cls(self, other) - def __reduce__(self): + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return _UnpickleDispatch(), (self._instance_cls,) - def _update(self, other, only_propagate=True): + def _update( + self, other: _Dispatch[_ET], only_propagate: bool = True + ) -> None: """Populate from the listeners in another :class:`_Dispatch` object.""" for ls in other._event_descriptors: @@ -164,32 +195,23 @@ class _Dispatch: ls, only_propagate=only_propagate ) - def _clear(self): + def _clear(self) -> None: for ls in self._event_descriptors: ls.for_modify(self).clear() -def _remove_dispatcher(cls): +def _remove_dispatcher(cls: Type[_HasEventsDispatch[_ET]]) -> None: for k in cls.dispatch._event_names: _registrars[k].remove(cls) if not _registrars[k]: del _registrars[k] -class _HasEventsDispatchProto(Protocol): - """protocol for non-event classes that will also receive the 'dispatch' - attribute in the form of a descriptor. - - """ - - dispatch: ClassVar["dispatcher"] - - -class _HasEventsDispatch: - _dispatch_target: Optional[Type[_HasEventsDispatchProto]] +class _HasEventsDispatch(Generic[_ET]): + _dispatch_target: Optional[Type[_ET]] """class which will receive the .dispatch collection""" - dispatch: _Dispatch + dispatch: _Dispatch[_ET] """reference back to the _Dispatch class. Bidirectional against _Dispatch._events @@ -202,19 +224,41 @@ class _HasEventsDispatch: cls._create_dispatcher_class(cls.__name__, cls.__bases__, cls.__dict__) + @classmethod + def _accept_with( + cls, target: Union[_ET, Type[_ET]] + ) -> Optional[Union[_ET, Type[_ET]]]: + raise NotImplementedError() + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + propagate: bool = False, + insert: bool = False, + named: bool = False, + asyncio: bool = False, + ) -> None: + raise NotImplementedError() + @staticmethod - def _set_dispatch(cls, dispatch_cls): + def _set_dispatch( + klass: Type[_HasEventsDispatch[_ET]], + dispatch_cls: Type[_Dispatch[_ET]], + ) -> _Dispatch[_ET]: # This allows an Events subclass to define additional utility # methods made available to the target via # "self.dispatch._events.<utilitymethod>" # @staticmethod to allow easy "super" calls while in a metaclass # constructor. - cls.dispatch = dispatch_cls(None) - dispatch_cls._events = cls - return cls.dispatch + klass.dispatch = dispatch_cls(None) + dispatch_cls._events = klass + return klass.dispatch @classmethod - def _create_dispatcher_class(cls, classname, bases, dict_): + def _create_dispatcher_class( + cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any] + ) -> None: """Create a :class:`._Dispatch` class corresponding to an :class:`.Events` class.""" @@ -227,14 +271,16 @@ class _HasEventsDispatch: dispatch_base = _Dispatch event_names = [k for k in dict_ if _is_event_name(k)] - dispatch_cls = type( - "%sDispatch" % classname, - (dispatch_base,), - {"__slots__": event_names}, + dispatch_cls = cast( + "Type[_Dispatch[_ET]]", + type( + "%sDispatch" % classname, + (dispatch_base,), # type: ignore + {"__slots__": event_names}, + ), ) dispatch_cls._event_names = event_names - dispatch_inst = cls._set_dispatch(cls, dispatch_cls) for k in dispatch_cls._event_names: setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) @@ -251,23 +297,28 @@ class _HasEventsDispatch: assert dispatch_target_cls is not None if ( hasattr(dispatch_target_cls, "__slots__") - and "_slots_dispatch" in dispatch_target_cls.__slots__ + and "_slots_dispatch" in dispatch_target_cls.__slots__ # type: ignore # noqa E501 ): dispatch_target_cls.dispatch = slots_dispatcher(cls) else: dispatch_target_cls.dispatch = dispatcher(cls) -class Events(_HasEventsDispatch): +class Events(_HasEventsDispatch[_ET]): """Define event listening functions for a particular target type.""" @classmethod - def _accept_with(cls, target): - def dispatch_is(*types): + def _accept_with( + cls, target: Union[_ET, Type[_ET]] + ) -> Optional[Union[_ET, Type[_ET]]]: + def dispatch_is(*types: Type[Any]) -> bool: return all(isinstance(target.dispatch, t) for t in types) - def dispatch_parent_is(t): - return isinstance(target.dispatch.parent, t) + def dispatch_parent_is(t: Type[Any]) -> bool: + + return isinstance( + cast("_JoinedDispatcher[_ET]", target.dispatch).parent, t + ) # Mapper, ClassManager, Session override this to # also accept classes, scoped_sessions, sessionmakers, etc. @@ -282,39 +333,45 @@ class Events(_HasEventsDispatch): ): return target + return None + @classmethod def _listen( cls, - event_key, - propagate=False, - insert=False, - named=False, - asyncio=False, - ): + event_key: _EventKey[_ET], + propagate: bool = False, + insert: bool = False, + named: bool = False, + asyncio: bool = False, + ) -> None: event_key.base_listen( propagate=propagate, insert=insert, named=named, asyncio=asyncio ) @classmethod - def _remove(cls, event_key): + def _remove(cls, event_key: _EventKey[_ET]) -> None: event_key.remove() @classmethod - def _clear(cls): + def _clear(cls) -> None: cls.dispatch._clear() -class _JoinedDispatcher: +class _JoinedDispatcher(Generic[_ET]): """Represent a connection between two _Dispatch objects.""" __slots__ = "local", "parent", "_instance_cls" - def __init__(self, local, parent): + local: _Dispatch[_ET] + parent: _Dispatch[_ET] + _instance_cls: Optional[Type[_ET]] + + def __init__(self, local: _Dispatch[_ET], parent: _Dispatch[_ET]): self.local = local self.parent = parent self._instance_cls = self.local._instance_cls - def __getattr__(self, name): + def __getattr__(self, name: str) -> _JoinedListener[_ET]: # Assign _JoinedListeners as attributes on demand # to reduce startup time for new dispatch objects. ls = getattr(self.local, name) @@ -322,16 +379,15 @@ class _JoinedDispatcher: setattr(self, ls.name, jl) return jl - @property - def _listen(self): - return self.parent._listen + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self.parent._listen(event_key, **kw) @property - def _events(self): + def _events(self) -> Type[_HasEventsDispatch[_ET]]: return self.parent._events -class dispatcher: +class dispatcher(Generic[_ET]): """Descriptor used by target classes to deliver the _Dispatch class at the class level and produce new _Dispatch instances for target @@ -339,11 +395,21 @@ class dispatcher: """ - def __init__(self, events): + def __init__(self, events: Type[_HasEventsDispatch[_ET]]): self.dispatch = events.dispatch self.events = events - def __get__(self, obj, cls): + @overload + def __get__( + self, obj: Literal[None], cls: Type[Any] + ) -> Type[_HasEventsDispatch[_ET]]: + ... + + @overload + def __get__(self, obj: Any, cls: Type[Any]) -> _HasEventsDispatch[_ET]: + ... + + def __get__(self, obj: Any, cls: Type[Any]) -> Any: if obj is None: return self.dispatch @@ -358,8 +424,8 @@ class dispatcher: return disp -class slots_dispatcher(dispatcher): - def __get__(self, obj, cls): +class slots_dispatcher(dispatcher[_ET]): + def __get__(self, obj: Any, cls: Type[Any]) -> Any: if obj is None: return self.dispatch |
