summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/orm/session.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/test/orm/session.py b/test/orm/session.py
index 09b9df05d..0282d28fd 100644
--- a/test/orm/session.py
+++ b/test/orm/session.py
@@ -911,6 +911,72 @@ class SessionTest(_fixtures.FixtureTest):
assert log == ['after_begin']
@testing.resolve_artifact_names
+ def test_before_flush(self):
+ """test that the flush plan can be affected during before_flush()"""
+
+ mapper(User, users)
+
+ class MyExt(sa.orm.session.SessionExtension):
+ def before_flush(self, session, flush_context, objects):
+ for obj in list(session.new) + list(session.dirty):
+ if isinstance(obj, User):
+ session.add(User(name='another %s' % obj.name))
+ for obj in list(session.deleted):
+ if isinstance(obj, User):
+ x = session.query(User).filter(User.name=='another %s' % obj.name).one()
+ session.delete(x)
+
+ sess = create_session(extension = MyExt(), autoflush=True)
+ u = User(name='u1')
+ sess.add(u)
+ sess.flush()
+ self.assertEquals(sess.query(User).order_by(User.name).all(),
+ [
+ User(name='another u1'),
+ User(name='u1')
+ ]
+ )
+
+ sess.flush()
+ self.assertEquals(sess.query(User).order_by(User.name).all(),
+ [
+ User(name='another u1'),
+ User(name='u1')
+ ]
+ )
+
+ u.name='u2'
+ sess.flush()
+ self.assertEquals(sess.query(User).order_by(User.name).all(),
+ [
+ User(name='another u1'),
+ User(name='another u2'),
+ User(name='u2')
+ ]
+ )
+
+ sess.delete(u)
+ sess.flush()
+ self.assertEquals(sess.query(User).order_by(User.name).all(),
+ [
+ User(name='another u1'),
+ ]
+ )
+
+ @testing.resolve_artifact_names
+ def test_reentrant_flush(self):
+
+ mapper(User, users)
+
+ class MyExt(sa.orm.session.SessionExtension):
+ def before_flush(s, session, flush_context, objects):
+ session.flush()
+
+ sess = create_session(extension=MyExt())
+ sess.add(User(name='foo'))
+ self.assertRaisesMessage(sa.exc.InvalidRequestError, "already flushing", sess.flush)
+
+ @testing.resolve_artifact_names
def test_pickled_update(self):
mapper(User, users)
sess1 = create_session()