diff options
Diffstat (limited to 'lib/sqlalchemy/event/attr.py')
| -rw-r--r-- | lib/sqlalchemy/event/attr.py | 53 |
1 files changed, 22 insertions, 31 deletions
diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index f8d70a06a..21d0a2274 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -175,57 +175,48 @@ class _ClsLevelDispatch(RefCollection[_ET]): return wrap_kw - def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + def _do_insert_or_append( + self, event_key: _EventKey[_ET], is_append: bool + ) -> None: target = event_key.dispatch_target assert isinstance( target, type ), "Class-level Event targets must be classes." if not getattr(target, "_sa_propagate_class_events", True): raise exc.InvalidRequestError( - "Can't assign an event directly to the %s class" % target + f"Can't assign an event directly to the {target} class" ) - for cls in util.walk_subclasses(target): - cls = cast(Type[_ET], cls) - if cls is not target and cls not in self._clslevel: - self.update_subclass(cls) - else: - if cls not in self._clslevel: - self._assign_cls_collection(cls) - self._clslevel[cls].appendleft(event_key._listen_fn) - registry._stored_in_collection(event_key, self) + cls: Type[_ET] - def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: - target = event_key.dispatch_target - assert isinstance( - target, type - ), "Class-level Event targets must be classes." - if not getattr(target, "_sa_propagate_class_events", True): - raise exc.InvalidRequestError( - "Can't assign an event directly to the %s class" % target - ) for cls in util.walk_subclasses(target): - cls = cast("Type[_ET]", cls) if cls is not target and cls not in self._clslevel: self.update_subclass(cls) else: if cls not in self._clslevel: - self._assign_cls_collection(cls) - self._clslevel[cls].append(event_key._listen_fn) + self.update_subclass(cls) + if is_append: + self._clslevel[cls].append(event_key._listen_fn) + else: + self._clslevel[cls].appendleft(event_key._listen_fn) registry._stored_in_collection(event_key, self) - def _assign_cls_collection(self, target: Type[_ET]) -> None: - if getattr(target, "_sa_propagate_class_events", True): - self._clslevel[target] = collections.deque() - else: - self._clslevel[target] = _empty_collection() + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self._do_insert_or_append(event_key, is_append=False) + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self._do_insert_or_append(event_key, is_append=True) def update_subclass(self, target: Type[_ET]) -> None: if target not in self._clslevel: - self._assign_cls_collection(target) + if getattr(target, "_sa_propagate_class_events", True): + self._clslevel[target] = collections.deque() + else: + self._clslevel[target] = _empty_collection() + clslevel = self._clslevel[target] + cls: Type[_ET] for cls in target.__mro__[1:]: - cls = cast("Type[_ET]", cls) if cls in self._clslevel: clslevel.extend( [fn for fn in self._clslevel[cls] if fn not in clslevel] @@ -233,8 +224,8 @@ class _ClsLevelDispatch(RefCollection[_ET]): def remove(self, event_key: _EventKey[_ET]) -> None: target = event_key.dispatch_target + cls: Type[_ET] for cls in util.walk_subclasses(target): - cls = cast("Type[_ET]", cls) if cls in self._clslevel: self._clslevel[cls].remove(event_key._listen_fn) registry._removed_from_collection(event_key, self) |
