from testbase import PersistTest, AssertMixin import unittest, sys, os from sqlalchemy.ext.sessioncontext import SessionContext from sqlalchemy.orm.session import object_session, Session from sqlalchemy import * import testbase metadata = MetaData() users = Table('users', metadata, Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), Column('user_name', String(40)), mysql_engine='innodb' ) class SessionContextTest(AssertMixin): def setUp(self): clear_mappers() def do_test(self, class_, context): """test session assignment on object creation""" obj = class_() assert context.current == object_session(obj) # keep a reference so the old session doesn't get gc'd old_session = context.current context.current = Session() assert context.current != object_session(obj) assert old_session == object_session(obj) new_session = context.current del context.current assert context.current != new_session assert old_session == object_session(obj) obj2 = class_() assert context.current == object_session(obj2) def test_mapper_extension(self): context = SessionContext(Session) class User(object): pass User.mapper = mapper(User, users, extension=context.mapper_extension) self.do_test(User, context) if __name__ == "__main__": testbase.main()