diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/orm/session.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/test/orm/session.py b/test/orm/session.py index 09b9df05d..0282d28fd 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -911,6 +911,72 @@ class SessionTest(_fixtures.FixtureTest): assert log == ['after_begin'] @testing.resolve_artifact_names + def test_before_flush(self): + """test that the flush plan can be affected during before_flush()""" + + mapper(User, users) + + class MyExt(sa.orm.session.SessionExtension): + def before_flush(self, session, flush_context, objects): + for obj in list(session.new) + list(session.dirty): + if isinstance(obj, User): + session.add(User(name='another %s' % obj.name)) + for obj in list(session.deleted): + if isinstance(obj, User): + x = session.query(User).filter(User.name=='another %s' % obj.name).one() + session.delete(x) + + sess = create_session(extension = MyExt(), autoflush=True) + u = User(name='u1') + sess.add(u) + sess.flush() + self.assertEquals(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + User(name='u1') + ] + ) + + sess.flush() + self.assertEquals(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + User(name='u1') + ] + ) + + u.name='u2' + sess.flush() + self.assertEquals(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + User(name='another u2'), + User(name='u2') + ] + ) + + sess.delete(u) + sess.flush() + self.assertEquals(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + ] + ) + + @testing.resolve_artifact_names + def test_reentrant_flush(self): + + mapper(User, users) + + class MyExt(sa.orm.session.SessionExtension): + def before_flush(s, session, flush_context, objects): + session.flush() + + sess = create_session(extension=MyExt()) + sess.add(User(name='foo')) + self.assertRaisesMessage(sa.exc.InvalidRequestError, "already flushing", sess.flush) + + @testing.resolve_artifact_names def test_pickled_update(self): mapper(User, users) sess1 = create_session() |
