summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-08-31 11:07:23 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-08-31 14:31:34 -0400
commitd3e0b8e750d864766148cdf1a658a601079eed46 (patch)
tree7b6ee55bbd18e6fa73b299f46b231abbae8780f5
parentec65def6bffa94d1c89ae5896e4d7e85f9abe84a (diff)
downloadsqlalchemy-d3e0b8e750d864766148cdf1a658a601079eed46.tar.gz
run update_subclass anytime we add new clslevel dispatch
Fixed event listening issue where event listeners added to a superclass would be lost if a subclass were created which then had its own listeners associated. The practical example is that of the :class:`.sessionmaker` class created after events have been associated with the :class:`_orm.Session` class. Fixes: #8467 Change-Id: I9bdba8769147e30110a09900d4a577e833ac3af9
-rw-r--r--doc/build/changelog/unreleased_14/8467.rst9
-rw-r--r--lib/sqlalchemy/event/attr.py53
-rw-r--r--lib/sqlalchemy/sql/annotation.py3
-rw-r--r--lib/sqlalchemy/util/langhelpers.py2
-rw-r--r--test/base/test_events.py29
-rw-r--r--test/orm/test_events.py29
6 files changed, 92 insertions, 33 deletions
diff --git a/doc/build/changelog/unreleased_14/8467.rst b/doc/build/changelog/unreleased_14/8467.rst
new file mode 100644
index 000000000..7626f50a3
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/8467.rst
@@ -0,0 +1,9 @@
+.. change::
+ :tags: bug, events, orm
+ :tickets: 8467
+
+ Fixed event listening issue where event listeners added to a superclass
+ would be lost if a subclass were created which then had its own listeners
+ associated. The practical example is that of the :class:`.sessionmaker`
+ class created after events have been associated with the
+ :class:`_orm.Session` class.
diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py
index f8d70a06a..21d0a2274 100644
--- a/lib/sqlalchemy/event/attr.py
+++ b/lib/sqlalchemy/event/attr.py
@@ -175,57 +175,48 @@ class _ClsLevelDispatch(RefCollection[_ET]):
return wrap_kw
- def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
+ def _do_insert_or_append(
+ self, event_key: _EventKey[_ET], is_append: bool
+ ) -> None:
target = event_key.dispatch_target
assert isinstance(
target, type
), "Class-level Event targets must be classes."
if not getattr(target, "_sa_propagate_class_events", True):
raise exc.InvalidRequestError(
- "Can't assign an event directly to the %s class" % target
+ f"Can't assign an event directly to the {target} class"
)
- for cls in util.walk_subclasses(target):
- cls = cast(Type[_ET], cls)
- if cls is not target and cls not in self._clslevel:
- self.update_subclass(cls)
- else:
- if cls not in self._clslevel:
- self._assign_cls_collection(cls)
- self._clslevel[cls].appendleft(event_key._listen_fn)
- registry._stored_in_collection(event_key, self)
+ cls: Type[_ET]
- def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
- target = event_key.dispatch_target
- assert isinstance(
- target, type
- ), "Class-level Event targets must be classes."
- if not getattr(target, "_sa_propagate_class_events", True):
- raise exc.InvalidRequestError(
- "Can't assign an event directly to the %s class" % target
- )
for cls in util.walk_subclasses(target):
- cls = cast("Type[_ET]", cls)
if cls is not target and cls not in self._clslevel:
self.update_subclass(cls)
else:
if cls not in self._clslevel:
- self._assign_cls_collection(cls)
- self._clslevel[cls].append(event_key._listen_fn)
+ self.update_subclass(cls)
+ if is_append:
+ self._clslevel[cls].append(event_key._listen_fn)
+ else:
+ self._clslevel[cls].appendleft(event_key._listen_fn)
registry._stored_in_collection(event_key, self)
- def _assign_cls_collection(self, target: Type[_ET]) -> None:
- if getattr(target, "_sa_propagate_class_events", True):
- self._clslevel[target] = collections.deque()
- else:
- self._clslevel[target] = _empty_collection()
+ def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
+ self._do_insert_or_append(event_key, is_append=False)
+
+ def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
+ self._do_insert_or_append(event_key, is_append=True)
def update_subclass(self, target: Type[_ET]) -> None:
if target not in self._clslevel:
- self._assign_cls_collection(target)
+ if getattr(target, "_sa_propagate_class_events", True):
+ self._clslevel[target] = collections.deque()
+ else:
+ self._clslevel[target] = _empty_collection()
+
clslevel = self._clslevel[target]
+ cls: Type[_ET]
for cls in target.__mro__[1:]:
- cls = cast("Type[_ET]", cls)
if cls in self._clslevel:
clslevel.extend(
[fn for fn in self._clslevel[cls] if fn not in clslevel]
@@ -233,8 +224,8 @@ class _ClsLevelDispatch(RefCollection[_ET]):
def remove(self, event_key: _EventKey[_ET]) -> None:
target = event_key.dispatch_target
+ cls: Type[_ET]
for cls in util.walk_subclasses(target):
- cls = cast("Type[_ET]", cls)
if cls in self._clslevel:
self._clslevel[cls].remove(event_key._listen_fn)
registry._removed_from_collection(event_key, self)
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 95dc1d4d4..86b2952cb 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -559,7 +559,8 @@ def _new_annotation_type(
def _prepare_annotations(
- target_hierarchy: Type[SupportsAnnotations], base_cls: Type[Annotated]
+ target_hierarchy: Type[SupportsWrappingAnnotations],
+ base_cls: Type[Annotated],
) -> None:
for cls in util.walk_subclasses(target_hierarchy):
_new_annotation_type(cls, base_cls)
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 66354f6b6..70c9bba9f 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -152,7 +152,7 @@ class safe_reraise:
raise value.with_traceback(traceback)
-def walk_subclasses(cls: type) -> Iterator[type]:
+def walk_subclasses(cls: Type[_T]) -> Iterator[Type[_T]]:
seen: Set[Any] = set()
stack = [cls]
diff --git a/test/base/test_events.py b/test/base/test_events.py
index 7e978d23b..67933a5fe 100644
--- a/test/base/test_events.py
+++ b/test/base/test_events.py
@@ -677,6 +677,35 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
eq_(len(SubTarget().dispatch.event_one), 2)
+ @testing.combinations(True, False, argnames="m1")
+ @testing.combinations(True, False, argnames="m2")
+ @testing.combinations(True, False, argnames="m3")
+ @testing.combinations(True, False, argnames="use_insert")
+ def test_subclass_gen_after_clslisten(self, m1, m2, m3, use_insert):
+ """test #8467"""
+ m1 = Mock() if m1 else None
+ m2 = Mock() if m2 else None
+ m3 = Mock() if m3 else None
+
+ if m1:
+ event.listen(self.TargetOne, "event_one", m1, insert=use_insert)
+
+ class SubTarget(self.TargetOne):
+ pass
+
+ if m2:
+ event.listen(SubTarget, "event_one", m2, insert=use_insert)
+
+ if m3:
+ event.listen(self.TargetOne, "event_one", m3, insert=use_insert)
+
+ st = SubTarget()
+ st.dispatch.event_one()
+
+ for m in m1, m2, m3:
+ if m:
+ eq_(m.mock_calls, [call()])
+
def test_lis_multisub_lis(self):
@event.listens_for(self.TargetOne, "event_one")
def handler1(x, y):
diff --git a/test/orm/test_events.py b/test/orm/test_events.py
index 7e1b29cb1..24870e20f 100644
--- a/test/orm/test_events.py
+++ b/test/orm/test_events.py
@@ -2195,6 +2195,35 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
s = fixture_session()
assert my_listener in s.dispatch.before_flush
+ @testing.combinations(True, False, argnames="m1")
+ @testing.combinations(True, False, argnames="m2")
+ @testing.combinations(True, False, argnames="m3")
+ @testing.combinations(True, False, argnames="use_insert")
+ def test_sessionmaker_gen_after_session_listen(
+ self, m1, m2, m3, use_insert
+ ):
+ m1 = Mock() if m1 else None
+ m2 = Mock() if m2 else None
+ m3 = Mock() if m3 else None
+
+ if m1:
+ event.listen(Session, "before_flush", m1, insert=use_insert)
+
+ factory = sessionmaker()
+
+ if m2:
+ event.listen(factory, "before_flush", m2, insert=use_insert)
+
+ if m3:
+ event.listen(factory, "before_flush", m3, insert=use_insert)
+
+ st = factory()
+ st.dispatch.before_flush()
+
+ for m in m1, m2, m3:
+ if m:
+ eq_(m.mock_calls, [call()])
+
def test_sessionmaker_listen(self):
"""test that listen can be applied to individual
scoped_session() classes."""