summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/event/attr.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-02-13 16:45:18 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-02-15 17:10:33 -0500
commit5c6081ddb03447697f909a03572b6d6d79e61b71 (patch)
tree8124ba2e9a496dcb6ac6ea92626804d261cc4c5d /lib/sqlalchemy/event/attr.py
parent619abb52b6f1ee023db0f85fd96ba9f88c8efa7b (diff)
downloadsqlalchemy-5c6081ddb03447697f909a03572b6d6d79e61b71.tar.gz
pep-484 for sqlalchemy.event; use future annotations
__future__.annotations mode allows us to use non-string annotations for argument and return types in most cases, but more importantly it removes a large amount of runtime overhead that would be spent in evaluating the annotations. Change-Id: I2f5b6126fe0019713fc50001be3627b664019ede References: #6810
Diffstat (limited to 'lib/sqlalchemy/event/attr.py')
-rw-r--r--lib/sqlalchemy/event/attr.py296
1 files changed, 224 insertions, 72 deletions
diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py
index a05966222..d1ae7a845 100644
--- a/lib/sqlalchemy/event/attr.py
+++ b/lib/sqlalchemy/event/attr.py
@@ -28,43 +28,89 @@ as well as support for subclass propagation (e.g. events assigned to
``Pool`` vs. ``QueuePool``) are all implemented here.
"""
+from __future__ import annotations
+
import collections
from itertools import chain
import threading
+from types import TracebackType
+import typing
+from typing import Any
+from typing import cast
+from typing import Collection
+from typing import Deque
+from typing import FrozenSet
+from typing import Generic
+from typing import Iterator
+from typing import MutableMapping
+from typing import MutableSequence
+from typing import NoReturn
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TypeVar
+from typing import Union
import weakref
from . import legacy
from . import registry
+from .registry import _ET
+from .registry import _EventKey
+from .registry import _ListenerFnType
from .. import exc
from .. import util
from ..util.concurrency import AsyncAdaptedLock
+from ..util.typing import Protocol
+
+_T = TypeVar("_T", bound=Any)
+
+if typing.TYPE_CHECKING:
+ from .base import _Dispatch
+ from .base import _HasEventsDispatch
+ from .base import _JoinedDispatcher
-class RefCollection(util.MemoizedSlots):
+class RefCollection(util.MemoizedSlots, Generic[_ET]):
__slots__ = ("ref",)
- def _memoized_attr_ref(self):
+ ref: weakref.ref[RefCollection[_ET]]
+
+ def _memoized_attr_ref(self) -> weakref.ref[RefCollection[_ET]]:
return weakref.ref(self, registry._collection_gced)
-class _empty_collection:
- def append(self, element):
+class _empty_collection(Collection[_T]):
+ def append(self, element: _T) -> None:
+ pass
+
+ def appendleft(self, element: _T) -> None:
pass
- def extend(self, other):
+ def extend(self, other: Sequence[_T]) -> None:
pass
- def remove(self, element):
+ def remove(self, element: _T) -> None:
pass
- def __iter__(self):
+ def __contains__(self, element: Any) -> bool:
+ return False
+
+ def __iter__(self) -> Iterator[_T]:
return iter([])
- def clear(self):
+ def clear(self) -> None:
pass
+ def __len__(self) -> int:
+ return 0
+
+
+_ListenerFnSequenceType = Union[Deque[_T], _empty_collection[_T]]
+
-class _ClsLevelDispatch(RefCollection):
+class _ClsLevelDispatch(RefCollection[_ET]):
"""Class-level events on :class:`._Dispatch` classes."""
__slots__ = (
@@ -77,7 +123,20 @@ class _ClsLevelDispatch(RefCollection):
"__weakref__",
)
- def __init__(self, parent_dispatch_cls, fn):
+ clsname: str
+ name: str
+ arg_names: Sequence[str]
+ has_kw: bool
+ legacy_signatures: MutableSequence[legacy._LegacySignatureType]
+ _clslevel: MutableMapping[
+ Type[_ET], _ListenerFnSequenceType[_ListenerFnType]
+ ]
+
+ def __init__(
+ self,
+ parent_dispatch_cls: Type[_HasEventsDispatch[_ET]],
+ fn: _ListenerFnType,
+ ):
self.name = fn.__name__
self.clsname = parent_dispatch_cls.__name__
argspec = util.inspect_getfullargspec(fn)
@@ -94,7 +153,9 @@ class _ClsLevelDispatch(RefCollection):
self._clslevel = weakref.WeakKeyDictionary()
- def _adjust_fn_spec(self, fn, named):
+ def _adjust_fn_spec(
+ self, fn: _ListenerFnType, named: bool
+ ) -> _ListenerFnType:
if named:
fn = self._wrap_fn_for_kw(fn)
if self.legacy_signatures:
@@ -106,15 +167,15 @@ class _ClsLevelDispatch(RefCollection):
fn = legacy._wrap_fn_for_legacy(self, fn, argspec)
return fn
- def _wrap_fn_for_kw(self, fn):
- def wrap_kw(*args, **kw):
+ def _wrap_fn_for_kw(self, fn: _ListenerFnType) -> _ListenerFnType:
+ def wrap_kw(*args: Any, **kw: Any) -> Any:
argdict = dict(zip(self.arg_names, args))
argdict.update(kw)
return fn(**argdict)
return wrap_kw
- def insert(self, event_key, propagate):
+ def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
target = event_key.dispatch_target
assert isinstance(
target, type
@@ -125,6 +186,7 @@ class _ClsLevelDispatch(RefCollection):
)
for cls in util.walk_subclasses(target):
+ cls = cast(Type[_ET], cls)
if cls is not target and cls not in self._clslevel:
self.update_subclass(cls)
else:
@@ -133,7 +195,7 @@ class _ClsLevelDispatch(RefCollection):
self._clslevel[cls].appendleft(event_key._listen_fn)
registry._stored_in_collection(event_key, self)
- def append(self, event_key, propagate):
+ def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
target = event_key.dispatch_target
assert isinstance(
target, type
@@ -143,6 +205,7 @@ class _ClsLevelDispatch(RefCollection):
"Can't assign an event directly to the %s class" % target
)
for cls in util.walk_subclasses(target):
+ cls = cast("Type[_ET]", cls)
if cls is not target and cls not in self._clslevel:
self.update_subclass(cls)
else:
@@ -151,39 +214,41 @@ class _ClsLevelDispatch(RefCollection):
self._clslevel[cls].append(event_key._listen_fn)
registry._stored_in_collection(event_key, self)
- def _assign_cls_collection(self, target):
+ def _assign_cls_collection(self, target: Type[_ET]) -> None:
if getattr(target, "_sa_propagate_class_events", True):
self._clslevel[target] = collections.deque()
else:
self._clslevel[target] = _empty_collection()
- def update_subclass(self, target):
+ def update_subclass(self, target: Type[_ET]) -> None:
if target not in self._clslevel:
self._assign_cls_collection(target)
clslevel = self._clslevel[target]
for cls in target.__mro__[1:]:
+ cls = cast("Type[_ET]", cls)
if cls in self._clslevel:
clslevel.extend(
[fn for fn in self._clslevel[cls] if fn not in clslevel]
)
- def remove(self, event_key):
+ def remove(self, event_key: _EventKey[_ET]) -> None:
target = event_key.dispatch_target
for cls in util.walk_subclasses(target):
+ cls = cast("Type[_ET]", cls)
if cls in self._clslevel:
self._clslevel[cls].remove(event_key._listen_fn)
registry._removed_from_collection(event_key, self)
- def clear(self):
+ def clear(self) -> None:
"""Clear all class level listeners"""
- to_clear = set()
+ to_clear: Set[_ListenerFnType] = set()
for dispatcher in self._clslevel.values():
to_clear.update(dispatcher)
dispatcher.clear()
registry._clear(self, to_clear)
- def for_modify(self, obj):
+ def for_modify(self, obj: _Dispatch[_ET]) -> _ClsLevelDispatch[_ET]:
"""Return an event collection which can be modified.
For _ClsLevelDispatch at the class level of
@@ -193,14 +258,30 @@ class _ClsLevelDispatch(RefCollection):
return self
-class _InstanceLevelDispatch(RefCollection):
+class _InstanceLevelDispatch(RefCollection[_ET], Collection[_ListenerFnType]):
__slots__ = ()
- def _adjust_fn_spec(self, fn, named):
+ parent: _ClsLevelDispatch[_ET]
+
+ def _adjust_fn_spec(
+ self, fn: _ListenerFnType, named: bool
+ ) -> _ListenerFnType:
return self.parent._adjust_fn_spec(fn, named)
+ def __contains__(self, item: Any) -> bool:
+ raise NotImplementedError()
+
+ def __len__(self) -> int:
+ raise NotImplementedError()
+
+ def __iter__(self) -> Iterator[_ListenerFnType]:
+ raise NotImplementedError()
+
+ def __bool__(self) -> bool:
+ raise NotImplementedError()
+
-class _EmptyListener(_InstanceLevelDispatch):
+class _EmptyListener(_InstanceLevelDispatch[_ET]):
"""Serves as a proxy interface to the events
served by a _ClsLevelDispatch, when there are no
instance-level events present.
@@ -210,19 +291,22 @@ class _EmptyListener(_InstanceLevelDispatch):
"""
- propagate = frozenset()
- listeners = ()
-
__slots__ = "parent", "parent_listeners", "name"
- def __init__(self, parent, target_cls):
+ propagate: FrozenSet[_ListenerFnType] = frozenset()
+ listeners: Tuple[()] = ()
+ parent: _ClsLevelDispatch[_ET]
+ parent_listeners: _ListenerFnSequenceType[_ListenerFnType]
+ name: str
+
+ def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]):
if target_cls not in parent._clslevel:
parent.update_subclass(target_cls)
- self.parent = parent # _ClsLevelDispatch
+ self.parent = parent
self.parent_listeners = parent._clslevel[target_cls]
self.name = parent.name
- def for_modify(self, obj):
+ def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]:
"""Return an event collection which can be modified.
For _EmptyListener at the instance level of
@@ -231,6 +315,7 @@ class _EmptyListener(_InstanceLevelDispatch):
and returns it.
"""
+ assert obj._instance_cls is not None
result = _ListenerCollection(self.parent, obj._instance_cls)
if getattr(obj, self.name) is self:
setattr(obj, self.name, result)
@@ -238,41 +323,79 @@ class _EmptyListener(_InstanceLevelDispatch):
assert isinstance(getattr(obj, self.name), _JoinedListener)
return result
- def _needs_modify(self, *args, **kw):
+ def _needs_modify(self, *args: Any, **kw: Any) -> NoReturn:
raise NotImplementedError("need to call for_modify()")
- exec_once = (
- exec_once_unless_exception
- ) = insert = append = remove = clear = _needs_modify
+ def exec_once(self, *args: Any, **kw: Any) -> NoReturn:
+ self._needs_modify(*args, **kw)
+
+ def exec_once_unless_exception(self, *args: Any, **kw: Any) -> NoReturn:
+ self._needs_modify(*args, **kw)
+
+ def insert(self, *args: Any, **kw: Any) -> NoReturn:
+ self._needs_modify(*args, **kw)
- def __call__(self, *args, **kw):
+ def append(self, *args: Any, **kw: Any) -> NoReturn:
+ self._needs_modify(*args, **kw)
+
+ def remove(self, *args: Any, **kw: Any) -> NoReturn:
+ self._needs_modify(*args, **kw)
+
+ def clear(self, *args: Any, **kw: Any) -> NoReturn:
+ self._needs_modify(*args, **kw)
+
+ def __call__(self, *args: Any, **kw: Any) -> None:
"""Execute this event."""
for fn in self.parent_listeners:
fn(*args, **kw)
- def __len__(self):
+ def __contains__(self, item: Any) -> bool:
+ return item in self.parent_listeners
+
+ def __len__(self) -> int:
return len(self.parent_listeners)
- def __iter__(self):
+ def __iter__(self) -> Iterator[_ListenerFnType]:
return iter(self.parent_listeners)
- def __bool__(self):
+ def __bool__(self) -> bool:
return bool(self.parent_listeners)
__nonzero__ = __bool__
-class _CompoundListener(_InstanceLevelDispatch):
+class _MutexProtocol(Protocol):
+ def __enter__(self) -> bool:
+ ...
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> Optional[bool]:
+ ...
+
+
+class _CompoundListener(_InstanceLevelDispatch[_ET]):
__slots__ = "_exec_once_mutex", "_exec_once", "_exec_w_sync_once"
- def _set_asyncio(self):
+ _exec_once_mutex: _MutexProtocol
+ parent_listeners: Collection[_ListenerFnType]
+ listeners: Collection[_ListenerFnType]
+ _exec_once: bool
+ _exec_w_sync_once: bool
+
+ def _set_asyncio(self) -> None:
self._exec_once_mutex = AsyncAdaptedLock()
- def _memoized_attr__exec_once_mutex(self):
+ def _memoized_attr__exec_once_mutex(self) -> _MutexProtocol:
return threading.Lock()
- def _exec_once_impl(self, retry_on_exception, *args, **kw):
+ def _exec_once_impl(
+ self, retry_on_exception: bool, *args: Any, **kw: Any
+ ) -> None:
with self._exec_once_mutex:
if not self._exec_once:
try:
@@ -285,14 +408,14 @@ class _CompoundListener(_InstanceLevelDispatch):
if not exception or not retry_on_exception:
self._exec_once = True
- def exec_once(self, *args, **kw):
+ def exec_once(self, *args: Any, **kw: Any) -> None:
"""Execute this event, but only if it has not been
executed already for this collection."""
if not self._exec_once:
self._exec_once_impl(False, *args, **kw)
- def exec_once_unless_exception(self, *args, **kw):
+ def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None:
"""Execute this event, but only if it has not been
executed already for this collection, or was called
by a previous exec_once_unless_exception call and
@@ -307,7 +430,7 @@ class _CompoundListener(_InstanceLevelDispatch):
if not self._exec_once:
self._exec_once_impl(True, *args, **kw)
- def _exec_w_sync_on_first_run(self, *args, **kw):
+ def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None:
"""Execute this event, and use a mutex if it has not been
executed already for this collection, or was called
by a previous _exec_w_sync_on_first_run call and
@@ -330,7 +453,7 @@ class _CompoundListener(_InstanceLevelDispatch):
else:
self(*args, **kw)
- def __call__(self, *args, **kw):
+ def __call__(self, *args: Any, **kw: Any) -> None:
"""Execute this event."""
for fn in self.parent_listeners:
@@ -338,19 +461,22 @@ class _CompoundListener(_InstanceLevelDispatch):
for fn in self.listeners:
fn(*args, **kw)
- def __len__(self):
+ def __contains__(self, item: Any) -> bool:
+ return item in self.parent_listeners or item in self.listeners
+
+ def __len__(self) -> int:
return len(self.parent_listeners) + len(self.listeners)
- def __iter__(self):
+ def __iter__(self) -> Iterator[_ListenerFnType]:
return chain(self.parent_listeners, self.listeners)
- def __bool__(self):
+ def __bool__(self) -> bool:
return bool(self.listeners or self.parent_listeners)
__nonzero__ = __bool__
-class _ListenerCollection(_CompoundListener):
+class _ListenerCollection(_CompoundListener[_ET]):
"""Instance-level attributes on instances of :class:`._Dispatch`.
Represents a collection of listeners.
@@ -369,7 +495,13 @@ class _ListenerCollection(_CompoundListener):
"__weakref__",
)
- def __init__(self, parent, target_cls):
+ parent_listeners: Collection[_ListenerFnType]
+ parent: _ClsLevelDispatch[_ET]
+ name: str
+ listeners: Deque[_ListenerFnType]
+ propagate: Set[_ListenerFnType]
+
+ def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]):
if target_cls not in parent._clslevel:
parent.update_subclass(target_cls)
self._exec_once = False
@@ -380,7 +512,7 @@ class _ListenerCollection(_CompoundListener):
self.listeners = collections.deque()
self.propagate = set()
- def for_modify(self, obj):
+ def for_modify(self, obj: _Dispatch[_ET]) -> _ListenerCollection[_ET]:
"""Return an event collection which can be modified.
For _ListenerCollection at the instance level of
@@ -389,10 +521,11 @@ class _ListenerCollection(_CompoundListener):
"""
return self
- def _update(self, other, only_propagate=True):
+ def _update(
+ self, other: _ListenerCollection[_ET], only_propagate: bool = True
+ ) -> None:
"""Populate from the listeners in another :class:`_Dispatch`
object."""
-
existing_listeners = self.listeners
existing_listener_set = set(existing_listeners)
self.propagate.update(other.propagate)
@@ -409,56 +542,75 @@ class _ListenerCollection(_CompoundListener):
to_associate = other.propagate.union(other_listeners)
registry._stored_in_collection_multi(self, other, to_associate)
- def insert(self, event_key, propagate):
+ def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
if event_key.prepend_to_list(self, self.listeners):
if propagate:
self.propagate.add(event_key._listen_fn)
- def append(self, event_key, propagate):
+ def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
if event_key.append_to_list(self, self.listeners):
if propagate:
self.propagate.add(event_key._listen_fn)
- def remove(self, event_key):
+ def remove(self, event_key: _EventKey[_ET]) -> None:
self.listeners.remove(event_key._listen_fn)
self.propagate.discard(event_key._listen_fn)
registry._removed_from_collection(event_key, self)
- def clear(self):
+ def clear(self) -> None:
registry._clear(self, self.listeners)
self.propagate.clear()
self.listeners.clear()
-class _JoinedListener(_CompoundListener):
- __slots__ = "parent", "name", "local", "parent_listeners"
+class _JoinedListener(_CompoundListener[_ET]):
+ __slots__ = "parent_dispatch", "name", "local", "parent_listeners"
+
+ parent_dispatch: _Dispatch[_ET]
+ name: str
+ local: _InstanceLevelDispatch[_ET]
+ parent_listeners: Collection[_ListenerFnType]
- def __init__(self, parent, name, local):
+ def __init__(
+ self,
+ parent_dispatch: _Dispatch[_ET],
+ name: str,
+ local: _EmptyListener[_ET],
+ ):
self._exec_once = False
- self.parent = parent
+ self.parent_dispatch = parent_dispatch
self.name = name
self.local = local
self.parent_listeners = self.local
- @property
- def listeners(self):
- return getattr(self.parent, self.name)
-
- def _adjust_fn_spec(self, fn, named):
+ if not typing.TYPE_CHECKING:
+ # first error, I don't really understand:
+ # Signature of "listeners" incompatible with
+ # supertype "_CompoundListener" [override]
+ # the name / return type are exactly the same
+ # second error is getattr_isn't typed, the cast() here
+ # adds too much method overhead
+ @property
+ def listeners(self) -> Collection[_ListenerFnType]:
+ return getattr(self.parent_dispatch, self.name)
+
+ def _adjust_fn_spec(
+ self, fn: _ListenerFnType, named: bool
+ ) -> _ListenerFnType:
return self.local._adjust_fn_spec(fn, named)
- def for_modify(self, obj):
+ def for_modify(self, obj: _JoinedDispatcher[_ET]) -> _JoinedListener[_ET]:
self.local = self.parent_listeners = self.local.for_modify(obj)
return self
- def insert(self, event_key, propagate):
+ def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
self.local.insert(event_key, propagate)
- def append(self, event_key, propagate):
+ def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
self.local.append(event_key, propagate)
- def remove(self, event_key):
+ def remove(self, event_key: _EventKey[_ET]) -> None:
self.local.remove(event_key)
- def clear(self):
+ def clear(self) -> None:
raise NotImplementedError()