summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2011-04-02 13:29:11 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2011-04-02 13:29:11 -0400
commit0b52a5ae744acad5c58f382049ecdbd954fc7ee6 (patch)
tree5988e2a0790a21ef5c4841374b741cddac2a5fcb
parenteb4a843318b4fa76d238c43a293d64af64ca1148 (diff)
downloadsqlalchemy-0b52a5ae744acad5c58f382049ecdbd954fc7ee6.tar.gz
- Added @event.listens_for() decorator, given
target + event name, applies the decorated function as a listener. [ticket:2106] - remove usage of globals from test.base.test_events
-rw-r--r--CHANGES5
-rw-r--r--doc/build/core/event.rst2
-rw-r--r--lib/sqlalchemy/event.py35
-rw-r--r--test/base/test_events.py112
4 files changed, 110 insertions, 44 deletions
diff --git a/CHANGES b/CHANGES
index 3765a6ed9..b4127beae 100644
--- a/CHANGES
+++ b/CHANGES
@@ -46,6 +46,11 @@ CHANGES
collection of Sequence objects, list
of schema names. [ticket:2104]
+-event
+ - Added @event.listens_for() decorator, given
+ target + event name, applies the decorated
+ function as a listener. [ticket:2106]
+
- pool
- AssertionPool now stores the traceback indicating
where the currently checked out connection was
diff --git a/doc/build/core/event.rst b/doc/build/core/event.rst
index 86cc7d968..68d4802bc 100644
--- a/doc/build/core/event.rst
+++ b/doc/build/core/event.rst
@@ -102,3 +102,5 @@ API Reference
.. autofunction:: sqlalchemy.event.listen
+.. autofunction:: sqlalchemy.event.listens_for
+
diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py
index b2e5cd00f..4be227c51 100644
--- a/lib/sqlalchemy/event.py
+++ b/lib/sqlalchemy/event.py
@@ -13,6 +13,21 @@ NO_RETVAL = util.symbol('NO_RETVAL')
def listen(target, identifier, fn, *args, **kw):
"""Register a listener function for the given target.
+
+ e.g.::
+
+ from sqlalchemy import event
+ from sqlalchemy.schema import UniqueConstraint
+
+ def unique_constraint_name(const, table):
+ const.name = "uq_%s_%s" % (
+ table.name,
+ list(const.columns)[0].name
+ )
+ event.listen(
+ UniqueConstraint,
+ "after_parent_attach",
+ unique_constraint_name)
"""
@@ -24,6 +39,26 @@ def listen(target, identifier, fn, *args, **kw):
raise exc.InvalidRequestError("No such event '%s' for target '%s'" %
(identifier,target))
+def listens_for(target, identifier, *args, **kw):
+ """Decorate a function as a listener for the given target + identifier.
+
+ e.g.::
+
+ from sqlalchemy import event
+ from sqlalchemy.schema import UniqueConstraint
+
+ @event.listens_for(UniqueConstraint, "after_parent_attach")
+ def unique_constraint_name(const, table):
+ const.name = "uq_%s_%s" % (
+ table.name,
+ list(const.columns)[0].name
+ )
+ """
+ def decorate(fn):
+ listen(target, identifier, fn, *args, **kw)
+ return fn
+ return decorate
+
def remove(target, identifier, fn):
"""Remove an event listener.
diff --git a/test/base/test_events.py b/test/base/test_events.py
index 96cda7cc9..94d3dad85 100644
--- a/test/base/test_events.py
+++ b/test/base/test_events.py
@@ -8,8 +8,6 @@ class TestEvents(fixtures.TestBase):
"""Test class- and instance-level event registration."""
def setUp(self):
- global Target
-
assert 'event_one' not in event._registrars
assert 'event_two' not in event._registrars
@@ -20,31 +18,35 @@ class TestEvents(fixtures.TestBase):
def event_two(self, x):
pass
+ def event_three(self, x):
+ pass
+
class Target(object):
dispatch = event.dispatcher(TargetEvents)
+ self.Target = Target
def tearDown(self):
- event._remove_dispatcher(Target.__dict__['dispatch'].events)
+ event._remove_dispatcher(self.Target.__dict__['dispatch'].events)
def test_register_class(self):
def listen(x, y):
pass
- event.listen(Target, "event_one", listen)
+ event.listen(self.Target, "event_one", listen)
- eq_(len(Target().dispatch.event_one), 1)
- eq_(len(Target().dispatch.event_two), 0)
+ eq_(len(self.Target().dispatch.event_one), 1)
+ eq_(len(self.Target().dispatch.event_two), 0)
def test_register_instance(self):
def listen(x, y):
pass
- t1 = Target()
+ t1 = self.Target()
event.listen(t1, "event_one", listen)
- eq_(len(Target().dispatch.event_one), 0)
+ eq_(len(self.Target().dispatch.event_one), 0)
eq_(len(t1.dispatch.event_one), 1)
- eq_(len(Target().dispatch.event_two), 0)
+ eq_(len(self.Target().dispatch.event_two), 0)
eq_(len(t1.dispatch.event_two), 0)
def test_register_class_instance(self):
@@ -54,21 +56,21 @@ class TestEvents(fixtures.TestBase):
def listen_two(x, y):
pass
- event.listen(Target, "event_one", listen_one)
+ event.listen(self.Target, "event_one", listen_one)
- t1 = Target()
+ t1 = self.Target()
event.listen(t1, "event_one", listen_two)
- eq_(len(Target().dispatch.event_one), 1)
+ eq_(len(self.Target().dispatch.event_one), 1)
eq_(len(t1.dispatch.event_one), 2)
- eq_(len(Target().dispatch.event_two), 0)
+ eq_(len(self.Target().dispatch.event_two), 0)
eq_(len(t1.dispatch.event_two), 0)
def listen_three(x, y):
pass
- event.listen(Target, "event_one", listen_three)
- eq_(len(Target().dispatch.event_one), 2)
+ event.listen(self.Target, "event_one", listen_three)
+ eq_(len(self.Target().dispatch.event_one), 2)
eq_(len(t1.dispatch.event_one), 3)
def test_append_vs_insert(self):
@@ -81,21 +83,44 @@ class TestEvents(fixtures.TestBase):
def listen_three(x, y):
pass
- event.listen(Target, "event_one", listen_one)
- event.listen(Target, "event_one", listen_two)
- event.listen(Target, "event_one", listen_three, insert=True)
+ event.listen(self.Target, "event_one", listen_one)
+ event.listen(self.Target, "event_one", listen_two)
+ event.listen(self.Target, "event_one", listen_three, insert=True)
eq_(
- list(Target().dispatch.event_one),
+ list(self.Target().dispatch.event_one),
[listen_three, listen_one, listen_two]
)
+ def test_decorator(self):
+ @event.listens_for(self.Target, "event_one")
+ def listen_one(x, y):
+ pass
+
+ @event.listens_for(self.Target, "event_two")
+ @event.listens_for(self.Target, "event_three")
+ def listen_two(x, y):
+ pass
+
+ eq_(
+ list(self.Target().dispatch.event_one),
+ [listen_one]
+ )
+
+ eq_(
+ list(self.Target().dispatch.event_two),
+ [listen_two]
+ )
+
+ eq_(
+ list(self.Target().dispatch.event_three),
+ [listen_two]
+ )
+
class TestAcceptTargets(fixtures.TestBase):
"""Test default target acceptance."""
def setUp(self):
- global TargetOne, TargetTwo
-
class TargetEventsOne(event.Events):
def event_one(self, x, y):
pass
@@ -109,10 +134,12 @@ class TestAcceptTargets(fixtures.TestBase):
class TargetTwo(object):
dispatch = event.dispatcher(TargetEventsTwo)
+ self.TargetOne = TargetOne
+ self.TargetTwo = TargetTwo
def tearDown(self):
- event._remove_dispatcher(TargetOne.__dict__['dispatch'].events)
- event._remove_dispatcher(TargetTwo.__dict__['dispatch'].events)
+ event._remove_dispatcher(self.TargetOne.__dict__['dispatch'].events)
+ event._remove_dispatcher(self.TargetTwo.__dict__['dispatch'].events)
def test_target_accept(self):
"""Test that events of the same name are routed to the correct
@@ -132,21 +159,21 @@ class TestAcceptTargets(fixtures.TestBase):
def listen_four(x, y):
pass
- event.listen(TargetOne, "event_one", listen_one)
- event.listen(TargetTwo, "event_one", listen_two)
+ event.listen(self.TargetOne, "event_one", listen_one)
+ event.listen(self.TargetTwo, "event_one", listen_two)
eq_(
- list(TargetOne().dispatch.event_one),
+ list(self.TargetOne().dispatch.event_one),
[listen_one]
)
eq_(
- list(TargetTwo().dispatch.event_one),
+ list(self.TargetTwo().dispatch.event_one),
[listen_two]
)
- t1 = TargetOne()
- t2 = TargetTwo()
+ t1 = self.TargetOne()
+ t2 = self.TargetTwo()
event.listen(t1, "event_one", listen_three)
event.listen(t2, "event_one", listen_four)
@@ -165,8 +192,6 @@ class TestCustomTargets(fixtures.TestBase):
"""Test custom target acceptance."""
def setUp(self):
- global Target
-
class TargetEvents(event.Events):
@classmethod
def _accept_with(cls, target):
@@ -180,9 +205,10 @@ class TestCustomTargets(fixtures.TestBase):
class Target(object):
dispatch = event.dispatcher(TargetEvents)
+ self.Target = Target
def tearDown(self):
- event._remove_dispatcher(Target.__dict__['dispatch'].events)
+ event._remove_dispatcher(self.Target.__dict__['dispatch'].events)
def test_indirect(self):
def listen(x, y):
@@ -191,22 +217,20 @@ class TestCustomTargets(fixtures.TestBase):
event.listen("one", "event_one", listen)
eq_(
- list(Target().dispatch.event_one),
+ list(self.Target().dispatch.event_one),
[listen]
)
assert_raises(
exc.InvalidRequestError,
event.listen,
- listen, "event_one", Target
+ listen, "event_one", self.Target
)
class TestListenOverride(fixtures.TestBase):
"""Test custom listen functions which change the listener function signature."""
def setUp(self):
- global Target
-
class TargetEvents(event.Events):
@classmethod
def _listen(cls, target, identifier, fn, add=False):
@@ -223,9 +247,10 @@ class TestListenOverride(fixtures.TestBase):
class Target(object):
dispatch = event.dispatcher(TargetEvents)
+ self.Target = Target
def tearDown(self):
- event._remove_dispatcher(Target.__dict__['dispatch'].events)
+ event._remove_dispatcher(self.Target.__dict__['dispatch'].events)
def test_listen_override(self):
result = []
@@ -235,10 +260,10 @@ class TestListenOverride(fixtures.TestBase):
def listen_two(x, y):
result.append((x, y))
- event.listen(Target, "event_one", listen_one, add=True)
- event.listen(Target, "event_one", listen_two)
+ event.listen(self.Target, "event_one", listen_one, add=True)
+ event.listen(self.Target, "event_one", listen_two)
- t1 = Target()
+ t1 = self.Target()
t1.dispatch.event_one(5, 7)
t1.dispatch.event_one(10, 5)
@@ -250,8 +275,6 @@ class TestListenOverride(fixtures.TestBase):
class TestPropagate(fixtures.TestBase):
def setUp(self):
- global Target
-
class TargetEvents(event.Events):
def event_one(self, arg):
pass
@@ -261,6 +284,7 @@ class TestPropagate(fixtures.TestBase):
class Target(object):
dispatch = event.dispatcher(TargetEvents)
+ self.Target = Target
def test_propagate(self):
@@ -271,12 +295,12 @@ class TestPropagate(fixtures.TestBase):
def listen_two(target, arg):
result.append((target, arg))
- t1 = Target()
+ t1 = self.Target()
event.listen(t1, "event_one", listen_one, propagate=True)
event.listen(t1, "event_two", listen_two)
- t2 = Target()
+ t2 = self.Target()
t2.dispatch._update(t1.dispatch)