summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/event/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/event/base.py')
-rw-r--r--lib/sqlalchemy/event/base.py212
1 files changed, 139 insertions, 73 deletions
diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py
index 25d369240..0e0647036 100644
--- a/lib/sqlalchemy/event/base.py
+++ b/lib/sqlalchemy/event/base.py
@@ -15,21 +15,37 @@ at the class level of a particular ``_Dispatch`` class as well as within
instances of ``_Dispatch``.
"""
-from typing import ClassVar
+from __future__ import annotations
+
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Generic
+from typing import Iterator
+from typing import List
+from typing import MutableMapping
from typing import Optional
+from typing import overload
+from typing import Tuple
from typing import Type
+from typing import Union
import weakref
from .attr import _ClsLevelDispatch
from .attr import _EmptyListener
+from .attr import _InstanceLevelDispatch
from .attr import _JoinedListener
+from .registry import _ET
+from .registry import _EventKey
from .. import util
-from ..util.typing import Protocol
+from ..util.typing import Literal
-_registrars = util.defaultdict(list)
+_registrars: MutableMapping[
+ str, List[Type[_HasEventsDispatch[Any]]]
+] = util.defaultdict(list)
-def _is_event_name(name):
+def _is_event_name(name: str) -> bool:
# _sa_event prefix is special to support internal-only event names.
# most event names are just plain method names that aren't
# underscored.
@@ -45,17 +61,17 @@ class _UnpickleDispatch:
"""
- def __call__(self, _instance_cls):
+ def __call__(self, _instance_cls: Type[_ET]) -> _Dispatch[_ET]:
for cls in _instance_cls.__mro__:
if "dispatch" in cls.__dict__:
- return cls.__dict__["dispatch"].dispatch._for_class(
- _instance_cls
- )
+ return cast(
+ "_Dispatch[_ET]", cls.__dict__["dispatch"].dispatch
+ )._for_class(_instance_cls)
else:
raise AttributeError("No class with a 'dispatch' member present.")
-class _Dispatch:
+class _Dispatch(Generic[_ET]):
"""Mirror the event listening definitions of an Events class with
listener collections.
@@ -79,20 +95,35 @@ class _Dispatch:
# so __dict__ is used in just that case and potentially others.
__slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners"
- _empty_listener_reg = weakref.WeakKeyDictionary()
+ _empty_listener_reg: MutableMapping[
+ Type[_ET], Dict[str, _EmptyListener[_ET]]
+ ] = weakref.WeakKeyDictionary()
+
+ _empty_listeners: Dict[str, _EmptyListener[_ET]]
+
+ _event_names: List[str]
+
+ _instance_cls: Optional[Type[_ET]]
- _events: Type["_HasEventsDispatch"]
+ _joined_dispatch_cls: Type[_JoinedDispatcher[_ET]]
+
+ _events: Type[_HasEventsDispatch[_ET]]
"""reference back to the Events class.
Bidirectional against _HasEventsDispatch.dispatch
"""
- def __init__(self, parent, instance_cls=None):
+ def __init__(
+ self,
+ parent: Optional[_Dispatch[_ET]],
+ instance_cls: Optional[Type[_ET]] = None,
+ ):
self._parent = parent
self._instance_cls = instance_cls
if instance_cls:
+ assert parent is not None
try:
self._empty_listeners = self._empty_listener_reg[instance_cls]
except KeyError:
@@ -105,7 +136,7 @@ class _Dispatch:
else:
self._empty_listeners = {}
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]:
# Assign EmptyListeners as attributes on demand
# to reduce startup time for new dispatch objects.
try:
@@ -117,24 +148,23 @@ class _Dispatch:
return ls
@property
- def _event_descriptors(self):
+ def _event_descriptors(self) -> Iterator[_ClsLevelDispatch[_ET]]:
for k in self._event_names:
# Yield _ClsLevelDispatch related
# to relevant event name.
yield getattr(self, k)
- @property
- def _listen(self):
- return self._events._listen
+ def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None:
+ return self._events._listen(event_key, **kw)
- def _for_class(self, instance_cls):
+ def _for_class(self, instance_cls: Type[_ET]) -> _Dispatch[_ET]:
return self.__class__(self, instance_cls)
- def _for_instance(self, instance):
+ def _for_instance(self, instance: _ET) -> _Dispatch[_ET]:
instance_cls = instance.__class__
return self._for_class(instance_cls)
- def _join(self, other):
+ def _join(self, other: _Dispatch[_ET]) -> _JoinedDispatcher[_ET]:
"""Create a 'join' of this :class:`._Dispatch` and another.
This new dispatcher will dispatch events to both
@@ -147,14 +177,15 @@ class _Dispatch:
(_JoinedDispatcher,),
{"__slots__": self._event_names},
)
-
self.__class__._joined_dispatch_cls = cls
return self._joined_dispatch_cls(self, other)
- def __reduce__(self):
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
return _UnpickleDispatch(), (self._instance_cls,)
- def _update(self, other, only_propagate=True):
+ def _update(
+ self, other: _Dispatch[_ET], only_propagate: bool = True
+ ) -> None:
"""Populate from the listeners in another :class:`_Dispatch`
object."""
for ls in other._event_descriptors:
@@ -164,32 +195,23 @@ class _Dispatch:
ls, only_propagate=only_propagate
)
- def _clear(self):
+ def _clear(self) -> None:
for ls in self._event_descriptors:
ls.for_modify(self).clear()
-def _remove_dispatcher(cls):
+def _remove_dispatcher(cls: Type[_HasEventsDispatch[_ET]]) -> None:
for k in cls.dispatch._event_names:
_registrars[k].remove(cls)
if not _registrars[k]:
del _registrars[k]
-class _HasEventsDispatchProto(Protocol):
- """protocol for non-event classes that will also receive the 'dispatch'
- attribute in the form of a descriptor.
-
- """
-
- dispatch: ClassVar["dispatcher"]
-
-
-class _HasEventsDispatch:
- _dispatch_target: Optional[Type[_HasEventsDispatchProto]]
+class _HasEventsDispatch(Generic[_ET]):
+ _dispatch_target: Optional[Type[_ET]]
"""class which will receive the .dispatch collection"""
- dispatch: _Dispatch
+ dispatch: _Dispatch[_ET]
"""reference back to the _Dispatch class.
Bidirectional against _Dispatch._events
@@ -202,19 +224,41 @@ class _HasEventsDispatch:
cls._create_dispatcher_class(cls.__name__, cls.__bases__, cls.__dict__)
+ @classmethod
+ def _accept_with(
+ cls, target: Union[_ET, Type[_ET]]
+ ) -> Optional[Union[_ET, Type[_ET]]]:
+ raise NotImplementedError()
+
+ @classmethod
+ def _listen(
+ cls,
+ event_key: _EventKey[_ET],
+ propagate: bool = False,
+ insert: bool = False,
+ named: bool = False,
+ asyncio: bool = False,
+ ) -> None:
+ raise NotImplementedError()
+
@staticmethod
- def _set_dispatch(cls, dispatch_cls):
+ def _set_dispatch(
+ klass: Type[_HasEventsDispatch[_ET]],
+ dispatch_cls: Type[_Dispatch[_ET]],
+ ) -> _Dispatch[_ET]:
# This allows an Events subclass to define additional utility
# methods made available to the target via
# "self.dispatch._events.<utilitymethod>"
# @staticmethod to allow easy "super" calls while in a metaclass
# constructor.
- cls.dispatch = dispatch_cls(None)
- dispatch_cls._events = cls
- return cls.dispatch
+ klass.dispatch = dispatch_cls(None)
+ dispatch_cls._events = klass
+ return klass.dispatch
@classmethod
- def _create_dispatcher_class(cls, classname, bases, dict_):
+ def _create_dispatcher_class(
+ cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any]
+ ) -> None:
"""Create a :class:`._Dispatch` class corresponding to an
:class:`.Events` class."""
@@ -227,14 +271,16 @@ class _HasEventsDispatch:
dispatch_base = _Dispatch
event_names = [k for k in dict_ if _is_event_name(k)]
- dispatch_cls = type(
- "%sDispatch" % classname,
- (dispatch_base,),
- {"__slots__": event_names},
+ dispatch_cls = cast(
+ "Type[_Dispatch[_ET]]",
+ type(
+ "%sDispatch" % classname,
+ (dispatch_base,), # type: ignore
+ {"__slots__": event_names},
+ ),
)
dispatch_cls._event_names = event_names
-
dispatch_inst = cls._set_dispatch(cls, dispatch_cls)
for k in dispatch_cls._event_names:
setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k]))
@@ -251,23 +297,28 @@ class _HasEventsDispatch:
assert dispatch_target_cls is not None
if (
hasattr(dispatch_target_cls, "__slots__")
- and "_slots_dispatch" in dispatch_target_cls.__slots__
+ and "_slots_dispatch" in dispatch_target_cls.__slots__ # type: ignore # noqa E501
):
dispatch_target_cls.dispatch = slots_dispatcher(cls)
else:
dispatch_target_cls.dispatch = dispatcher(cls)
-class Events(_HasEventsDispatch):
+class Events(_HasEventsDispatch[_ET]):
"""Define event listening functions for a particular target type."""
@classmethod
- def _accept_with(cls, target):
- def dispatch_is(*types):
+ def _accept_with(
+ cls, target: Union[_ET, Type[_ET]]
+ ) -> Optional[Union[_ET, Type[_ET]]]:
+ def dispatch_is(*types: Type[Any]) -> bool:
return all(isinstance(target.dispatch, t) for t in types)
- def dispatch_parent_is(t):
- return isinstance(target.dispatch.parent, t)
+ def dispatch_parent_is(t: Type[Any]) -> bool:
+
+ return isinstance(
+ cast("_JoinedDispatcher[_ET]", target.dispatch).parent, t
+ )
# Mapper, ClassManager, Session override this to
# also accept classes, scoped_sessions, sessionmakers, etc.
@@ -282,39 +333,45 @@ class Events(_HasEventsDispatch):
):
return target
+ return None
+
@classmethod
def _listen(
cls,
- event_key,
- propagate=False,
- insert=False,
- named=False,
- asyncio=False,
- ):
+ event_key: _EventKey[_ET],
+ propagate: bool = False,
+ insert: bool = False,
+ named: bool = False,
+ asyncio: bool = False,
+ ) -> None:
event_key.base_listen(
propagate=propagate, insert=insert, named=named, asyncio=asyncio
)
@classmethod
- def _remove(cls, event_key):
+ def _remove(cls, event_key: _EventKey[_ET]) -> None:
event_key.remove()
@classmethod
- def _clear(cls):
+ def _clear(cls) -> None:
cls.dispatch._clear()
-class _JoinedDispatcher:
+class _JoinedDispatcher(Generic[_ET]):
"""Represent a connection between two _Dispatch objects."""
__slots__ = "local", "parent", "_instance_cls"
- def __init__(self, local, parent):
+ local: _Dispatch[_ET]
+ parent: _Dispatch[_ET]
+ _instance_cls: Optional[Type[_ET]]
+
+ def __init__(self, local: _Dispatch[_ET], parent: _Dispatch[_ET]):
self.local = local
self.parent = parent
self._instance_cls = self.local._instance_cls
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> _JoinedListener[_ET]:
# Assign _JoinedListeners as attributes on demand
# to reduce startup time for new dispatch objects.
ls = getattr(self.local, name)
@@ -322,16 +379,15 @@ class _JoinedDispatcher:
setattr(self, ls.name, jl)
return jl
- @property
- def _listen(self):
- return self.parent._listen
+ def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None:
+ return self.parent._listen(event_key, **kw)
@property
- def _events(self):
+ def _events(self) -> Type[_HasEventsDispatch[_ET]]:
return self.parent._events
-class dispatcher:
+class dispatcher(Generic[_ET]):
"""Descriptor used by target classes to
deliver the _Dispatch class at the class level
and produce new _Dispatch instances for target
@@ -339,11 +395,21 @@ class dispatcher:
"""
- def __init__(self, events):
+ def __init__(self, events: Type[_HasEventsDispatch[_ET]]):
self.dispatch = events.dispatch
self.events = events
- def __get__(self, obj, cls):
+ @overload
+ def __get__(
+ self, obj: Literal[None], cls: Type[Any]
+ ) -> Type[_HasEventsDispatch[_ET]]:
+ ...
+
+ @overload
+ def __get__(self, obj: Any, cls: Type[Any]) -> _HasEventsDispatch[_ET]:
+ ...
+
+ def __get__(self, obj: Any, cls: Type[Any]) -> Any:
if obj is None:
return self.dispatch
@@ -358,8 +424,8 @@ class dispatcher:
return disp
-class slots_dispatcher(dispatcher):
- def __get__(self, obj, cls):
+class slots_dispatcher(dispatcher[_ET]):
+ def __get__(self, obj: Any, cls: Type[Any]) -> Any:
if obj is None:
return self.dispatch