diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-02-13 16:45:18 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-02-15 17:10:33 -0500 |
| commit | 5c6081ddb03447697f909a03572b6d6d79e61b71 (patch) | |
| tree | 8124ba2e9a496dcb6ac6ea92626804d261cc4c5d /lib/sqlalchemy/event/attr.py | |
| parent | 619abb52b6f1ee023db0f85fd96ba9f88c8efa7b (diff) | |
| download | sqlalchemy-5c6081ddb03447697f909a03572b6d6d79e61b71.tar.gz | |
pep-484 for sqlalchemy.event; use future annotations
__future__.annotations mode allows us to use non-string
annotations for argument and return types in most cases,
but more importantly it removes a large amount of runtime
overhead that would be spent in evaluating the annotations.
Change-Id: I2f5b6126fe0019713fc50001be3627b664019ede
References: #6810
Diffstat (limited to 'lib/sqlalchemy/event/attr.py')
| -rw-r--r-- | lib/sqlalchemy/event/attr.py | 296 |
1 files changed, 224 insertions, 72 deletions
diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index a05966222..d1ae7a845 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -28,43 +28,89 @@ as well as support for subclass propagation (e.g. events assigned to ``Pool`` vs. ``QueuePool``) are all implemented here. """ +from __future__ import annotations + import collections from itertools import chain import threading +from types import TracebackType +import typing +from typing import Any +from typing import cast +from typing import Collection +from typing import Deque +from typing import FrozenSet +from typing import Generic +from typing import Iterator +from typing import MutableMapping +from typing import MutableSequence +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union import weakref from . import legacy from . import registry +from .registry import _ET +from .registry import _EventKey +from .registry import _ListenerFnType from .. import exc from .. import util from ..util.concurrency import AsyncAdaptedLock +from ..util.typing import Protocol + +_T = TypeVar("_T", bound=Any) + +if typing.TYPE_CHECKING: + from .base import _Dispatch + from .base import _HasEventsDispatch + from .base import _JoinedDispatcher -class RefCollection(util.MemoizedSlots): +class RefCollection(util.MemoizedSlots, Generic[_ET]): __slots__ = ("ref",) - def _memoized_attr_ref(self): + ref: weakref.ref[RefCollection[_ET]] + + def _memoized_attr_ref(self) -> weakref.ref[RefCollection[_ET]]: return weakref.ref(self, registry._collection_gced) -class _empty_collection: - def append(self, element): +class _empty_collection(Collection[_T]): + def append(self, element: _T) -> None: + pass + + def appendleft(self, element: _T) -> None: pass - def extend(self, other): + def extend(self, other: Sequence[_T]) -> None: pass - def remove(self, element): + def remove(self, element: _T) -> None: pass - def __iter__(self): + def __contains__(self, element: Any) -> bool: + return False + + def __iter__(self) -> Iterator[_T]: return iter([]) - def clear(self): + def clear(self) -> None: pass + def __len__(self) -> int: + return 0 + + +_ListenerFnSequenceType = Union[Deque[_T], _empty_collection[_T]] + -class _ClsLevelDispatch(RefCollection): +class _ClsLevelDispatch(RefCollection[_ET]): """Class-level events on :class:`._Dispatch` classes.""" __slots__ = ( @@ -77,7 +123,20 @@ class _ClsLevelDispatch(RefCollection): "__weakref__", ) - def __init__(self, parent_dispatch_cls, fn): + clsname: str + name: str + arg_names: Sequence[str] + has_kw: bool + legacy_signatures: MutableSequence[legacy._LegacySignatureType] + _clslevel: MutableMapping[ + Type[_ET], _ListenerFnSequenceType[_ListenerFnType] + ] + + def __init__( + self, + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + fn: _ListenerFnType, + ): self.name = fn.__name__ self.clsname = parent_dispatch_cls.__name__ argspec = util.inspect_getfullargspec(fn) @@ -94,7 +153,9 @@ class _ClsLevelDispatch(RefCollection): self._clslevel = weakref.WeakKeyDictionary() - def _adjust_fn_spec(self, fn, named): + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: if named: fn = self._wrap_fn_for_kw(fn) if self.legacy_signatures: @@ -106,15 +167,15 @@ class _ClsLevelDispatch(RefCollection): fn = legacy._wrap_fn_for_legacy(self, fn, argspec) return fn - def _wrap_fn_for_kw(self, fn): - def wrap_kw(*args, **kw): + def _wrap_fn_for_kw(self, fn: _ListenerFnType) -> _ListenerFnType: + def wrap_kw(*args: Any, **kw: Any) -> Any: argdict = dict(zip(self.arg_names, args)) argdict.update(kw) return fn(**argdict) return wrap_kw - def insert(self, event_key, propagate): + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: target = event_key.dispatch_target assert isinstance( target, type @@ -125,6 +186,7 @@ class _ClsLevelDispatch(RefCollection): ) for cls in util.walk_subclasses(target): + cls = cast(Type[_ET], cls) if cls is not target and cls not in self._clslevel: self.update_subclass(cls) else: @@ -133,7 +195,7 @@ class _ClsLevelDispatch(RefCollection): self._clslevel[cls].appendleft(event_key._listen_fn) registry._stored_in_collection(event_key, self) - def append(self, event_key, propagate): + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: target = event_key.dispatch_target assert isinstance( target, type @@ -143,6 +205,7 @@ class _ClsLevelDispatch(RefCollection): "Can't assign an event directly to the %s class" % target ) for cls in util.walk_subclasses(target): + cls = cast("Type[_ET]", cls) if cls is not target and cls not in self._clslevel: self.update_subclass(cls) else: @@ -151,39 +214,41 @@ class _ClsLevelDispatch(RefCollection): self._clslevel[cls].append(event_key._listen_fn) registry._stored_in_collection(event_key, self) - def _assign_cls_collection(self, target): + def _assign_cls_collection(self, target: Type[_ET]) -> None: if getattr(target, "_sa_propagate_class_events", True): self._clslevel[target] = collections.deque() else: self._clslevel[target] = _empty_collection() - def update_subclass(self, target): + def update_subclass(self, target: Type[_ET]) -> None: if target not in self._clslevel: self._assign_cls_collection(target) clslevel = self._clslevel[target] for cls in target.__mro__[1:]: + cls = cast("Type[_ET]", cls) if cls in self._clslevel: clslevel.extend( [fn for fn in self._clslevel[cls] if fn not in clslevel] ) - def remove(self, event_key): + def remove(self, event_key: _EventKey[_ET]) -> None: target = event_key.dispatch_target for cls in util.walk_subclasses(target): + cls = cast("Type[_ET]", cls) if cls in self._clslevel: self._clslevel[cls].remove(event_key._listen_fn) registry._removed_from_collection(event_key, self) - def clear(self): + def clear(self) -> None: """Clear all class level listeners""" - to_clear = set() + to_clear: Set[_ListenerFnType] = set() for dispatcher in self._clslevel.values(): to_clear.update(dispatcher) dispatcher.clear() registry._clear(self, to_clear) - def for_modify(self, obj): + def for_modify(self, obj: _Dispatch[_ET]) -> _ClsLevelDispatch[_ET]: """Return an event collection which can be modified. For _ClsLevelDispatch at the class level of @@ -193,14 +258,30 @@ class _ClsLevelDispatch(RefCollection): return self -class _InstanceLevelDispatch(RefCollection): +class _InstanceLevelDispatch(RefCollection[_ET], Collection[_ListenerFnType]): __slots__ = () - def _adjust_fn_spec(self, fn, named): + parent: _ClsLevelDispatch[_ET] + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: return self.parent._adjust_fn_spec(fn, named) + def __contains__(self, item: Any) -> bool: + raise NotImplementedError() + + def __len__(self) -> int: + raise NotImplementedError() + + def __iter__(self) -> Iterator[_ListenerFnType]: + raise NotImplementedError() + + def __bool__(self) -> bool: + raise NotImplementedError() + -class _EmptyListener(_InstanceLevelDispatch): +class _EmptyListener(_InstanceLevelDispatch[_ET]): """Serves as a proxy interface to the events served by a _ClsLevelDispatch, when there are no instance-level events present. @@ -210,19 +291,22 @@ class _EmptyListener(_InstanceLevelDispatch): """ - propagate = frozenset() - listeners = () - __slots__ = "parent", "parent_listeners", "name" - def __init__(self, parent, target_cls): + propagate: FrozenSet[_ListenerFnType] = frozenset() + listeners: Tuple[()] = () + parent: _ClsLevelDispatch[_ET] + parent_listeners: _ListenerFnSequenceType[_ListenerFnType] + name: str + + def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) - self.parent = parent # _ClsLevelDispatch + self.parent = parent self.parent_listeners = parent._clslevel[target_cls] self.name = parent.name - def for_modify(self, obj): + def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]: """Return an event collection which can be modified. For _EmptyListener at the instance level of @@ -231,6 +315,7 @@ class _EmptyListener(_InstanceLevelDispatch): and returns it. """ + assert obj._instance_cls is not None result = _ListenerCollection(self.parent, obj._instance_cls) if getattr(obj, self.name) is self: setattr(obj, self.name, result) @@ -238,41 +323,79 @@ class _EmptyListener(_InstanceLevelDispatch): assert isinstance(getattr(obj, self.name), _JoinedListener) return result - def _needs_modify(self, *args, **kw): + def _needs_modify(self, *args: Any, **kw: Any) -> NoReturn: raise NotImplementedError("need to call for_modify()") - exec_once = ( - exec_once_unless_exception - ) = insert = append = remove = clear = _needs_modify + def exec_once(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def insert(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) - def __call__(self, *args, **kw): + def append(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def remove(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def clear(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def __call__(self, *args: Any, **kw: Any) -> None: """Execute this event.""" for fn in self.parent_listeners: fn(*args, **kw) - def __len__(self): + def __contains__(self, item: Any) -> bool: + return item in self.parent_listeners + + def __len__(self) -> int: return len(self.parent_listeners) - def __iter__(self): + def __iter__(self) -> Iterator[_ListenerFnType]: return iter(self.parent_listeners) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.parent_listeners) __nonzero__ = __bool__ -class _CompoundListener(_InstanceLevelDispatch): +class _MutexProtocol(Protocol): + def __enter__(self) -> bool: + ... + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + ... + + +class _CompoundListener(_InstanceLevelDispatch[_ET]): __slots__ = "_exec_once_mutex", "_exec_once", "_exec_w_sync_once" - def _set_asyncio(self): + _exec_once_mutex: _MutexProtocol + parent_listeners: Collection[_ListenerFnType] + listeners: Collection[_ListenerFnType] + _exec_once: bool + _exec_w_sync_once: bool + + def _set_asyncio(self) -> None: self._exec_once_mutex = AsyncAdaptedLock() - def _memoized_attr__exec_once_mutex(self): + def _memoized_attr__exec_once_mutex(self) -> _MutexProtocol: return threading.Lock() - def _exec_once_impl(self, retry_on_exception, *args, **kw): + def _exec_once_impl( + self, retry_on_exception: bool, *args: Any, **kw: Any + ) -> None: with self._exec_once_mutex: if not self._exec_once: try: @@ -285,14 +408,14 @@ class _CompoundListener(_InstanceLevelDispatch): if not exception or not retry_on_exception: self._exec_once = True - def exec_once(self, *args, **kw): + def exec_once(self, *args: Any, **kw: Any) -> None: """Execute this event, but only if it has not been executed already for this collection.""" if not self._exec_once: self._exec_once_impl(False, *args, **kw) - def exec_once_unless_exception(self, *args, **kw): + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None: """Execute this event, but only if it has not been executed already for this collection, or was called by a previous exec_once_unless_exception call and @@ -307,7 +430,7 @@ class _CompoundListener(_InstanceLevelDispatch): if not self._exec_once: self._exec_once_impl(True, *args, **kw) - def _exec_w_sync_on_first_run(self, *args, **kw): + def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None: """Execute this event, and use a mutex if it has not been executed already for this collection, or was called by a previous _exec_w_sync_on_first_run call and @@ -330,7 +453,7 @@ class _CompoundListener(_InstanceLevelDispatch): else: self(*args, **kw) - def __call__(self, *args, **kw): + def __call__(self, *args: Any, **kw: Any) -> None: """Execute this event.""" for fn in self.parent_listeners: @@ -338,19 +461,22 @@ class _CompoundListener(_InstanceLevelDispatch): for fn in self.listeners: fn(*args, **kw) - def __len__(self): + def __contains__(self, item: Any) -> bool: + return item in self.parent_listeners or item in self.listeners + + def __len__(self) -> int: return len(self.parent_listeners) + len(self.listeners) - def __iter__(self): + def __iter__(self) -> Iterator[_ListenerFnType]: return chain(self.parent_listeners, self.listeners) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.listeners or self.parent_listeners) __nonzero__ = __bool__ -class _ListenerCollection(_CompoundListener): +class _ListenerCollection(_CompoundListener[_ET]): """Instance-level attributes on instances of :class:`._Dispatch`. Represents a collection of listeners. @@ -369,7 +495,13 @@ class _ListenerCollection(_CompoundListener): "__weakref__", ) - def __init__(self, parent, target_cls): + parent_listeners: Collection[_ListenerFnType] + parent: _ClsLevelDispatch[_ET] + name: str + listeners: Deque[_ListenerFnType] + propagate: Set[_ListenerFnType] + + def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): if target_cls not in parent._clslevel: parent.update_subclass(target_cls) self._exec_once = False @@ -380,7 +512,7 @@ class _ListenerCollection(_CompoundListener): self.listeners = collections.deque() self.propagate = set() - def for_modify(self, obj): + def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]: """Return an event collection which can be modified. For _ListenerCollection at the instance level of @@ -389,10 +521,11 @@ class _ListenerCollection(_CompoundListener): """ return self - def _update(self, other, only_propagate=True): + def _update( + self, other: _ListenerCollection[_ET], only_propagate: bool = True + ) -> None: """Populate from the listeners in another :class:`_Dispatch` object.""" - existing_listeners = self.listeners existing_listener_set = set(existing_listeners) self.propagate.update(other.propagate) @@ -409,56 +542,75 @@ class _ListenerCollection(_CompoundListener): to_associate = other.propagate.union(other_listeners) registry._stored_in_collection_multi(self, other, to_associate) - def insert(self, event_key, propagate): + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: if event_key.prepend_to_list(self, self.listeners): if propagate: self.propagate.add(event_key._listen_fn) - def append(self, event_key, propagate): + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: if event_key.append_to_list(self, self.listeners): if propagate: self.propagate.add(event_key._listen_fn) - def remove(self, event_key): + def remove(self, event_key: _EventKey[_ET]) -> None: self.listeners.remove(event_key._listen_fn) self.propagate.discard(event_key._listen_fn) registry._removed_from_collection(event_key, self) - def clear(self): + def clear(self) -> None: registry._clear(self, self.listeners) self.propagate.clear() self.listeners.clear() -class _JoinedListener(_CompoundListener): - __slots__ = "parent", "name", "local", "parent_listeners" +class _JoinedListener(_CompoundListener[_ET]): + __slots__ = "parent_dispatch", "name", "local", "parent_listeners" + + parent_dispatch: _Dispatch[_ET] + name: str + local: _InstanceLevelDispatch[_ET] + parent_listeners: Collection[_ListenerFnType] - def __init__(self, parent, name, local): + def __init__( + self, + parent_dispatch: _Dispatch[_ET], + name: str, + local: _EmptyListener[_ET], + ): self._exec_once = False - self.parent = parent + self.parent_dispatch = parent_dispatch self.name = name self.local = local self.parent_listeners = self.local - @property - def listeners(self): - return getattr(self.parent, self.name) - - def _adjust_fn_spec(self, fn, named): + if not typing.TYPE_CHECKING: + # first error, I don't really understand: + # Signature of "listeners" incompatible with + # supertype "_CompoundListener" [override] + # the name / return type are exactly the same + # second error is getattr_isn't typed, the cast() here + # adds too much method overhead + @property + def listeners(self) -> Collection[_ListenerFnType]: + return getattr(self.parent_dispatch, self.name) + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: return self.local._adjust_fn_spec(fn, named) - def for_modify(self, obj): + def for_modify(self, obj: _JoinedDispatcher[_ET]) -> _JoinedListener[_ET]: self.local = self.parent_listeners = self.local.for_modify(obj) return self - def insert(self, event_key, propagate): + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: self.local.insert(event_key, propagate) - def append(self, event_key, propagate): + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: self.local.append(event_key, propagate) - def remove(self, event_key): + def remove(self, event_key: _EventKey[_ET]) -> None: self.local.remove(event_key) - def clear(self): + def clear(self) -> None: raise NotImplementedError() |
