diff options
Diffstat (limited to 'test')
37 files changed, 2159 insertions, 1731 deletions
diff --git a/test/activemapper.py b/test/activemapper.py index 7edb41488..6a8b0904e 100644 --- a/test/activemapper.py +++ b/test/activemapper.py @@ -1,80 +1,84 @@ -from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one -from sqlalchemy.ext import activemapper -from sqlalchemy import objectstore, global_connect -from sqlalchemy import and_, or_ -from sqlalchemy import ForeignKey, String, Integer, DateTime -from datetime import datetime +from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one, objectstore +from sqlalchemy import and_, or_, clear_mappers +from sqlalchemy import ForeignKey, String, Integer, DateTime +from datetime import datetime import unittest +import sqlalchemy.ext.activemapper as activemapper -# -# application-level model objects -# +import testbase -class Person(ActiveMapper): - class mapping: - id = column(Integer, primary_key=True) - full_name = column(String) - first_name = column(String) - middle_name = column(String) - last_name = column(String) - birth_date = column(DateTime) - ssn = column(String) - gender = column(String) - home_phone = column(String) - cell_phone = column(String) - work_phone = column(String) - prefs_id = column(Integer, foreign_key=ForeignKey('preferences.id')) - addresses = one_to_many('Address', colname='person_id', backref='person') - preferences = one_to_one('Preferences', colname='pref_id', backref='person') - - def __str__(self): - s = '%s\n' % self.full_name - s += ' * birthdate: %s\n' % (self.birth_date or 'not provided') - s += ' * fave color: %s\n' % (self.preferences.favorite_color or 'Unknown') - s += ' * personality: %s\n' % (self.preferences.personality_type or 'Unknown') - - for address in self.addresses: - s += ' * address: %s\n' % address.address_1 - s += ' %s, %s %s\n' % (address.city, address.state, address.postal_code) - - return s +class testcase(testbase.PersistTest): + def setUpAll(self): + global Person, Preferences, Address + + class Person(ActiveMapper): + class mapping: + id = column(Integer, primary_key=True) + full_name = column(String) + first_name = column(String) + middle_name = column(String) + last_name = column(String) + birth_date = column(DateTime) + ssn = column(String) + gender = column(String) + home_phone = column(String) + cell_phone = column(String) + work_phone = column(String) + prefs_id = column(Integer, foreign_key=ForeignKey('preferences.id')) + addresses = one_to_many('Address', colname='person_id', backref='person') + preferences = one_to_one('Preferences', colname='pref_id', backref='person') + def __str__(self): + s = '%s\n' % self.full_name + s += ' * birthdate: %s\n' % (self.birth_date or 'not provided') + s += ' * fave color: %s\n' % (self.preferences.favorite_color or 'Unknown') + s += ' * personality: %s\n' % (self.preferences.personality_type or 'Unknown') -class Preferences(ActiveMapper): - class mapping: - __table__ = 'preferences' - id = column(Integer, primary_key=True) - favorite_color = column(String) - personality_type = column(String) + for address in self.addresses: + s += ' * address: %s\n' % address.address_1 + s += ' %s, %s %s\n' % (address.city, address.state, address.postal_code) + return s -class Address(ActiveMapper): - class mapping: - id = column(Integer, primary_key=True) - type = column(String) - address_1 = column(String) - city = column(String) - state = column(String) - postal_code = column(String) - person_id = column(Integer, foreign_key=ForeignKey('person.id')) + class Preferences(ActiveMapper): + class mapping: + __table__ = 'preferences' + id = column(Integer, primary_key=True) + favorite_color = column(String) + personality_type = column(String) + class Address(ActiveMapper): + class mapping: + id = column(Integer, primary_key=True) + type = column(String) + address_1 = column(String) + city = column(String) + state = column(String) + postal_code = column(String) + person_id = column(Integer, foreign_key=ForeignKey('person.id')) + activemapper.metadata.connect(testbase.db) + activemapper.create_tables() -class testcase(unittest.TestCase): - + def tearDownAll(self): + clear_mappers() + activemapper.drop_tables() + def tearDown(self): - people = Person.select() - for person in people: person.delete() + for t in activemapper.metadata.table_iterator(reverse=True): + t.delete().execute() + #people = Person.select() + #for person in people: person.delete() - addresses = Address.select() - for address in addresses: address.delete() + #addresses = Address.select() + #for address in addresses: address.delete() - preferences = Preferences.select() - for preference in preferences: preference.delete() + #preferences = Preferences.select() + #for preference in preferences: preference.delete() - objectstore.commit() - objectstore.clear() + #objectstore.flush() + #objectstore.clear() def create_person_one(self): # create a person @@ -130,12 +134,8 @@ class testcase(unittest.TestCase): def test_create(self): - global_connect('sqlite:///', echo=False) - activemapper.create_tables() - p1 = self.create_person_one() - - objectstore.commit() + objectstore.flush() objectstore.clear() results = Person.select() @@ -151,14 +151,14 @@ class testcase(unittest.TestCase): def test_delete(self): p1 = self.create_person_one() - objectstore.commit() + objectstore.flush() objectstore.clear() results = Person.select() self.assertEquals(len(results), 1) results[0].delete() - objectstore.commit() + objectstore.flush() objectstore.clear() results = Person.select() @@ -169,7 +169,7 @@ class testcase(unittest.TestCase): p1 = self.create_person_one() p2 = self.create_person_two() - objectstore.commit() + objectstore.flush() objectstore.clear() # select and make sure we get back two results @@ -200,29 +200,38 @@ class testcase(unittest.TestCase): # FIXME: I don't know why, but it seems that my backwards relationship # on preferences still ends up being a list even though I pass # in uselist=False... + # FIXED: the backref is a new PropertyLoader which needs its own "uselist". + # uses a function which I dont think existed when you first wrote ActiveMapper. p1 = self.create_person_one() self.assertEquals(p1.preferences.person, p1) p1.delete() - objectstore.commit() + objectstore.flush() objectstore.clear() def test_select_by(self): # FIXME: either I don't understand select_by, or it doesn't work. + # FIXED (as good as we can for now): yup....everyone thinks it works that way....it only + # generates joins for keyword arguments, not ColumnClause args. would need a new layer of + # "MapperClause" objects to use properties in expressions. (MB) p1 = self.create_person_one() p2 = self.create_person_two() - objectstore.commit() + objectstore.flush() objectstore.clear() - results = Person.select_by( - Address.c.postal_code.like('30075') + results = Person.select( + Address.c.postal_code.like('30075') & + Person.join_to('addresses') ) self.assertEquals(len(results), 1) if __name__ == '__main__': + # go ahead and setup the database connection, and create the tables + + # launch the unit tests unittest.main()
\ No newline at end of file diff --git a/test/alltests.py b/test/alltests.py index 3595edd7e..c1662bce7 100644 --- a/test/alltests.py +++ b/test/alltests.py @@ -1,20 +1,16 @@ import testbase import unittest -testbase.echo = False - -#test - def suite(): modules_to_test = ( # core utilities - 'historyarray', + 'historyarray', 'attributes', 'dependency', - + # connectivity, execution 'pool', - 'engine', + 'transaction', # schema/tables 'reflection', @@ -40,7 +36,9 @@ def suite(): 'eagertest2', # ORM persistence + 'sessioncontext', 'objectstore', + 'cascade', 'relationships', # cyclical ORM persistence @@ -51,9 +49,11 @@ def suite(): 'manytomany', 'onetoone', 'inheritance', + 'polymorph', # extensions 'proxy_engine', + 'activemapper' #'wsgi_test', ) @@ -62,8 +62,6 @@ def suite(): alltests.addTest(unittest.findTestCases(module, suiteClass=None)) return alltests -import sys -sys.stdout = sys.stderr if __name__ == '__main__': testbase.runTests(suite()) diff --git a/test/attributes.py b/test/attributes.py index 126f50456..bff864fa6 100644 --- a/test/attributes.py +++ b/test/attributes.py @@ -45,8 +45,7 @@ class AttributesTest(PersistTest): manager.register_attribute(MyTest, 'email_address', uselist = False) x = MyTest() x.user_id=7 - s = pickle.dumps(x) - y = pickle.loads(s) + pickle.dumps(x) def testlist(self): class User(object):pass diff --git a/test/cascade.py b/test/cascade.py new file mode 100644 index 000000000..4a997d67f --- /dev/null +++ b/test/cascade.py @@ -0,0 +1,173 @@ +import testbase, tables +import unittest, sys, datetime + +from sqlalchemy.ext.sessioncontext import SessionContext +from sqlalchemy import * + +class O2MCascadeTest(testbase.AssertMixin): + def tearDown(self): + ctx.current.clear() + tables.delete() + + def tearDownAll(self): + clear_mappers() + tables.drop() + + def setUpAll(self): + global ctx, data + ctx = SessionContext(lambda: create_session(echo_uow=True)) + tables.create() + mapper(tables.User, tables.users, properties = dict( + address = relation(mapper(tables.Address, tables.addresses), lazy = False, uselist = False, private = True), + orders = relation( + mapper(tables.Order, tables.orders, properties = dict ( + items = relation(mapper(tables.Item, tables.orderitems), lazy = False, uselist =True, private = True) + )), + lazy = True, uselist = True, private = True) + )) + + def setUp(self): + global data + data = [tables.User, + {'user_name' : 'ed', + 'address' : (tables.Address, {'email_address' : 'foo@bar.com'}), + 'orders' : (tables.Order, [ + {'description' : 'eds 1st order', 'items' : (tables.Item, [{'item_name' : 'eds o1 item'}, {'item_name' : 'eds other o1 item'}])}, + {'description' : 'eds 2nd order', 'items' : (tables.Item, [{'item_name' : 'eds o2 item'}, {'item_name' : 'eds other o2 item'}])} + ]) + }, + {'user_name' : 'jack', + 'address' : (tables.Address, {'email_address' : 'jack@jack.com'}), + 'orders' : (tables.Order, [ + {'description' : 'jacks 1st order', 'items' : (tables.Item, [{'item_name' : 'im a lumberjack'}, {'item_name' : 'and im ok'}])} + ]) + }, + {'user_name' : 'foo', + 'address' : (tables.Address, {'email_address': 'hi@lala.com'}), + 'orders' : (tables.Order, [ + {'description' : 'foo order', 'items' : (tables.Item, [])}, + {'description' : 'foo order 2', 'items' : (tables.Item, [{'item_name' : 'hi'}])}, + {'description' : 'foo order three', 'items' : (tables.Item, [{'item_name' : 'there'}])} + ]) + } + ] + + for elem in data[1:]: + u = tables.User() + ctx.current.save(u) + u.user_name = elem['user_name'] + u.address = tables.Address() + u.address.email_address = elem['address'][1]['email_address'] + u.orders = [] + for order in elem['orders'][1]: + o = tables.Order() + o.isopen = None + o.description = order['description'] + u.orders.append(o) + o.items = [] + for item in order['items'][1]: + i = tables.Item() + i.item_name = item['item_name'] + o.items.append(i) + + ctx.current.flush() + ctx.current.clear() + + + def testdelete(self): + l = ctx.current.query(tables.User).select() + for u in l: + self.echo( repr(u.orders)) + self.assert_result(l, data[0], *data[1:]) + + self.echo("\n\n\n") + ids = (l[0].user_id, l[2].user_id) + ctx.current.delete(l[0]) + ctx.current.delete(l[2]) + + ctx.current.flush() + self.assert_(tables.orders.count(tables.orders.c.user_id.in_(*ids)).scalar() == 0) + self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0) + self.assert_(tables.addresses.count(tables.addresses.c.user_id.in_(*ids)).scalar() == 0) + self.assert_(tables.users.count(tables.users.c.user_id.in_(*ids)).scalar() == 0) + + + def testorphan(self): + l = ctx.current.query(tables.User).select() + jack = l[1] + jack.orders[:] = [] + + ids = [jack.user_id] + self.assert_(tables.orders.count(tables.orders.c.user_id.in_(*ids)).scalar() == 1) + self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 2) + + ctx.current.flush() + + self.assert_(tables.orders.count(tables.orders.c.user_id.in_(*ids)).scalar() == 0) + self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0) + + +class M2OCascadeTest(testbase.AssertMixin): + def tearDown(self): + ctx.current.clear() + for t in metadata.table_iterator(reverse=True): + t.delete().execute() + + def tearDownAll(self): + clear_mappers() + metadata.drop_all() + + def setUpAll(self): + global ctx, data, metadata, User, Pref + ctx = SessionContext(create_session) + metadata = BoundMetaData(testbase.db) + prefs = Table('prefs', metadata, + Column('prefs_id', Integer, Sequence('prefs_id_seq', optional=True), primary_key=True), + Column('prefs_data', String(40))) + + users = Table('users', metadata, + Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), + Column('user_name', String(40)), + Column('pref_id', Integer, ForeignKey('prefs.prefs_id')) + ) + class User(object): + pass + class Pref(object): + pass + metadata.create_all() + mapper(User, users, properties = dict( + pref = relation(mapper(Pref, prefs), lazy=False, cascade="all, delete-orphan") + )) + + def setUp(self): + global data + data = [User, + {'user_name' : 'ed', + 'pref' : (Pref, {'prefs_data' : 'pref 1'}), + }, + {'user_name' : 'jack', + 'pref' : (Pref, {'prefs_data' : 'pref 2'}), + }, + {'user_name' : 'foo', + 'pref' : (Pref, {'prefs_data' : 'pref 3'}), + } + ] + + for elem in data[1:]: + u = User() + ctx.current.save(u) + u.user_name = elem['user_name'] + u.pref = Pref() + u.pref.prefs_data = elem['pref'][1]['prefs_data'] + + ctx.current.flush() + ctx.current.clear() + + def testorphan(self): + l = ctx.current.query(User).select() + jack = l[1] + jack.pref = None + ctx.current.flush() + +if __name__ == "__main__": + testbase.main() diff --git a/test/cycles.py b/test/cycles.py index 02d34bbb0..fb051f47a 100644 --- a/test/cycles.py +++ b/test/cycles.py @@ -22,26 +22,22 @@ class Tester(object): class SelfReferentialTest(AssertMixin): """tests a self-referential mapper, with an additional list of child objects.""" def setUpAll(self): - testbase.db.tables.clear() - global t1 - global t2 - t1 = Table('t1', testbase.db, + global t1, t2, metadata + metadata = BoundMetaData(testbase.db) + t1 = Table('t1', metadata, Column('c1', Integer, primary_key=True), Column('parent_c1', Integer, ForeignKey('t1.c1')), Column('data', String(20)) ) - t2 = Table('t2', testbase.db, + t2 = Table('t2', metadata, Column('c1', Integer, primary_key=True), Column('c1id', Integer, ForeignKey('t1.c1')), Column('data', String(20)) ) - t1.create() - t2.create() + metadata.create_all() def tearDownAll(self): - t2.drop() - t1.drop() + metadata.drop_all() def setUp(self): - objectstore.clear() clear_mappers() def testsingle(self): @@ -53,9 +49,11 @@ class SelfReferentialTest(AssertMixin): }) a = C1('head c1') a.c1s.append(C1('another c1')) - objectstore.commit() - objectstore.delete(a) - objectstore.commit() + sess = create_session(echo_uow=False) + sess.save(a) + sess.flush() + sess.delete(a) + sess.flush() def testcycle(self): class C1(Tester): @@ -75,36 +73,34 @@ class SelfReferentialTest(AssertMixin): a.c1s[0].c1s.append(C1('subchild2')) a.c1s[1].c2s.append(C2('child2 data1')) a.c1s[1].c2s.append(C2('child2 data2')) - objectstore.commit() + sess = create_session(echo_uow=False) + sess.save(a) + sess.flush() - objectstore.delete(a) - objectstore.commit() + sess.delete(a) + sess.flush() class BiDirectionalOneToManyTest(AssertMixin): """tests two mappers with a one-to-many relation to each other.""" def setUpAll(self): - testbase.db.tables.clear() - global t1 - global t2 - t1 = Table('t1', testbase.db, + global t1, t2, metadata + metadata = BoundMetaData(testbase.db) + t1 = Table('t1', metadata, Column('c1', Integer, primary_key=True), Column('c2', Integer, ForeignKey('t2.c1')) ) - t2 = Table('t2', testbase.db, + t2 = Table('t2', metadata, Column('c1', Integer, primary_key=True), Column('c2', Integer) ) - t2.create() - t1.create() + metadata.create_all() t2.c.c2.append_item(ForeignKey('t1.c1')) def tearDownAll(self): - t1.drop() + t1.drop() t2.drop() - def setUp(self): - objectstore.clear() - #objectstore.LOG = True + #metadata.drop_all() + def tearDown(self): clear_mappers() - def testcycle(self): class C1(object):pass class C2(object):pass @@ -123,35 +119,33 @@ class BiDirectionalOneToManyTest(AssertMixin): a.c2s.append(b) d.c1s.append(c) b.c1s.append(c) - objectstore.commit() + sess = create_session() + [sess.save(x) for x in [a,b,c,d,e,f]] + sess.flush() class BiDirectionalOneToManyTest2(AssertMixin): """tests two mappers with a one-to-many relation to each other, with a second one-to-many on one of the mappers""" def setUpAll(self): - testbase.db.tables.clear() - global t1 - global t2 - global t3 - t1 = Table('t1', testbase.db, + global t1, t2, t3, metadata + metadata = BoundMetaData(testbase.db) + t1 = Table('t1', metadata, Column('c1', Integer, primary_key=True), Column('c2', Integer, ForeignKey('t2.c1')), ) - t2 = Table('t2', testbase.db, + t2 = Table('t2', metadata, Column('c1', Integer, primary_key=True), Column('c2', Integer), ) t2.create() t1.create() t2.c.c2.append_item(ForeignKey('t1.c1')) - t3 = Table('t1_data', testbase.db, + t3 = Table('t1_data', metadata, Column('c1', Integer, primary_key=True), Column('t1id', Integer, ForeignKey('t1.c1')), Column('data', String(20))) t3.create() - def setUp(self): - objectstore.clear() - #objectstore.LOG = True + def tearDown(self): clear_mappers() def tearDownAll(self): @@ -185,25 +179,28 @@ class BiDirectionalOneToManyTest2(AssertMixin): a.data.append(C1Data('c1data1')) a.data.append(C1Data('c1data2')) c.data.append(C1Data('c1data3')) - objectstore.commit() + sess = create_session() + [sess.save(x) for x in [a,b,c,d,e,f]] + sess.flush() - objectstore.delete(d) - objectstore.delete(c) - objectstore.commit() + sess.delete(d) + sess.delete(c) + sess.flush() class OneToManyManyToOneTest(AssertMixin): """tests two mappers, one has a one-to-many on the other mapper, the other has a separate many-to-one relationship to the first. two tests will have a row for each item that is dependent on the other. without the "post_update" flag, such relationships raise an exception when dependencies are sorted.""" def setUpAll(self): - testbase.db.tables.clear() + global metadata + metadata = BoundMetaData(testbase.db) global person global ball - ball = Table('ball', db, + ball = Table('ball', metadata, Column('id', Integer, Sequence('ball_id_seq', optional=True), primary_key=True), Column('person_id', Integer), ) - person = Table('person', db, + person = Table('person', metadata, Column('id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), Column('favoriteBall_id', Integer, ForeignKey('ball.id')), # Column('favoriteBall_id', Integer), @@ -223,9 +220,7 @@ class OneToManyManyToOneTest(AssertMixin): person.drop() ball.drop() - def setUp(self): - objectstore.clear() - #objectstore.LOG = True + def tearDown(self): clear_mappers() def testcycle(self): @@ -249,7 +244,10 @@ class OneToManyManyToOneTest(AssertMixin): b = Ball() p = Person() p.balls.append(b) - objectstore.commit() + sess = create_session() + sess.save(b) + sess.save(b) + sess.flush() def testpostupdate_m2o(self): """tests a cycle between two rows, with a post_update on the many-to-one""" @@ -275,76 +273,79 @@ class OneToManyManyToOneTest(AssertMixin): p.balls.append(Ball()) p.balls.append(Ball()) p.favorateBall = b - - self.assert_sql(db, lambda: objectstore.uow().commit(), [ + sess = create_session() + sess.save(b) + sess.save(p) + + self.assert_sql(db, lambda: sess.flush(), [ ( "INSERT INTO person (favoriteBall_id) VALUES (:favoriteBall_id)", {'favoriteBall_id': None} ), ( "INSERT INTO ball (person_id) VALUES (:person_id)", - lambda:{'person_id':p.id} + lambda ctx:{'person_id':p.id} ), ( "INSERT INTO ball (person_id) VALUES (:person_id)", - lambda:{'person_id':p.id} + lambda ctx:{'person_id':p.id} ), ( "INSERT INTO ball (person_id) VALUES (:person_id)", - lambda:{'person_id':p.id} + lambda ctx:{'person_id':p.id} ), ( "INSERT INTO ball (person_id) VALUES (:person_id)", - lambda:{'person_id':p.id} + lambda ctx:{'person_id':p.id} ), ( "UPDATE person SET favoriteBall_id=:favoriteBall_id WHERE person.id = :person_id", - lambda:[{'favoriteBall_id':p.favorateBall.id,'person_id':p.id}] + lambda ctx:{'favoriteBall_id':p.favorateBall.id,'person_id':p.id} ) ], with_sequences= [ ( "INSERT INTO person (id, favoriteBall_id) VALUES (:id, :favoriteBall_id)", - lambda:{'id':db.last_inserted_ids()[0], 'favoriteBall_id': None} + lambda ctx:{'id':ctx.last_inserted_ids()[0], 'favoriteBall_id': None} ), ( "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)", - lambda:{'id':db.last_inserted_ids()[0],'person_id':p.id} + lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id} ), ( "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)", - lambda:{'id':db.last_inserted_ids()[0],'person_id':p.id} + lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id} ), ( "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)", - lambda:{'id':db.last_inserted_ids()[0],'person_id':p.id} + lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id} ), ( "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)", - lambda:{'id':db.last_inserted_ids()[0],'person_id':p.id} + lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id} ), # heres the post update ( "UPDATE person SET favoriteBall_id=:favoriteBall_id WHERE person.id = :person_id", - lambda:[{'favoriteBall_id':p.favorateBall.id,'person_id':p.id}] + lambda ctx:{'favoriteBall_id':p.favorateBall.id,'person_id':p.id} ) ]) - objectstore.delete(p) - self.assert_sql(db, lambda: objectstore.uow().commit(), [ + sess.delete(p) + self.assert_sql(db, lambda: sess.flush(), [ # heres the post update (which is a pre-update with deletes) ( "UPDATE person SET favoriteBall_id=:favoriteBall_id WHERE person.id = :person_id", - lambda:[{'person_id': p.id, 'favoriteBall_id': None}] + lambda ctx:{'person_id': p.id, 'favoriteBall_id': None} ), ( "DELETE FROM ball WHERE ball.id = :id", None # order cant be predicted, but something like: - #lambda:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}] + #lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}] ), ( "DELETE FROM person WHERE person.id = :id", - lambda:[{'id': p.id}] + lambda ctx:[{'id': p.id}] ) @@ -377,8 +378,10 @@ class OneToManyManyToOneTest(AssertMixin): b4 = Ball() p.balls.append(b4) p.favorateBall = b -# objectstore.commit() - self.assert_sql(db, lambda: objectstore.uow().commit(), [ + sess = create_session() + [sess.save(x) for x in [b,p,b2,b3,b4]] + + self.assert_sql(db, lambda: sess.flush(), [ ( "INSERT INTO ball (person_id) VALUES (:person_id)", {'person_id':None} @@ -397,92 +400,92 @@ class OneToManyManyToOneTest(AssertMixin): ), ( "INSERT INTO person (favoriteBall_id) VALUES (:favoriteBall_id)", - lambda:{'favoriteBall_id':b.id} + lambda ctx:{'favoriteBall_id':b.id} ), # heres the post update on each one-to-many item ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id':p.id,'ball_id':b.id}] + lambda ctx:{'person_id':p.id,'ball_id':b.id} ), ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id':p.id,'ball_id':b2.id}] + lambda ctx:{'person_id':p.id,'ball_id':b2.id} ), ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id':p.id,'ball_id':b3.id}] + lambda ctx:{'person_id':p.id,'ball_id':b3.id} ), ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id':p.id,'ball_id':b4.id}] + lambda ctx:{'person_id':p.id,'ball_id':b4.id} ), ], with_sequences=[ ( "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)", - lambda:{'id':db.last_inserted_ids()[0], 'person_id':None} + lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None} ), ( "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)", - lambda:{'id':db.last_inserted_ids()[0], 'person_id':None} + lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None} ), ( "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)", - lambda:{'id':db.last_inserted_ids()[0], 'person_id':None} + lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None} ), ( "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)", - lambda:{'id':db.last_inserted_ids()[0], 'person_id':None} + lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None} ), ( "INSERT INTO person (id, favoriteBall_id) VALUES (:id, :favoriteBall_id)", - lambda:{'id':db.last_inserted_ids()[0], 'favoriteBall_id':b.id} + lambda ctx:{'id':ctx.last_inserted_ids()[0], 'favoriteBall_id':b.id} ), ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id':p.id,'ball_id':b.id}] + lambda ctx:{'person_id':p.id,'ball_id':b.id} ), ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id':p.id,'ball_id':b2.id}] + lambda ctx:{'person_id':p.id,'ball_id':b2.id} ), ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id':p.id,'ball_id':b3.id}] + lambda ctx:{'person_id':p.id,'ball_id':b3.id} ), ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id':p.id,'ball_id':b4.id}] + lambda ctx:{'person_id':p.id,'ball_id':b4.id} ), ]) - objectstore.delete(p) - self.assert_sql(db, lambda: objectstore.uow().commit(), [ + sess.delete(p) + self.assert_sql(db, lambda: sess.flush(), [ ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id': None, 'ball_id': b.id}] + lambda ctx:{'person_id': None, 'ball_id': b.id} ), ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id': None, 'ball_id': b2.id}] + lambda ctx:{'person_id': None, 'ball_id': b2.id} ), ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id': None, 'ball_id': b3.id}] + lambda ctx:{'person_id': None, 'ball_id': b3.id} ), ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", - lambda:[{'person_id': None, 'ball_id': b4.id}] + lambda ctx:{'person_id': None, 'ball_id': b4.id} ), ( "DELETE FROM person WHERE person.id = :id", - lambda:[{'id':p.id}] + lambda ctx:[{'id':p.id}] ), ( "DELETE FROM ball WHERE ball.id = :id", None # the order of deletion is not predictable, but its roughly: - # lambda:[{'id': b.id}, {'id': b2.id}, {'id': b3.id}, {'id': b4.id}] + # lambda ctx:[{'id': b.id}, {'id': b2.id}, {'id': b3.id}, {'id': b4.id}] ) ]) diff --git a/test/defaults.py b/test/defaults.py index 8d848f4c5..a271cbcb5 100644 --- a/test/defaults.py +++ b/test/defaults.py @@ -7,7 +7,7 @@ from sqlalchemy import * import sqlalchemy db = testbase.db -testbase.echo=False + class DefaultTest(PersistTest): def setUpAll(self): @@ -20,15 +20,17 @@ class DefaultTest(PersistTest): use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle' is_oracle = db.engine.name == 'oracle' - # select "count(1)" from the DB which returns different results - # on different DBs - currenttime = db.func.current_date(type=Date); + # select "count(1)" returns different results on different DBs + # also correct for "current_date" compatible as column default, value differences + currenttime = func.current_date(type=Date, engine=db); if is_oracle: - ts = db.func.sysdate().scalar() + ts = db.func.trunc(func.sysdate(), column("'DAY'")).scalar() f = select([func.count(1) + 5], engine=db).scalar() f2 = select([func.count(1) + 14], engine=db).scalar() + # TODO: engine propigation across nested functions not working + currenttime = func.trunc(currenttime, column("'DAY'"), engine=db) def1 = currenttime - def2 = text("sysdate") + def2 = func.trunc(text("sysdate"), column("'DAY'")) deftype = Date elif use_function_defaults: f = select([func.count(1) + 5], engine=db).scalar() @@ -72,9 +74,10 @@ class DefaultTest(PersistTest): t.delete().execute() def teststandalone(self): - x = t.c.col1.default.execute() + c = db.engine.contextual_connect() + x = c.execute(t.c.col1.default) y = t.c.col2.default.execute() - z = t.c.col3.default.execute() + z = c.execute(t.c.col3.default) self.assert_(50 <= x <= 57) self.assert_(y == 'imthedefault') self.assert_(z == f) @@ -82,8 +85,8 @@ class DefaultTest(PersistTest): self.assert_(5 <= z <= 6) def testinsert(self): - t.insert().execute() - self.assert_(t.engine.lastrow_has_defaults()) + r = t.insert().execute() + self.assert_(r.lastrow_has_defaults()) t.insert().execute() t.insert().execute() @@ -99,8 +102,8 @@ class DefaultTest(PersistTest): def testupdate(self): - t.insert().execute() - pk = t.engine.last_inserted_ids()[0] + r = t.insert().execute() + pk = r.last_inserted_ids()[0] t.update(t.c.col1==pk).execute(col4=None, col5=None) ctexec = currenttime.scalar() self.echo("Currenttime "+ repr(ctexec)) @@ -111,8 +114,8 @@ class DefaultTest(PersistTest): self.assert_(14 <= f2 <= 15) def testupdatevalues(self): - t.insert().execute() - pk = t.engine.last_inserted_ids()[0] + r = t.insert().execute() + pk = r.last_inserted_ids()[0] t.update(t.c.col1==pk, values={'col3': 55}).execute() l = t.select(t.c.col1==pk).execute() l = l.fetchone() @@ -143,10 +146,10 @@ class SequenceTest(PersistTest): @testbase.supported('postgres', 'oracle') def teststandalone(self): - s = Sequence("my_sequence", engine=db) + s = Sequence("my_sequence", metadata=testbase.db) s.create() try: - x =s.execute() + x = s.execute() self.assert_(x == 1) finally: s.drop() diff --git a/test/dependency.py b/test/dependency.py index 81165dc6d..d2b5bd698 100644 --- a/test/dependency.py +++ b/test/dependency.py @@ -1,5 +1,5 @@ from testbase import PersistTest -import sqlalchemy.mapping.topological as topological +import sqlalchemy.orm.topological as topological import unittest, sys, os @@ -17,26 +17,6 @@ class thingy(object): return repr(self) class DependencySortTest(PersistTest): - - def _assert_sort(self, tuples, allnodes, **kwargs): - - head = DependencySorter(tuples, allnodes).sort(**kwargs) - - print "\n" + str(head) - def findnode(t, n, parent=False): - if n.item is t[0] or (n.cycles is not None and t[0] in [c.item for c in n.cycles]): - parent=True - elif n.item is t[1]: - if not parent and (n.cycles is None or t[0] not in [c.item for c in n.cycles]): - self.assert_(False, "Node " + str(t[1]) + " not a child of " +str(t[0])) - else: - return - for c in n.children: - findnode(t, c, parent) - - for t in tuples: - findnode(t, head) - def testsort(self): rootnode = thingy('root') node2 = thingy('node2') @@ -47,7 +27,6 @@ class DependencySortTest(PersistTest): subnode3 = thingy('subnode3') subnode4 = thingy('subnode4') subsubnode1 = thingy('subsubnode1') - allnodes = [rootnode, node2,node3,node4,subnode1,subnode2,subnode3,subnode4,subsubnode1] tuples = [ (subnode3, subsubnode1), (node2, subnode1), @@ -58,8 +37,8 @@ class DependencySortTest(PersistTest): (node4, subnode3), (node4, subnode4) ] - - self._assert_sort(tuples, allnodes) + head = DependencySorter(tuples, []).sort() + print "\n" + str(head) def testsort2(self): node1 = thingy('node1') @@ -76,7 +55,8 @@ class DependencySortTest(PersistTest): (node5, node6), (node6, node2) ] - self._assert_sort(tuples, [node1,node2,node3,node4,node5,node6,node7]) + head = DependencySorter(tuples, [node7]).sort() + print "\n" + str(head) def testsort3(self): ['Mapper|Keyword|keywords,Mapper|IKAssociation|itemkeywords', 'Mapper|Item|items,Mapper|IKAssociation|itemkeywords'] @@ -88,10 +68,15 @@ class DependencySortTest(PersistTest): (node3, node2), (node1,node3) ] - self._assert_sort(tuples, [node1, node2, node3]) - self._assert_sort(tuples, [node3, node1, node2]) - self._assert_sort(tuples, [node3, node2, node1]) + head1 = DependencySorter(tuples, [node1, node2, node3]).sort() + head2 = DependencySorter(tuples, [node3, node1, node2]).sort() + head3 = DependencySorter(tuples, [node3, node2, node1]).sort() + # TODO: figure out a "node == node2" function + #self.assert_(str(head1) == str(head2) == str(head3)) + print "\n" + str(head1) + print "\n" + str(head2) + print "\n" + str(head3) def testsort4(self): node1 = thingy('keywords') @@ -104,7 +89,8 @@ class DependencySortTest(PersistTest): (node1, node3), (node3, node2) ] - self._assert_sort(tuples, [node1,node2,node3,node4]) + head = DependencySorter(tuples, []).sort() + print "\n" + str(head) def testsort5(self): # this one, depenending on the weather, @@ -131,24 +117,10 @@ class DependencySortTest(PersistTest): node3, node4 ] - self._assert_sort(tuples, allitems) - - def testsort6(self): - #('tbl_c', 'tbl_d'), ('tbl_a', 'tbl_c'), ('tbl_b', 'tbl_d') - nodea = thingy('tbl_a') - nodeb = thingy('tbl_b') - nodec = thingy('tbl_c') - noded = thingy('tbl_d') - tuples = [ - (nodec, noded), - (nodea, nodec), - (nodeb, noded) - ] - allitems = [nodea,nodeb,nodec,noded] - self._assert_sort(tuples, allitems) + head = DependencySorter(tuples, allitems).sort() + print "\n" + str(head) def testcircular(self): - #print "TESTCIRCULAR" node1 = thingy('node1') node2 = thingy('node2') node3 = thingy('node3') @@ -162,8 +134,8 @@ class DependencySortTest(PersistTest): (node3, node1), (node4, node1) ] - self._assert_sort(tuples, [node1,node2,node3,node4,node5], allow_all_cycles=True) - #print "TESTCIRCULAR DONE" + head = DependencySorter(tuples, []).sort(allow_all_cycles=True) + print "\n" + str(head) if __name__ == "__main__": diff --git a/test/eagertest1.py b/test/eagertest1.py index 5897e4016..9765379f4 100644 --- a/test/eagertest1.py +++ b/test/eagertest1.py @@ -7,65 +7,60 @@ import datetime class EagerTest(AssertMixin): def setUpAll(self): global designType, design, part, inheritedPart - - designType = Table('design_types', testbase.db, + designType = Table('design_types', testbase.metadata, Column('design_type_id', Integer, primary_key=True), ) - design =Table('design', testbase.db, + design =Table('design', testbase.metadata, Column('design_id', Integer, primary_key=True), Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) - part = Table('parts', testbase.db, + part = Table('parts', testbase.metadata, Column('part_id', Integer, primary_key=True), Column('design_id', Integer, ForeignKey('design.design_id')), Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) - inheritedPart = Table('inherited_part', testbase.db, + inheritedPart = Table('inherited_part', testbase.metadata, Column('ip_id', Integer, primary_key=True), Column('part_id', Integer, ForeignKey('parts.part_id')), Column('design_id', Integer, ForeignKey('design.design_id')), ) - designType.create() - design.create() - part.create() - inheritedPart.create() + testbase.metadata.create_all() def tearDownAll(self): - inheritedPart.drop() - part.drop() - design.drop() - designType.drop() - + testbase.metadata.drop_all() + testbase.metadata.clear() def testone(self): class Part(object):pass class Design(object):pass class DesignType(object):pass class InheritedPart(object):pass - assign_mapper(Part, part) + mapper(Part, part) - assign_mapper(InheritedPart, inheritedPart, properties=dict( + mapper(InheritedPart, inheritedPart, properties=dict( part=relation(Part, lazy=False) )) - assign_mapper(Design, design, properties=dict( + mapper(Design, design, properties=dict( parts=relation(Part, private=True, backref="design"), inheritedParts=relation(InheritedPart, private=True, backref="design"), )) - assign_mapper(DesignType, designType, properties=dict( + mapper(DesignType, designType, properties=dict( # designs=relation(Design, private=True, backref="type"), )) - Design.mapper.add_property("type", relation(DesignType, lazy=False, backref="designs")) - Part.mapper.add_property("design", relation(Design, lazy=False, backref="parts")) + class_mapper(Design).add_property("type", relation(DesignType, lazy=False, backref="designs")) + class_mapper(Part).add_property("design", relation(Design, lazy=False, backref="parts")) #Part.mapper.add_property("designType", relation(DesignType)) d = Design() - objectstore.commit() - objectstore.clear() - x = Design.get(1) + sess = create_session() + sess.save(d) + sess.flush() + sess.clear() + x = sess.query(Design).get(1) x.inheritedParts if __name__ == "__main__": diff --git a/test/eagertest2.py b/test/eagertest2.py index 430e12ba7..ef385df16 100644 --- a/test/eagertest2.py +++ b/test/eagertest2.py @@ -3,70 +3,56 @@ import testbase import unittest, sys, os from sqlalchemy import * import datetime - -db = testbase.db +from sqlalchemy.ext.sessioncontext import SessionContext class EagerTest(AssertMixin): def setUpAll(self): - objectstore.clear() - clear_mappers() - testbase.db.tables.clear() - - global companies_table, addresses_table, invoice_table, phones_table, items_table + global companies_table, addresses_table, invoice_table, phones_table, items_table, ctx, metadata - companies_table = Table('companies', db, + metadata = BoundMetaData(testbase.db) + ctx = SessionContext(create_session) + + companies_table = Table('companies', metadata, Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key = True), Column('company_name', String(40)), ) - addresses_table = Table('addresses', db, + addresses_table = Table('addresses', metadata, Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True), Column('company_id', Integer, ForeignKey("companies.company_id")), Column('address', String(40)), ) - phones_table = Table('phone_numbers', db, + phones_table = Table('phone_numbers', metadata, Column('phone_id', Integer, Sequence('phone_id_seq', optional=True), primary_key = True), Column('address_id', Integer, ForeignKey('addresses.address_id')), Column('type', String(20)), Column('number', String(10)), ) - invoice_table = Table('invoices', db, + invoice_table = Table('invoices', metadata, Column('invoice_id', Integer, Sequence('invoice_id_seq', optional=True), primary_key = True), Column('company_id', Integer, ForeignKey("companies.company_id")), Column('date', DateTime), ) - items_table = Table('items', db, + items_table = Table('items', metadata, Column('item_id', Integer, Sequence('item_id_seq', optional=True), primary_key = True), Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')), Column('code', String(20)), Column('qty', Integer), ) - companies_table.create() - addresses_table.create() - phones_table.create() - invoice_table.create() - items_table.create() + metadata.create_all() def tearDownAll(self): - items_table.drop() - invoice_table.drop() - phones_table.drop() - addresses_table.drop() - companies_table.drop() + metadata.drop_all() def tearDown(self): - objectstore.clear() clear_mappers() - items_table.delete().execute() - invoice_table.delete().execute() - phones_table.delete().execute() - addresses_table.delete().execute() - companies_table.delete().execute() + for t in metadata.table_iterator(reverse=True): + t.delete().execute() def testone(self): """tests eager load of a many-to-one attached to a one-to-many. this testcase illustrated @@ -88,14 +74,14 @@ class EagerTest(AssertMixin): def __repr__(self): return "Invoice:" + repr(getattr(self, 'invoice_id', None)) + " " + repr(getattr(self, 'date', None)) + " " + repr(self.company) - Address.mapper = mapper(Address, addresses_table, properties={ - }) - Company.mapper = mapper(Company, companies_table, properties={ - 'addresses' : relation(Address.mapper, lazy=False), - }) - Invoice.mapper = mapper(Invoice, invoice_table, properties={ - 'company': relation(Company.mapper, lazy=False, ) - }) + mapper(Address, addresses_table, properties={ + }, extension=ctx.mapper_extension) + mapper(Company, companies_table, properties={ + 'addresses' : relation(Address, lazy=False), + }, extension=ctx.mapper_extension) + mapper(Invoice, invoice_table, properties={ + 'company': relation(Company, lazy=False, ) + }, extension=ctx.mapper_extension) c1 = Company() c1.company_name = 'company 1' @@ -109,19 +95,18 @@ class EagerTest(AssertMixin): i1.date = datetime.datetime.now() i1.company = c1 - - objectstore.commit() + ctx.current.flush() company_id = c1.company_id invoice_id = i1.invoice_id - objectstore.clear() + ctx.current.clear() - c = Company.mapper.get(company_id) + c = ctx.current.query(Company).get(company_id) - objectstore.clear() + ctx.current.clear() - i = Invoice.mapper.get(invoice_id) + i = ctx.current.query(Invoice).get(invoice_id) self.echo(repr(c)) self.echo(repr(i.company)) @@ -153,24 +138,24 @@ class EagerTest(AssertMixin): def __repr__(self): return "Item: " + repr(getattr(self, 'item_id', None)) + " " + repr(getattr(self, 'invoice_id', None)) + " " + repr(self.code) + " " + repr(self.qty) - Phone.mapper = mapper(Phone, phones_table, is_primary=True) + mapper(Phone, phones_table, extension=ctx.mapper_extension) - Address.mapper = mapper(Address, addresses_table, properties={ - 'phones': relation(Phone.mapper, lazy=False, backref='address') - }) + mapper(Address, addresses_table, properties={ + 'phones': relation(Phone, lazy=False, backref='address') + }, extension=ctx.mapper_extension) - Company.mapper = mapper(Company, companies_table, properties={ - 'addresses' : relation(Address.mapper, lazy=False, backref='company'), - }) + mapper(Company, companies_table, properties={ + 'addresses' : relation(Address, lazy=False, backref='company'), + }, extension=ctx.mapper_extension) - Item.mapper = mapper(Item, items_table, is_primary=True) + mapper(Item, items_table, extension=ctx.mapper_extension) - Invoice.mapper = mapper(Invoice, invoice_table, properties={ - 'items': relation(Item.mapper, lazy=False, backref='invoice'), - 'company': relation(Company.mapper, lazy=False, backref='invoices') - }) + mapper(Invoice, invoice_table, properties={ + 'items': relation(Item, lazy=False, backref='invoice'), + 'company': relation(Company, lazy=False, backref='invoices') + }, extension=ctx.mapper_extension) - objectstore.clear() + ctx.current.clear() c1 = Company() c1.company_name = 'company 1' @@ -205,13 +190,13 @@ class EagerTest(AssertMixin): c1.addresses.append(a2) - objectstore.commit() + ctx.current.flush() company_id = c1.company_id - objectstore.clear() + ctx.current.clear() - a = Company.mapper.get(company_id) + a = ctx.current.query(Company).get(company_id) self.echo(repr(a)) # set up an invoice @@ -234,18 +219,18 @@ class EagerTest(AssertMixin): item3.qty = 3 item3.invoice = i1 - objectstore.commit() + ctx.current.flush() invoice_id = i1.invoice_id - objectstore.clear() + ctx.current.clear() - c = Company.mapper.get(company_id) + c = ctx.current.query(Company).get(company_id) self.echo(repr(c)) - objectstore.clear() + ctx.current.clear() - i = Invoice.mapper.get(invoice_id) + i = ctx.current.query(Invoice).get(invoice_id) self.echo(repr(i)) self.assert_(repr(i.company) == repr(c)) diff --git a/test/engine.py b/test/engine.py deleted file mode 100644 index 193201fa1..000000000 --- a/test/engine.py +++ /dev/null @@ -1,64 +0,0 @@ -from sqlalchemy import * - -from testbase import PersistTest -import testbase -import unittest, re -import tables - -class TransactionTest(PersistTest): - def setUpAll(self): - tables.create() - def tearDownAll(self): - tables.drop() - def tearDown(self): - tables.delete() - - def testbasic(self): - testbase.db.begin() - tables.users.insert().execute(user_name='jack') - tables.users.insert().execute(user_name='fred') - testbase.db.commit() - l = tables.users.select().execute().fetchall() - print l - self.assert_(len(l) == 2) - - def testrollback(self): - testbase.db.begin() - tables.users.insert().execute(user_name='jack') - tables.users.insert().execute(user_name='fred') - testbase.db.rollback() - l = tables.users.select().execute().fetchall() - print l - self.assert_(len(l) == 0) - - @testbase.unsupported('sqlite') - def testnested(self): - """tests nested sessions. SQLite should raise an error.""" - testbase.db.begin() - tables.users.insert().execute(user_name='jack') - tables.users.insert().execute(user_name='fred') - testbase.db.push_session() - tables.users.insert().execute(user_name='ed') - tables.users.insert().execute(user_name='wendy') - testbase.db.pop_session() - testbase.db.rollback() - l = tables.users.select().execute().fetchall() - print l - self.assert_(len(l) == 2) - - def testtwo(self): - testbase.db.begin() - tables.users.insert().execute(user_name='jack') - tables.users.insert().execute(user_name='fred') - testbase.db.commit() - testbase.db.begin() - tables.users.insert().execute(user_name='ed') - tables.users.insert().execute(user_name='wendy') - testbase.db.commit() - testbase.db.rollback() - l = tables.users.select().execute().fetchall() - print l - self.assert_(len(l) == 4) - -if __name__ == "__main__": - testbase.main() diff --git a/test/entity.py b/test/entity.py index 591cff7ea..22e74f171 100644 --- a/test/entity.py +++ b/test/entity.py @@ -2,6 +2,7 @@ from testbase import PersistTest, AssertMixin import unittest from sqlalchemy import * import testbase +from sqlalchemy.ext.sessioncontext import SessionContext from tables import * import tables @@ -10,38 +11,35 @@ class EntityTest(AssertMixin): """tests mappers that are constructed based on "entity names", which allows the same class to have multiple primary mappers """ def setUpAll(self): - global user1, user2, address1, address2 - db = testbase.db - user1 = Table('user1', db, + global user1, user2, address1, address2, metadata, ctx + metadata = BoundMetaData(testbase.db) + ctx = SessionContext(create_session) + + user1 = Table('user1', metadata, Column('user_id', Integer, Sequence('user1_id_seq'), primary_key=True), Column('name', String(60), nullable=False) - ).create() - user2 = Table('user2', db, + ) + user2 = Table('user2', metadata, Column('user_id', Integer, Sequence('user2_id_seq'), primary_key=True), Column('name', String(60), nullable=False) - ).create() - address1 = Table('address1', db, + ) + address1 = Table('address1', metadata, Column('address_id', Integer, Sequence('address1_id_seq'), primary_key=True), Column('user_id', Integer, ForeignKey(user1.c.user_id), nullable=False), Column('email', String(100), nullable=False) - ).create() - address2 = Table('address2', db, + ) + address2 = Table('address2', metadata, Column('address_id', Integer, Sequence('address2_id_seq'), primary_key=True), Column('user_id', Integer, ForeignKey(user2.c.user_id), nullable=False), Column('email', String(100), nullable=False) - ).create() + ) + metadata.create_all() def tearDownAll(self): - address1.drop() - address2.drop() - user1.drop() - user2.drop() + metadata.drop_all() def tearDown(self): - address1.delete().execute() - address2.delete().execute() - user1.delete().execute() - user2.delete().execute() - objectstore.clear() clear_mappers() + for t in metadata.table_iterator(reverse=True): + t.delete().execute() def testbasic(self): """tests a pair of one-to-many mapper structures, establishing that both @@ -50,14 +48,14 @@ class EntityTest(AssertMixin): class User(object):pass class Address(object):pass - a1mapper = mapper(Address, address1, entity_name='address1') - a2mapper = mapper(Address, address2, entity_name='address2') + a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.mapper_extension) + a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.mapper_extension) u1mapper = mapper(User, user1, entity_name='user1', properties ={ 'addresses':relation(a1mapper) - }) + }, extension=ctx.mapper_extension) u2mapper =mapper(User, user2, entity_name='user2', properties={ 'addresses':relation(a2mapper) - }) + }, extension=ctx.mapper_extension) u1 = User(_sa_entity_name='user1') u1.name = 'this is user 1' @@ -71,15 +69,15 @@ class EntityTest(AssertMixin): a2.email='a2@foo.com' u2.addresses.append(a2) - objectstore.commit() + ctx.current.flush() assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] - objectstore.clear() - u1list = u1mapper.select() - u2list = u2mapper.select() + ctx.current.clear() + u1list = ctx.current.query(User, entity_name='user1').select() + u2list = ctx.current.query(User, entity_name='user2').select() assert len(u1list) == len(u2list) == 1 assert u1list[0] is not u2list[0] assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 @@ -90,14 +88,14 @@ class EntityTest(AssertMixin): class Address1(object):pass class Address2(object):pass - a1mapper = mapper(Address1, address1) - a2mapper = mapper(Address2, address2) + a1mapper = mapper(Address1, address1, extension=ctx.mapper_extension) + a2mapper = mapper(Address2, address2, extension=ctx.mapper_extension) u1mapper = mapper(User, user1, entity_name='user1', properties ={ 'addresses':relation(a1mapper) - }) + }, extension=ctx.mapper_extension) u2mapper =mapper(User, user2, entity_name='user2', properties={ 'addresses':relation(a2mapper) - }) + }, extension=ctx.mapper_extension) u1 = User(_sa_entity_name='user1') u1.name = 'this is user 1' @@ -111,15 +109,15 @@ class EntityTest(AssertMixin): a2.email='a2@foo.com' u2.addresses.append(a2) - objectstore.commit() + ctx.current.flush() assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] - objectstore.clear() - u1list = u1mapper.select() - u2list = u2mapper.select() + ctx.current.clear() + u1list = ctx.current.query(User, entity_name='user1').select() + u2list = ctx.current.query(User, entity_name='user2').select() assert len(u1list) == len(u2list) == 1 assert u1list[0] is not u2list[0] assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 diff --git a/test/indexes.py b/test/indexes.py index fbf1a2c81..a111b34e5 100644 --- a/test/indexes.py +++ b/test/indexes.py @@ -5,59 +5,51 @@ import testbase class IndexTest(testbase.AssertMixin): def setUp(self): - self.created = [] + global metadata + metadata = BoundMetaData(testbase.db) self.echo = testbase.db.echo self.logger = testbase.db.logger def tearDown(self): testbase.db.echo = self.echo testbase.db.logger = testbase.db.engine.logger = self.logger - if self.created: - self.created.reverse() - for entity in self.created: - entity.drop() + metadata.drop_all() def test_index_create(self): - employees = Table('employees', testbase.db, + employees = Table('employees', metadata, Column('id', Integer, primary_key=True), Column('first_name', String(30)), Column('last_name', String(30)), Column('email_address', String(30))) employees.create() - self.created.append(employees) i = Index('employee_name_index', employees.c.last_name, employees.c.first_name) i.create() - self.created.append(i) assert employees.indexes['employee_name_index'] is i i2 = Index('employee_email_index', employees.c.email_address, unique=True) i2.create() - self.created.append(i2) assert employees.indexes['employee_email_index'] is i2 def test_index_create_camelcase(self): """test that mixed-case index identifiers are legal""" - employees = Table('companyEmployees', testbase.db, + employees = Table('companyEmployees', metadata, Column('id', Integer, primary_key=True), Column('firstName', String(30)), Column('lastName', String(30)), Column('emailAddress', String(30))) employees.create() - self.created.append(employees) i = Index('employeeNameIndex', employees.c.lastName, employees.c.firstName) i.create() - self.created.append(i) i = Index('employeeEmailIndex', employees.c.emailAddress, unique=True) i.create() - self.created.append(i) # Check that the table is useable. This is mostly for pg, # which can be somewhat sticky with mixed-case identifiers @@ -75,8 +67,7 @@ class IndexTest(testbase.AssertMixin): stream = dummy() stream.write = capt.append testbase.db.logger = testbase.db.engine.logger = stream - - events = Table('events', testbase.db, + events = Table('events', metadata, Column('id', Integer, primary_key=True), Column('name', String(30), unique=True), Column('location', String(30), index=True), @@ -94,22 +85,21 @@ class IndexTest(testbase.AssertMixin): assert len(index_names) == 4 events.create() - self.created.append(events) # verify that the table is functional events.insert().execute(id=1, name='hockey finals', location='rink', sport='hockey', announcer='some canadian', winner='sweden') ss = events.select().execute().fetchall() - + assert capt[0].strip().startswith('CREATE TABLE events') - assert capt[2].strip() == \ + assert capt[3].strip() == \ 'CREATE UNIQUE INDEX ux_events_name ON events (name)' - assert capt[4].strip() == \ - 'CREATE INDEX ix_events_location ON events (location)' assert capt[6].strip() == \ + 'CREATE INDEX ix_events_location ON events (location)' + assert capt[9].strip() == \ 'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)' - assert capt[8].strip() == \ + assert capt[12].strip() == \ 'CREATE INDEX idx_winners ON events (winner)' if __name__ == "__main__": diff --git a/test/inheritance.py b/test/inheritance.py index 8b683b8ff..265134173 100644 --- a/test/inheritance.py +++ b/test/inheritance.py @@ -1,11 +1,13 @@ -from sqlalchemy import * import testbase +from sqlalchemy import * import string import sqlalchemy.attributes as attr import sys class Principal( object ): - pass + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) class User( Principal ): pass @@ -20,16 +22,18 @@ class InheritTest(testbase.AssertMixin): global users global groups global user_group_map + global metadata + metadata = BoundMetaData(testbase.db) principals = Table( 'principals', - testbase.db, + metadata, Column('principal_id', Integer, Sequence('principal_id_seq', optional=False), primary_key=True), Column('name', String(50), nullable=False), ) users = Table( 'prin_users', - testbase.db, + metadata, Column('principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True), Column('password', String(50), nullable=False), Column('email', String(50), nullable=False), @@ -39,14 +43,14 @@ class InheritTest(testbase.AssertMixin): groups = Table( 'prin_groups', - testbase.db, + metadata, Column( 'principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True), ) user_group_map = Table( 'prin_user_group_map', - testbase.db, + metadata, Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), primary_key=True ), Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), primary_key=True ), #Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), ), @@ -54,65 +58,57 @@ class InheritTest(testbase.AssertMixin): ) - principals.create() - users.create() - groups.create() - user_group_map.create() + metadata.create_all() + def tearDownAll(self): - user_group_map.drop() - groups.drop() - users.drop() - principals.drop() - testbase.db.tables.clear() + metadata.drop_all() + def setUp(self): - objectstore.clear() clear_mappers() def testbasic(self): - assign_mapper( Principal, principals ) - assign_mapper( + mapper( Principal, principals ) + mapper( User, users, - inherits=Principal.mapper + inherits=Principal ) - assign_mapper( + mapper( Group, groups, - inherits=Principal.mapper, - properties=dict( users = relation(User.mapper, user_group_map, lazy=True, backref="groups") ) + inherits=Principal, + properties=dict( users = relation(User, secondary=user_group_map, lazy=True, backref="groups") ) ) g = Group(name="group1") g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1")) - - objectstore.commit() + sess = create_session() + sess.save(g) + sess.flush() # TODO: put an assertion class InheritTest2(testbase.AssertMixin): """deals with inheritance and many-to-many relationships""" def setUpAll(self): - engine = testbase.db - global foo, bar, foo_bar - foo = Table('foo', engine, + global foo, bar, foo_bar, metadata + metadata = BoundMetaData(testbase.db) + foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_id_seq'), primary_key=True), Column('data', String(20)), ).create() - bar = Table('bar', engine, + bar = Table('bar', metadata, Column('bid', Integer, ForeignKey('foo.id'), primary_key=True), #Column('fid', Integer, ForeignKey('foo.id'), ) ).create() - foo_bar = Table('foo_bar', engine, + foo_bar = Table('foo_bar', metadata, Column('foo_id', Integer, ForeignKey('foo.id')), Column('bar_id', Integer, ForeignKey('bar.bid'))).create() - + metadata.create_all() def tearDownAll(self): - foo_bar.drop() - bar.drop() - foo.drop() - testbase.db.tables.clear() + metadata.drop_all() def testbasic(self): class Foo(object): @@ -123,33 +119,28 @@ class InheritTest2(testbase.AssertMixin): def __repr__(self): return str(self) - Foo.mapper = mapper(Foo, foo) + mapper(Foo, foo) class Bar(Foo): def __str__(self): return "Bar(%s)" % self.data - Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper, properties = { - # the old way, you needed to explicitly set up a compound - # column like this. but now the mapper uses SyncRules to match up - # the parent/child inherited columns - #'id':[bar.c.bid, foo.c.id] - }) - - #Bar.mapper.add_property('foos', relation(Foo.mapper, foo_bar, primaryjoin=bar.c.bid==foo_bar.c.bar_id, secondaryjoin=foo_bar.c.foo_id==foo.c.id, lazy=False)) - Bar.mapper.add_property('foos', relation(Foo.mapper, foo_bar, lazy=False)) - - b = Bar('barfoo') - objectstore.commit() + mapper(Bar, bar, inherits=Foo, properties={ + 'foos': relation(Foo, secondary=foo_bar, lazy=False) + }) + + sess = create_session() + b = Bar('barfoo', _sa_session=sess) + sess.flush() f1 = Foo('subfoo1') f2 = Foo('subfoo2') b.foos.append(f1) b.foos.append(f2) - objectstore.commit() - objectstore.clear() + sess.flush() + sess.clear() - l =b.mapper.select() + l = sess.query(Bar).select() print l[0] print l[0].foos self.assert_result(l, Bar, @@ -160,44 +151,38 @@ class InheritTest2(testbase.AssertMixin): class InheritTest3(testbase.AssertMixin): """deals with inheritance and many-to-many relationships""" def setUpAll(self): - engine = testbase.db - global foo, bar, blub, bar_foo, blub_bar, blub_foo,tables - engine.engine.echo = 'debug' + global foo, bar, blub, bar_foo, blub_bar, blub_foo,metadata + metadata = BoundMetaData(testbase.db) # the 'data' columns are to appease SQLite which cant handle a blank INSERT - foo = Table('foo', engine, + foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq'), primary_key=True), Column('data', String(20))) - bar = Table('bar', engine, + bar = Table('bar', metadata, Column('id', Integer, ForeignKey('foo.id'), primary_key=True), Column('data', String(20))) - blub = Table('blub', engine, + blub = Table('blub', metadata, Column('id', Integer, ForeignKey('bar.id'), primary_key=True), Column('data', String(20))) - bar_foo = Table('bar_foo', engine, + bar_foo = Table('bar_foo', metadata, Column('bar_id', Integer, ForeignKey('bar.id')), Column('foo_id', Integer, ForeignKey('foo.id'))) - blub_bar = Table('bar_blub', engine, + blub_bar = Table('bar_blub', metadata, Column('blub_id', Integer, ForeignKey('blub.id')), Column('bar_id', Integer, ForeignKey('bar.id'))) - blub_foo = Table('blub_foo', engine, + blub_foo = Table('blub_foo', metadata, Column('blub_id', Integer, ForeignKey('blub.id')), Column('foo_id', Integer, ForeignKey('foo.id'))) - - tables = [foo, bar, blub, bar_foo, blub_bar, blub_foo] - for table in tables: - table.create() + metadata.create_all() def tearDownAll(self): - for table in reversed(tables): - table.drop() - testbase.db.tables.clear() + metadata.drop_all() def tearDown(self): - for table in reversed(tables): + for table in metadata.table_iterator(): table.delete().execute() def testbasic(self): @@ -206,24 +191,24 @@ class InheritTest3(testbase.AssertMixin): self.data = data def __repr__(self): return "Foo id %d, data %s" % (self.id, self.data) - Foo.mapper = mapper(Foo, foo) + mapper(Foo, foo) class Bar(Foo): def __repr__(self): return "Bar id %d, data %s" % (self.id, self.data) - Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper, properties={ - #'foos' :relation(Foo.mapper, bar_foo, primaryjoin=bar.c.id==bar_foo.c.bar_id, lazy=False) - 'foos' :relation(Foo.mapper, bar_foo, lazy=True) + mapper(Bar, bar, inherits=Foo, properties={ + 'foos' :relation(Foo, secondary=bar_foo, lazy=True) }) - b = Bar('bar #1') + sess = create_session() + b = Bar('bar #1', _sa_session=sess) b.foos.append(Foo("foo #1")) b.foos.append(Foo("foo #2")) - objectstore.commit() + sess.flush() compare = repr(b) + repr(b.foos) - objectstore.clear() - l = Bar.mapper.select() + sess.clear() + l = sess.query(Bar).select() self.echo(repr(l[0]) + repr(l[0].foos)) self.assert_(repr(l[0]) + repr(l[0].foos) == compare) @@ -233,83 +218,66 @@ class InheritTest3(testbase.AssertMixin): self.data = data def __repr__(self): return "Foo id %d, data %s" % (self.id, self.data) - Foo.mapper = mapper(Foo, foo) + mapper(Foo, foo) class Bar(Foo): def __repr__(self): return "Bar id %d, data %s" % (self.id, self.data) - Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper) + mapper(Bar, bar, inherits=Foo) class Blub(Bar): def __repr__(self): return "Blub id %d, data %s, bars %s, foos %s" % (self.id, self.data, repr([b for b in self.bars]), repr([f for f in self.foos])) - Blub.mapper = mapper(Blub, blub, inherits=Bar.mapper, properties={ -# 'bars':relation(Bar.mapper, blub_bar, primaryjoin=blub.c.id==blub_bar.c.blub_id, lazy=False), -# 'foos':relation(Foo.mapper, blub_foo, primaryjoin=blub.c.id==blub_foo.c.blub_id, lazy=False), - 'bars':relation(Bar.mapper, blub_bar, lazy=False), - 'foos':relation(Foo.mapper, blub_foo, lazy=False), + mapper(Blub, blub, inherits=Bar, properties={ + 'bars':relation(Bar, secondary=blub_bar, lazy=False), + 'foos':relation(Foo, secondary=blub_foo, lazy=False), }) - useobjects = True - if (useobjects): - f1 = Foo("foo #1") - b1 = Bar("bar #1") - b2 = Bar("bar #2") - bl1 = Blub("blub #1") - bl1.foos.append(f1) - bl1.bars.append(b2) - objectstore.commit() - compare = repr(bl1) - blubid = bl1.id - objectstore.clear() - else: - foo.insert().execute(data='foo #1') - foo.insert().execute(data='foo #2') - bar.insert().execute(id=1, data="bar #1") - bar.insert().execute(id=2, data="bar #2") - blub.insert().execute(id=1, data="blub #1") - blub_bar.insert().execute(blub_id=1, bar_id=2) - blub_foo.insert().execute(blub_id=1, foo_id=2) - - l = Blub.mapper.select() + sess = create_session() + f1 = Foo("foo #1", _sa_session=sess) + b1 = Bar("bar #1", _sa_session=sess) + b2 = Bar("bar #2", _sa_session=sess) + bl1 = Blub("blub #1", _sa_session=sess) + bl1.foos.append(f1) + bl1.bars.append(b2) + sess.flush() + compare = repr(bl1) + blubid = bl1.id + sess.clear() + + l = sess.query(Blub).select() self.echo(l) self.assert_(repr(l[0]) == compare) - objectstore.clear() - x = Blub.mapper.get_by(id=blubid) #traceback 2 + sess.clear() + x = sess.query(Blub).get_by(id=blubid) self.echo(x) self.assert_(repr(x) == compare) class InheritTest4(testbase.AssertMixin): """deals with inheritance and one-to-many relationships""" def setUpAll(self): - engine = testbase.db - global foo, bar, blub, tables - engine.engine.echo = 'debug' + global foo, bar, blub, metadata + metadata = BoundMetaData(testbase.db) # the 'data' columns are to appease SQLite which cant handle a blank INSERT - foo = Table('foo', engine, + foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq'), primary_key=True), Column('data', String(20))) - bar = Table('bar', engine, + bar = Table('bar', metadata, Column('id', Integer, ForeignKey('foo.id'), primary_key=True), Column('data', String(20))) - blub = Table('blub', engine, + blub = Table('blub', metadata, Column('id', Integer, ForeignKey('bar.id'), primary_key=True), Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False), Column('data', String(20))) - - tables = [foo, bar, blub] - for table in tables: - table.create() + metadata.create_all() def tearDownAll(self): - for table in reversed(tables): - table.drop() - testbase.db.tables.clear() + metadata.drop_all() def tearDown(self): - for table in reversed(tables): + for table in metadata.table_iterator(): table.delete().execute() def testbasic(self): @@ -318,56 +286,55 @@ class InheritTest4(testbase.AssertMixin): self.data = data def __repr__(self): return "Foo id %d, data %s" % (self.id, self.data) - Foo.mapper = mapper(Foo, foo) + mapper(Foo, foo) class Bar(Foo): def __repr__(self): return "Bar id %d, data %s" % (self.id, self.data) - Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper) + mapper(Bar, bar, inherits=Foo) class Blub(Bar): def __repr__(self): return "Blub id %d, data %s" % (self.id, self.data) - Blub.mapper = mapper(Blub, blub, inherits=Bar.mapper, properties={ - # bug was raised specifically based on the order of cols in the join.... -# 'parent_foo':relation(Foo.mapper, primaryjoin=blub.c.foo_id==foo.c.id) -# 'parent_foo':relation(Foo.mapper, primaryjoin=foo.c.id==blub.c.foo_id) - 'parent_foo':relation(Foo.mapper) + mapper(Blub, blub, inherits=Bar, properties={ + 'parent_foo':relation(Foo) }) - b1 = Blub("blub #1") - b2 = Blub("blub #2") - f = Foo("foo #1") + sess = create_session() + b1 = Blub("blub #1", _sa_session=sess) + b2 = Blub("blub #2", _sa_session=sess) + f = Foo("foo #1", _sa_session=sess) b1.parent_foo = f b2.parent_foo = f - objectstore.commit() + sess.flush() compare = repr(b1) + repr(b2) + repr(b1.parent_foo) + repr(b2.parent_foo) - objectstore.clear() - l = Blub.mapper.select() + sess.clear() + l = sess.query(Blub).select() result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo) self.echo(result) self.assert_(compare == result) self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') -class InheritTest5(testbase.AssertMixin): +class InheritTest5(testbase.AssertMixin): + """testing that construction of inheriting mappers works regardless of when extra properties + are added to the superclass mapper""" def setUpAll(self): - engine = testbase.db - global content_type, content, product - content_type = Table('content_type', engine, + global content_type, content, product, metadata + metadata = BoundMetaData(testbase.db) + content_type = Table('content_type', metadata, Column('id', Integer, primary_key=True) ) - content = Table('content', engine, + content = Table('content', metadata, Column('id', Integer, primary_key=True), Column('content_type_id', Integer, ForeignKey('content_type.id')) ) - product = Table('product', engine, + product = Table('product', metadata, Column('id', Integer, ForeignKey('content.id'), primary_key=True) ) def tearDownAll(self): - testbase.db.tables.clear() - + pass def tearDown(self): pass @@ -384,14 +351,12 @@ class InheritTest5(testbase.AssertMixin): # shouldnt throw exception products = mapper(Product, product, inherits=contents) - def testbackref(self): - """this test is currently known to fail in the 0.1 series of SQLAlchemy, pending the resolution of [ticket:154]""" + """tests adding a property to the superclass mapper""" class ContentType(object): pass class Content(object): pass class Product(Content): pass - # this test fails currently contents = mapper(Content, content) products = mapper(Product, product, inherits=contents) content_types = mapper(ContentType, content_type, properties={ @@ -400,25 +365,24 @@ class InheritTest5(testbase.AssertMixin): p = Product() p.contenttype = ContentType() - class InheritTest6(testbase.AssertMixin): """tests eager load/lazy load of child items off inheritance mappers, tests that LazyLoader constructs the right query condition.""" def setUpAll(self): - global foo, bar, bar_foo - foo = Table('foo', testbase.db, Column('id', Integer, Sequence('foo_seq'), primary_key=True), - Column('data', String(30))).create() - bar = Table('bar', testbase.db, Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - Column('data', String(30))).create() - - bar_foo = Table('bar_foo', testbase.db, + global foo, bar, bar_foo, metadata + metadata=BoundMetaData(testbase.db) + foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq'), primary_key=True), + Column('data', String(30))) + bar = Table('bar', metadata, Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('data', String(30))) + + bar_foo = Table('bar_foo', metadata, Column('bar_id', Integer, ForeignKey('bar.id')), Column('foo_id', Integer, ForeignKey('foo.id')) - ).create() + ) + metadata.create_all() def tearDownAll(self): - bar_foo.drop() - bar.drop() - foo.drop() + metadata.drop_all() def testbasic(self): class Foo(object): pass @@ -427,6 +391,7 @@ class InheritTest6(testbase.AssertMixin): foos = mapper(Foo, foo) bars = mapper(Bar, bar, inherits=foos) bars.add_property('lazy', relation(foos, bar_foo, lazy=True)) + print bars.props['lazy'].primaryjoin, bars.props['lazy'].secondaryjoin bars.add_property('eager', relation(foos, bar_foo, lazy=False)) foo.insert().execute(data='foo1') @@ -440,9 +405,11 @@ class InheritTest6(testbase.AssertMixin): bar_foo.insert().execute(bar_id=1, foo_id=3) bar_foo.insert().execute(bar_id=2, foo_id=4) - - self.assert_(len(bars.selectfirst().lazy) == 1) - self.assert_(len(bars.selectfirst().eager) == 1) + + sess = create_session() + q = sess.query(Bar) + self.assert_(len(q.selectfirst().lazy) == 1) + self.assert_(len(q.selectfirst().eager) == 1) if __name__ == "__main__": testbase.main() diff --git a/test/lazytest1.py b/test/lazytest1.py index 986f067aa..eb1310d66 100644 --- a/test/lazytest1.py +++ b/test/lazytest1.py @@ -6,28 +6,25 @@ import datetime class LazyTest(AssertMixin): def setUpAll(self): - global info_table, data_table, rel_table - engine = testbase.db - info_table = Table('infos', engine, + global info_table, data_table, rel_table, metadata + metadata = BoundMetaData(testbase.db) + info_table = Table('infos', metadata, Column('pk', Integer, primary_key=True), Column('info', String)) - data_table = Table('data', engine, + data_table = Table('data', metadata, Column('data_pk', Integer, primary_key=True), Column('info_pk', Integer, ForeignKey(info_table.c.pk)), Column('timeval', Integer), Column('data_val', String)) - rel_table = Table('rels', engine, + rel_table = Table('rels', metadata, Column('rel_pk', Integer, primary_key=True), Column('info_pk', Integer, ForeignKey(info_table.c.pk)), Column('start', Integer), Column('finish', Integer)) - - info_table.create() - rel_table.create() - data_table.create() + metadata.create_all() info_table.insert().execute( {'pk':1, 'info':'pk_1_info'}, {'pk':2, 'info':'pk_2_info'}, @@ -52,13 +49,8 @@ class LazyTest(AssertMixin): def tearDownAll(self): - data_table.drop() - rel_table.drop() - info_table.drop() + metadata.drop_all() - def setUp(self): - clear_mappers() - def testone(self): """tests a lazy load which has multiple join conditions, including two that are against the same column in the child table""" @@ -71,58 +63,28 @@ class LazyTest(AssertMixin): class Data(object): pass - # Create the basic mappers, with no frills or modifications - Information.mapper = mapper(Information, info_table) - Data.mapper = mapper(Data, data_table) - Relation.mapper = mapper(Relation, rel_table) - - Relation.mapper.add_property('datas', relation(Data.mapper, - primaryjoin=and_(Relation.c.info_pk==Data.c.info_pk, - Data.c.timeval >= Relation.c.start, - Data.c.timeval <= Relation.c.finish - ), - foreignkey=Data.c.info_pk)) - - Information.mapper.add_property('rels', relation(Relation.mapper)) - - info = Information.mapper.get(1) - assert info - assert len(info.rels) == 2 - assert len(info.rels[0].datas) == 3 - - def testtwo(self): - """same thing, but reversing the order of the cols in the join""" - class Information(object): - pass - - class Relation(object): - pass - - class Data(object): - pass - - # Create the basic mappers, with no frills or modifications - Information.mapper = mapper(Information, info_table) - Data.mapper = mapper(Data, data_table) - Relation.mapper = mapper(Relation, rel_table) - - Relation.mapper.add_property('datas', relation(Data.mapper, - primaryjoin=and_(Relation.c.info_pk==Data.c.info_pk, - Relation.c.start <= Data.c.timeval, - Relation.c.finish >= Data.c.timeval, - # Data.c.timeval >= Relation.c.start, - # Data.c.timeval <= Relation.c.finish - ), - foreignkey=Data.c.info_pk)) - - Information.mapper.add_property('rels', relation(Relation.mapper)) - - info = Information.mapper.get(1) + session = create_session() + + mapper(Data, data_table) + mapper(Relation, rel_table, properties={ + + 'datas': relation(Data, + primaryjoin=and_(rel_table.c.info_pk==Data.c.info_pk, + Data.c.timeval >= rel_table.c.start, + Data.c.timeval <= rel_table.c.finish), + foreignkey=Data.c.info_pk) + } + + ) + mapper(Information, info_table, properties={ + 'rels': relation(Relation) + }) + + info = session.query(Information).get(1) assert info assert len(info.rels) == 2 assert len(info.rels[0].datas) == 3 - if __name__ == "__main__": testbase.main() diff --git a/test/legacy_objectstore.py b/test/legacy_objectstore.py new file mode 100644 index 000000000..3aa99a1ae --- /dev/null +++ b/test/legacy_objectstore.py @@ -0,0 +1,113 @@ +from testbase import PersistTest, AssertMixin +import unittest, sys, os +from sqlalchemy import * +import StringIO +import testbase + +from tables import * +import tables + +install_mods('legacy_session') + + +class LegacySessionTest(AssertMixin): + def setUpAll(self): + db.echo = False + users.create() + db.echo = testbase.echo + def tearDownAll(self): + db.echo = False + users.drop() + db.echo = testbase.echo + def setUp(self): + objectstore.get_session().clear() + clear_mappers() + tables.user_data() + #db.echo = "debug" + def tearDown(self): + tables.delete_user_data() + + def test_nested_begin_commit(self): + """tests that nesting objectstore transactions with multiple commits + affects only the outermost transaction""" + class User(object):pass + m = mapper(User, users) + def name_of(id): + return users.select(users.c.user_id == id).execute().fetchone().user_name + name1 = "Oliver Twist" + name2 = 'Mr. Bumble' + self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) + self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + s = objectstore.get_session() + trans = s.begin() + trans2 = s.begin() + m.get(7).user_name = name1 + trans3 = s.begin() + m.get(8).user_name = name2 + trans3.commit() + s.commit() # should do nothing + self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) + self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + trans2.commit() + s.commit() # should do nothing + self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) + self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + trans.commit() + self.assert_(name_of(7) == name1, msg="user_name should be %s" % name1) + self.assert_(name_of(8) == name2, msg="user_name should be %s" % name2) + + def test_nested_rollback(self): + """tests that nesting objectstore transactions with a rollback inside + affects only the outermost transaction""" + class User(object):pass + m = mapper(User, users) + def name_of(id): + return users.select(users.c.user_id == id).execute().fetchone().user_name + name1 = "Oliver Twist" + name2 = 'Mr. Bumble' + self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) + self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + s = objectstore.get_session() + trans = s.begin() + trans2 = s.begin() + m.get(7).user_name = name1 + trans3 = s.begin() + m.get(8).user_name = name2 + trans3.rollback() + self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) + self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + trans2.commit() + self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) + self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + trans.commit() + self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) + self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + + def test_true_nested(self): + """tests creating a new Session inside a database transaction, in + conjunction with an engine-level nested transaction, which uses + a second connection in order to achieve a nested transaction that commits, inside + of another engine session that rolls back.""" +# testbase.db.echo='debug' + class User(object): + pass + testbase.db.begin() + try: + m = mapper(User, users) + name1 = "Oliver Twist" + name2 = 'Mr. Bumble' + m.get(7).user_name = name1 + s = objectstore.Session(nest_on=testbase.db) + m.using(s).get(8).user_name = name2 + s.commit() + objectstore.commit() + testbase.db.rollback() + except: + testbase.db.rollback() + raise + objectstore.clear() + self.assert_(m.get(8).user_name == name2) + self.assert_(m.get(7).user_name != name1) + +if __name__ == "__main__": + testbase.main() diff --git a/test/manytomany.py b/test/manytomany.py index 12f62dbdb..92b7efb26 100644 --- a/test/manytomany.py +++ b/test/manytomany.py @@ -1,5 +1,5 @@ -from sqlalchemy import * import testbase +from sqlalchemy import * import string import sqlalchemy.attributes as attr @@ -28,21 +28,22 @@ class Transition(object): class M2MTest(testbase.AssertMixin): def setUpAll(self): - db = testbase.db + self.install_threadlocal() + metadata = testbase.metadata global place - place = Table('place', db, + place = Table('place', metadata, Column('place_id', Integer, Sequence('pid_seq', optional=True), primary_key=True), Column('name', String(30), nullable=False), ) global transition - transition = Table('transition', db, + transition = Table('transition', metadata, Column('transition_id', Integer, Sequence('tid_seq', optional=True), primary_key=True), Column('name', String(30), nullable=False), ) global place_thingy - place_thingy = Table('place_thingy', db, + place_thingy = Table('place_thingy', metadata, Column('thingy_id', Integer, Sequence('thid_seq', optional=True), primary_key=True), Column('place_id', Integer, ForeignKey('place.place_id'), nullable=False), Column('name', String(30), nullable=False) @@ -50,20 +51,20 @@ class M2MTest(testbase.AssertMixin): # association table #1 global place_input - place_input = Table('place_input', db, + place_input = Table('place_input', metadata, Column('place_id', Integer, ForeignKey('place.place_id')), Column('transition_id', Integer, ForeignKey('transition.transition_id')), ) # association table #2 global place_output - place_output = Table('place_output', db, + place_output = Table('place_output', metadata, Column('place_id', Integer, ForeignKey('place.place_id')), Column('transition_id', Integer, ForeignKey('transition.transition_id')), ) global place_place - place_place = Table('place_place', db, + place_place = Table('place_place', metadata, Column('pl1_id', Integer, ForeignKey('place.place_id')), Column('pl2_id', Integer, ForeignKey('place.place_id')), ) @@ -83,7 +84,8 @@ class M2MTest(testbase.AssertMixin): place.drop() transition.drop() #testbase.db.tables.clear() - + self.uninstall_threadlocal() + def setUp(self): objectstore.clear() clear_mappers() @@ -140,7 +142,7 @@ class M2MTest(testbase.AssertMixin): pp = p.places self.echo("Place " + str(p) +" places " + repr(pp)) - objectstore.delete(p1,p2,p3,p4,p5,p6,p7) + [objectstore.delete(p) for p in p1,p2,p3,p4,p5,p6,p7] objectstore.flush() def testdouble(self): @@ -152,8 +154,8 @@ class M2MTest(testbase.AssertMixin): }) Transition.mapper = mapper(Transition, transition, properties = dict( - inputs = relation(Place.mapper, place_output, lazy=False, selectalias='op_alias'), - outputs = relation(Place.mapper, place_input, lazy=False, selectalias='ip_alias'), + inputs = relation(Place.mapper, place_output, lazy=False), + outputs = relation(Place.mapper, place_input, lazy=False), ) ) @@ -161,7 +163,7 @@ class M2MTest(testbase.AssertMixin): tran.inputs.append(Place('place1')) tran.outputs.append(Place('place2')) tran.outputs.append(Place('place3')) - objectstore.commit() + objectstore.flush() objectstore.clear() r = Transition.mapper.select() @@ -201,20 +203,21 @@ class M2MTest(testbase.AssertMixin): p3.inputs.append(t2) p1.outputs.append(t1) - objectstore.commit() + objectstore.flush() self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])}) self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])}) class M2MTest2(testbase.AssertMixin): def setUpAll(self): - db = testbase.db + self.install_threadlocal() + metadata = testbase.metadata global studentTbl - studentTbl = Table('student', db, Column('name', String(20), primary_key=True)) + studentTbl = Table('student', metadata, Column('name', String(20), primary_key=True)) global courseTbl - courseTbl = Table('course', db, Column('name', String(20), primary_key=True)) + courseTbl = Table('course', metadata, Column('name', String(20), primary_key=True)) global enrolTbl - enrolTbl = Table('enrol', db, + enrolTbl = Table('enrol', metadata, Column('student_id', String(20), ForeignKey('student.name'),primary_key=True), Column('course_id', String(20), ForeignKey('course.name'), primary_key=True)) @@ -227,7 +230,8 @@ class M2MTest2(testbase.AssertMixin): studentTbl.drop() courseTbl.drop() #testbase.db.tables.clear() - + self.uninstall_threadlocal() + def setUp(self): objectstore.clear() clear_mappers() @@ -258,7 +262,7 @@ class M2MTest2(testbase.AssertMixin): c3.students.append(s1) self.assert_(len(s1.courses) == 3) self.assert_(len(c1.students) == 1) - objectstore.commit() + objectstore.flush() objectstore.clear() s = Student.mapper.get_by(name='Student1') c = Course.mapper.get_by(name='Course3') @@ -267,65 +271,66 @@ class M2MTest2(testbase.AssertMixin): self.assert_(len(s.courses) == 2) class M2MTest3(testbase.AssertMixin): - def setUpAll(self): - e = testbase.db - global c, c2a1, c2a2, b, a - c = Table('c', e, - Column('c1', Integer, primary_key = True), - Column('c2', String(20)), - ).create() - - a = Table('a', e, - Column('a1', Integer, primary_key=True), - Column('a2', String(20)), - Column('c1', Integer, ForeignKey('c.c1')) - ).create() - - c2a1 = Table('ctoaone', e, - Column('c1', Integer, ForeignKey('c.c1')), - Column('a1', Integer, ForeignKey('a.a1')) - ).create() - c2a2 = Table('ctoatwo', e, - Column('c1', Integer, ForeignKey('c.c1')), - Column('a1', Integer, ForeignKey('a.a1')) - ).create() - - b = Table('b', e, - Column('b1', Integer, primary_key=True), - Column('a1', Integer, ForeignKey('a.a1')), - Column('b2', Boolean) - ).create() - - def tearDownAll(self): - b.drop() - c2a2.drop() - c2a1.drop() - a.drop() - c.drop() - #testbase.db.tables.clear() - - def testbasic(self): - class C(object):pass - class A(object):pass - class B(object):pass + def setUpAll(self): + self.install_threadlocal() + metadata = testbase.metadata + global c, c2a1, c2a2, b, a + c = Table('c', metadata, + Column('c1', Integer, primary_key = True), + Column('c2', String(20)), + ).create() + + a = Table('a', metadata, + Column('a1', Integer, primary_key=True), + Column('a2', String(20)), + Column('c1', Integer, ForeignKey('c.c1')) + ).create() + + c2a1 = Table('ctoaone', metadata, + Column('c1', Integer, ForeignKey('c.c1')), + Column('a1', Integer, ForeignKey('a.a1')) + ).create() + c2a2 = Table('ctoatwo', metadata, + Column('c1', Integer, ForeignKey('c.c1')), + Column('a1', Integer, ForeignKey('a.a1')) + ).create() + + b = Table('b', metadata, + Column('b1', Integer, primary_key=True), + Column('a1', Integer, ForeignKey('a.a1')), + Column('b2', Boolean) + ).create() - assign_mapper(B, b) + def tearDownAll(self): + b.drop() + c2a2.drop() + c2a1.drop() + a.drop() + c.drop() + #testbase.db.tables.clear() + self.uninstall_threadlocal() + + def testbasic(self): + class C(object):pass + class A(object):pass + class B(object):pass - assign_mapper(A, a, - properties = { - 'tbs' : relation(B, primaryjoin=and_(b.c.a1==a.c.a1, b.c.b2 == True), lazy=False), - } - ) + assign_mapper(B, b) - assign_mapper(C, c, - properties = { - 'a1s' : relation(A, secondary=c2a1, lazy=False), - 'a2s' : relation(A, secondary=c2a2, lazy=False) - } - ) + assign_mapper(A, a, + properties = { + 'tbs' : relation(B, primaryjoin=and_(b.c.a1==a.c.a1, b.c.b2 == True), lazy=False), + } + ) - o1 = C.get(1) + assign_mapper(C, c, + properties = { + 'a1s' : relation(A, secondary=c2a1, lazy=False), + 'a2s' : relation(A, secondary=c2a2, lazy=False) + } + ) + o1 = C.get(1) if __name__ == "__main__": diff --git a/test/mapper.py b/test/mapper.py index 28982f756..546ed1a87 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -2,7 +2,7 @@ from testbase import PersistTest, AssertMixin import testbase import unittest, sys, os from sqlalchemy import * - +import sqlalchemy.exceptions as exceptions from tables import * import tables @@ -68,35 +68,38 @@ class MapperSuperTest(AssertMixin): tables.drop() db.echo = testbase.echo def tearDown(self): - objectstore.clear() clear_mappers() def setUp(self): pass class MapperTest(MapperSuperTest): def testget(self): - m = mapper(User, users) - self.assert_(m.get(19) is None) - u = m.get(7) - u2 = m.get(7) + s = create_session() + mapper(User, users) + self.assert_(s.get(User, 19) is None) + u = s.get(User, 7) + u2 = s.get(User, 7) self.assert_(u is u2) - objectstore.clear() - u2 = m.get(7) + s.clear() + u2 = s.get(User, 7) self.assert_(u is not u2) def testrefresh(self): - m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))}) - u = m.get(7) + mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))}) + s = create_session() + u = s.get(User, 7) u.user_name = 'foo' a = Address() + import sqlalchemy.orm.session + assert sqlalchemy.orm.session.object_session(a) is None u.addresses.append(a) self.assert_(a in u.addresses) - objectstore.refresh(u) + s.refresh(u) # its refreshed, so not dirty - self.assert_(u not in objectstore.get_session().uow.dirty) + self.assert_(u not in s.dirty) # username is back to the DB self.assert_(u.user_name == 'jack') @@ -106,10 +109,10 @@ class MapperTest(MapperSuperTest): u.user_name = 'foo' u.addresses.append(a) # now its dirty - self.assert_(u in objectstore.get_session().uow.dirty) + self.assert_(u in s.dirty) self.assert_(u.user_name == 'foo') self.assert_(a in u.addresses) - objectstore.expire(u) + s.expire(u) # get the attribute, it refreshes self.assert_(u.user_name == 'jack') @@ -117,41 +120,22 @@ class MapperTest(MapperSuperTest): def testrefresh_lazy(self): """tests that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems""" - m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))}) - m2 = m.options(lazyload('addresses')) - u = m2.selectfirst(users.c.user_id==8) + s = create_session() + mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))}) + q2 = s.query(User).options(lazyload('addresses')) + u = q2.selectfirst(users.c.user_id==8) def go(): - objectstore.refresh(u) + s.refresh(u) self.assert_sql_count(db, go, 1) - def testexpire_eager(self): - """tests that an eager load will populate expire()'d objects""" - m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))}) - [u1, u2, u3] = m.select(users.c.user_id.in_(7, 8, 9)) - self.echo([repr(x.addresses) for x in [u1, u2, u3]]) - [objectstore.expire(u) for u in [u1, u2, u3]] - m2 = m.options(eagerload('addresses')) - l = m2.select(users.c.user_id.in_(7,8,9)) - def go(): - u1.addresses - u2.addresses - u3.addresses - self.assert_sql_count(db, go, 0) - - def testsessionpropigation(self): - sess = objectstore.Session() - m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=True)}) - u = m.using(sess).get(7) - assert objectstore.get_session(u) is sess - assert objectstore.get_session(u.addresses[0]) is sess - def testexpire(self): - m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)}) - u = m.get(7) + s = create_session() + mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)}) + u = s.get(User, 7) assert(len(u.addresses) == 1) u.user_name = 'foo' del u.addresses[0] - objectstore.expire(u) + s.expire(u) # test plain expire self.assert_(u.user_name =='jack') self.assert_(len(u.addresses) == 1) @@ -159,14 +143,14 @@ class MapperTest(MapperSuperTest): # we're changing the database here, so if this test fails in the middle, # it'll screw up the other tests which are hardcoded to 7/'jack' u.user_name = 'foo' - objectstore.commit() + s.flush() # change the value in the DB users.update(users.c.user_id==7, values=dict(user_name='jack')).execute() - objectstore.expire(u) + s.expire(u) # object isnt refreshed yet, using dict to bypass trigger self.assert_(u.__dict__['user_name'] != 'jack') # do a select - m.select() + s.query(User).select() # test that it refreshed self.assert_(u.__dict__['user_name'] == 'jack') @@ -175,41 +159,43 @@ class MapperTest(MapperSuperTest): self.assert_(u.user_name =='jack') def testrefresh2(self): - assign_mapper(Address, addresses) + s = create_session() + mapper(Address, addresses) - assign_mapper(User, users, properties = dict(addresses=relation(Address.mapper,private=True,lazy=False)) ) + mapper(User, users, properties = dict(addresses=relation(Address,private=True,lazy=False)) ) u=User() u.user_name='Justin' a = Address() a.address_id=17 # to work around the hardcoded IDs in this test suite.... u.addresses.append(a) - objectstore.commit() - objectstore.clear() - u = User.mapper.selectfirst() + s.flush() + s.clear() + u = s.query(User).selectfirst() print u.user_name #ok so far - u.expire() #hangs when + s.expire(u) #hangs when print u.user_name #this line runs - u.refresh() #hangs + s.refresh(u) #hangs def testmagic(self): - m = mapper(User, users, properties = { + mapper(User, users, properties = { 'addresses' : relation(mapper(Address, addresses)) }) - l = m.select_by(user_name='fred') + sess = create_session() + l = sess.query(User).select_by(user_name='fred') self.assert_result(l, User, *[{'user_id':9}]) u = l[0] - u2 = m.get_by_user_name('fred') + u2 = sess.query(User).get_by_user_name('fred') self.assert_(u is u2) - l = m.select_by(email_address='ed@bettyboop.com') + l = sess.query(User).select_by(email_address='ed@bettyboop.com') self.assert_result(l, User, *[{'user_id':8}]) - l = m.select_by(User.c.user_name=='fred', addresses.c.email_address!='ed@bettyboop.com', user_id=9) + l = sess.query(User).select_by(User.c.user_name=='fred', addresses.c.email_address!='ed@bettyboop.com', user_id=9) def testprops(self): """tests the various attributes of the properties attached to classes""" @@ -220,46 +206,69 @@ class MapperTest(MapperSuperTest): def testload(self): """tests loading rows with a mapper and producing object instances""" - m = mapper(User, users) - l = m.select() + mapper(User, users) + l = create_session().query(User).select() self.assert_result(l, User, *user_result) - l = m.select(users.c.user_name.endswith('ed')) + l = create_session().query(User).select(users.c.user_name.endswith('ed')) self.assert_result(l, User, *user_result[1:3]) + def testjoinvia(self): + m = mapper(User, users, properties={ + 'orders':relation(mapper(Order, orders, properties={ + 'items':relation(mapper(Item, orderitems)) + })) + }) + + q = create_session().query(m) + + l = q.select((orderitems.c.item_name=='item 4') & q.join_via(['orders', 'items'])) + self.assert_result(l, User, user_result[0]) + + l = q.select_by(item_name='item 4') + self.assert_result(l, User, user_result[0]) + + l = q.select((orderitems.c.item_name=='item 4') & q.join_to('item_name')) + self.assert_result(l, User, user_result[0]) + + l = q.select((orderitems.c.item_name=='item 4') & q.join_to('items')) + self.assert_result(l, User, user_result[0]) + def testorderby(self): # TODO: make a unit test out of these various combinations # m = mapper(User, users, order_by=desc(users.c.user_name)) - m = mapper(User, users, order_by=None) -# m = mapper(User, users) + mapper(User, users, order_by=None) +# mapper(User, users) -# l = m.select(order_by=[desc(users.c.user_name), asc(users.c.user_id)]) - l = m.select() -# l = m.select(order_by=[]) -# l = m.select(order_by=None) +# l = create_session().query(User).select(order_by=[desc(users.c.user_name), asc(users.c.user_id)]) + l = create_session().query(User).select() +# l = create_session().query(User).select(order_by=[]) +# l = create_session().query(User).select(order_by=None) def testfunction(self): """tests mapping to a SELECT statement that has functions in it.""" s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')], users.c.user_id==addresses.c.user_id, group_by=[c for c in users.c]).alias('myselect') - m = mapper(User, s, primarytable=users) - print [c.key for c in m.c] - l = m.select() + mapper(User, s) + sess = create_session() + l = sess.query(User).select() for u in l: print "User", u.user_id, u.user_name, u.concat, u.count - #l[1].user_name='asdf' - #objectstore.commit() - + assert l[0].concat == l[0].user_id * 2 == 14 + assert l[1].concat == l[1].user_id * 2 == 16 + def testcount(self): - m = mapper(User, users) - self.assert_(m.count()==3) - self.assert_(m.count(users.c.user_id.in_(8,9))==2) - self.assert_(m.count_by(user_name='fred')==1) + mapper(User, users) + q = create_session().query(User) + self.assert_(q.count()==3) + self.assert_(q.count(users.c.user_id.in_(8,9))==2) + self.assert_(q.count_by(user_name='fred')==1) def testmultitable(self): usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) - m = mapper(User, usersaddresses, primarytable = users, primary_key=[users.c.user_id]) - l = m.select() + m = mapper(User, usersaddresses, primary_key=[users.c.user_id]) + q = create_session().query(m) + l = q.select() self.assert_result(l, User, *user_result[0:2]) def testoverride(self): @@ -269,14 +278,16 @@ class MapperTest(MapperSuperTest): 'user_name' : relation(mapper(Address, addresses)), }) self.assert_(False, "should have raised ArgumentError") - except ArgumentError, e: + except exceptions.ArgumentError, e: self.assert_(True) - + + clear_mappers() # assert that allow_column_override cancels the error m = mapper(User, users, properties = { 'user_name' : relation(mapper(Address, addresses)) }, allow_column_override=True) + clear_mappers() # assert that the column being named else where also cancels the error m = mapper(User, users, properties = { 'user_name' : relation(mapper(Address, addresses)), @@ -285,27 +296,50 @@ class MapperTest(MapperSuperTest): def testeageroptions(self): """tests that a lazy relation can be upgraded to an eager relation via the options method""" - m = mapper(User, users, properties = dict( + sess = create_session() + mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True) )) - l = m.options(eagerload('addresses')).select() + l = sess.query(User).options(eagerload('addresses')).select() def go(): self.assert_result(l, User, *user_address_result) self.assert_sql_count(db, go, 0) + def testeagerdegrade(self): + """tests that an eager relation automatically degrades to a lazy relation if eager columns are not available""" + sess = create_session() + usermapper = mapper(User, users, properties = dict( + addresses = relation(mapper(Address, addresses), lazy = False) + )) + + # first test straight eager load, 1 statement + def go(): + l = usermapper.query(sess).select() + self.assert_result(l, User, *user_address_result) + self.assert_sql_count(db, go, 1) + + # then select just from users. run it into instances. + # then assert the data, which will launch 3 more lazy loads + def go(): + r = users.select().execute() + l = usermapper.instances(r, sess) + self.assert_result(l, User, *user_address_result) + self.assert_sql_count(db, go, 4) + def testlazyoptions(self): """tests that an eager relation can be upgraded to a lazy relation via the options method""" - m = mapper(User, users, properties = dict( + sess = create_session() + mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False) )) - l = m.options(lazyload('addresses')).select() + l = sess.query(User).options(lazyload('addresses')).select() def go(): self.assert_result(l, User, *user_address_result) self.assert_sql_count(db, go, 3) def testdeepoptions(self): - m = mapper(User, users, + mapper(User, users, properties = { 'orders': relation(mapper(Order, orders, properties = { 'items' : relation(mapper(Item, orderitems, properties = { @@ -314,55 +348,43 @@ class MapperTest(MapperSuperTest): })) }) - m2 = m.options(eagerload('orders.items.keywords')) - u = m.select() + sess = create_session() + q2 = sess.query(User).options(eagerload('orders.items.keywords')) + u = sess.query(User).select() def go(): print u[0].orders[1].items[0].keywords[1] self.assert_sql_count(db, go, 3) - objectstore.clear() - u = m2.select() + sess.clear() + u = q2.select() self.assert_sql_count(db, go, 2) -class PropertyTest(MapperSuperTest): - def testbasic(self): - """tests that you can create mappers inline with class definitions""" - class _Address(object): - pass - assign_mapper(_Address, addresses) - - class _User(object): - pass - assign_mapper(_User, users, properties = dict( - addresses = relation(_Address.mapper, lazy = False) - ), is_primary = True) - - l = _User.mapper.select(_User.c.user_name == 'fred') - self.echo(repr(l)) - +class InheritanceTest(MapperSuperTest): def testinherits(self): class _Order(object): pass - assign_mapper(_Order, orders) + ordermapper = mapper(_Order, orders) class _User(object): pass - assign_mapper(_User, users, properties = dict( - orders = relation(_Order.mapper, lazy = False) + usermapper = mapper(_User, users, properties = dict( + orders = relation(ordermapper, lazy = False) )) class AddressUser(_User): pass - assign_mapper(AddressUser, addresses, inherits = _User.mapper) - - l = AddressUser.mapper.select() + mapper(AddressUser, addresses, inherits = usermapper) + + sess = create_session() + q = sess.query(AddressUser) + l = q.select() jack = l[0] self.assert_(jack.user_name=='jack') jack.email_address = 'jack@gmail.com' - objectstore.commit() - objectstore.clear() - au = AddressUser.mapper.get_by(user_name='jack') + sess.flush() + sess.clear() + au = q.get_by(user_name='jack') self.assert_(au.email_address == 'jack@gmail.com') def testinherits2(self): @@ -372,19 +394,20 @@ class PropertyTest(MapperSuperTest): pass class AddressUser(_Address): pass - assign_mapper(_Order, orders) - assign_mapper(_Address, addresses) - assign_mapper(AddressUser, users, inherits = _Address.mapper, + ordermapper = mapper(_Order, orders) + addressmapper = mapper(_Address, addresses) + usermapper = mapper(AddressUser, users, inherits = addressmapper, properties = { - 'orders' : relation(_Order.mapper, lazy=False) + 'orders' : relation(ordermapper, lazy=False) }) - l = AddressUser.mapper.select() + sess = create_session() + l = sess.query(usermapper).select() jack = l[0] self.assert_(jack.user_name=='jack') jack.email_address = 'jack@gmail.com' - objectstore.commit() - objectstore.clear() - au = AddressUser.mapper.get_by(user_name='jack') + sess.flush() + sess.clear() + au = sess.query(usermapper).get_by(user_name='jack') self.assert_(au.email_address == 'jack@gmail.com') @@ -400,13 +423,15 @@ class DeferredTest(MapperSuperTest): o = Order() self.assert_(o.description is None) + q = create_session().query(m) def go(): - l = m.select() + l = q.select() o2 = l[2] print o2.description + orderby = str(orders.default_order_by()[0].compile(engine=db)) self.assert_sql(db, go, [ - ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}), + ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) ]) @@ -415,10 +440,12 @@ class DeferredTest(MapperSuperTest): 'description':deferred(orders.c.description) }) - l = m.select() + sess = create_session() + q = sess.query(m) + l = q.select() o2 = l[2] o2.isopen = 1 - objectstore.commit() + sess.flush() def testgroup(self): """tests deferred load with a group""" @@ -428,36 +455,43 @@ class DeferredTest(MapperSuperTest): 'description':deferred(orders.c.description, group='primary'), 'opened':deferred(orders.c.isopen, group='primary') }) - + q = create_session().query(m) def go(): - l = m.select() + l = q.select() o2 = l[2] print o2.opened, o2.description, o2.userident + + orderby = str(orders.default_order_by()[0].compile(db)) self.assert_sql(db, go, [ - ("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}), + ("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}), ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) ]) def testoptions(self): """tests using options on a mapper to create deferred and undeferred columns""" m = mapper(Order, orders) - m2 = m.options(defer('user_id')) + sess = create_session() + q = sess.query(m) + q2 = q.options(defer('user_id')) def go(): - l = m2.select() + l = q2.select() print l[2].user_id + + orderby = str(orders.default_order_by()[0].compile(db)) self.assert_sql(db, go, [ - ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}), + ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), ("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) ]) - objectstore.clear() - m3 = m2.options(undefer('user_id')) + sess.clear() + q3 = q2.options(undefer('user_id')) def go(): - l = m3.select() + l = q3.select() print l[3].user_id self.assert_sql(db, go, [ - ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}), + ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), ]) + def testdeepoptions(self): m = mapper(User, users, properties={ 'orders':relation(mapper(Order, orders, properties={ @@ -466,15 +500,17 @@ class DeferredTest(MapperSuperTest): })) })) }) - l = m.select() + sess = create_session() + q = sess.query(m) + l = q.select() item = l[0].orders[1].items[1] def go(): print item.item_name self.assert_sql_count(db, go, 1) self.assert_(item.item_name == 'item 4') - objectstore.clear() - m2 = m.options(undefer('orders.items.item_name')) - l = m2.select() + sess.clear() + q2 = q.options(undefer('orders.items.item_name')) + l = q2.select() item = l[0].orders[1].items[1] def go(): print item.item_name @@ -489,7 +525,8 @@ class LazyTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True) )) - l = m.select(users.c.user_id == 7) + q = create_session().query(m) + l = q.select(users.c.user_id == 7) self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, ) @@ -500,7 +537,8 @@ class LazyTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(m, lazy = True, order_by=addresses.c.email_address), )) - l = m.select() + q = create_session().query(m) + l = q.select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, @@ -508,13 +546,29 @@ class LazyTest(MapperSuperTest): {'user_id' : 9, 'addresses' : (Address, [])} ) + def testorderby_select(self): + """tests that a regular mapper select on a single table can order by a relation to a second table""" + m = mapper(Address, addresses) + + m = mapper(User, users, properties = dict( + addresses = relation(m, lazy = True), + )) + q = create_session().query(m) + l = q.select(users.c.user_id==addresses.c.user_id, order_by=addresses.c.email_address) + + self.assert_result(l, User, + {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@wood.com'}, {'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, ])}, + {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, + ) + def testorderby_desc(self): m = mapper(Address, addresses) m = mapper(User, users, properties = dict( addresses = relation(m, lazy = True, order_by=[desc(addresses.c.email_address)]), )) - l = m.select() + q = create_session().query(m) + l = q.select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, @@ -531,35 +585,39 @@ class LazyTest(MapperSuperTest): addresses = relation(mapper(Address, addresses), lazy = True), orders = relation(ordermapper, primaryjoin = users.c.user_id==orders.c.user_id, lazy = True), )) - l = m.select(limit=2, offset=1) + sess= create_session() + q = sess.query(m) + l = q.select(limit=2, offset=1) self.assert_result(l, User, *user_all_result[1:3]) # use a union all to get a lot of rows to join against u2 = users.alias('u2') s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') print [key for key in s.c.keys()] - l = m.select(s.c.u2_user_id==User.c.user_id, distinct=True) + l = q.select(s.c.u2_user_id==User.c.user_id, distinct=True) self.assert_result(l, User, *user_all_result) - objectstore.clear() + sess.clear() m = mapper(Item, orderitems, is_primary=True, properties = dict( keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True), )) - l = m.select((Item.c.item_name=='item 2') | (Item.c.item_name=='item 5') | (Item.c.item_name=='item 3'), order_by=[Item.c.item_id], limit=2) + + l = sess.query(m).select((Item.c.item_name=='item 2') | (Item.c.item_name=='item 5') | (Item.c.item_name=='item 3'), order_by=[Item.c.item_id], limit=2) self.assert_result(l, Item, *[item_keyword_result[1], item_keyword_result[2]]) def testonetoone(self): m = mapper(User, users, properties = dict( address = relation(mapper(Address, addresses), lazy = True, uselist = False) )) - l = m.select(users.c.user_id == 7) - self.echo(repr(l)) - self.echo(repr(l[0].address)) + q = create_session().query(m) + l = q.select(users.c.user_id == 7) + self.assert_result(l, User, {'user_id':7, 'address' : (Address, {'address_id':1})}) def testbackwardsonetoone(self): m = mapper(Address, addresses, properties = dict( - user = relation(mapper(User, users, properties = {'id':users.c.user_id}), lazy = True) + user = relation(mapper(User, users), lazy = True) )) - l = m.select(addresses.c.address_id == 1) + q = create_session().query(m) + l = q.select(addresses.c.address_id == 1) self.echo(repr(l)) print repr(l[0].__dict__) self.echo(repr(l[0].user)) @@ -572,10 +630,11 @@ class LazyTest(MapperSuperTest): closedorders = alias(orders, 'closedorders') m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False), - open_orders = relation(mapper(Order, openorders), primaryjoin = and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id), lazy = True), - closed_orders = relation(mapper(Order, closedorders), primaryjoin = and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id), lazy = True) + open_orders = relation(mapper(Order, openorders, entity_name='open'), primaryjoin = and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id), lazy = True), + closed_orders = relation(mapper(Order, closedorders,entity_name='closed'), primaryjoin = and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id), lazy = True) )) - l = m.select() + q = create_session().query(m) + l = q.select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}]), @@ -596,10 +655,11 @@ class LazyTest(MapperSuperTest): def testmanytomany(self): """tests a many-to-many lazy load""" - assign_mapper(Item, orderitems, properties = dict( + mapper(Item, orderitems, properties = dict( keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True), )) - l = Item.mapper.select() + q = create_session().query(Item) + l = q.select() self.assert_result(l, Item, {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])}, @@ -607,7 +667,7 @@ class LazyTest(MapperSuperTest): {'item_id' : 4, 'keywords' : (Keyword, [])}, {'item_id' : 5, 'keywords' : (Keyword, [])} ) - l = Item.mapper.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, Item.c.item_id==itemkeywords.c.item_id)) + l = q.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, Item.c.item_id==itemkeywords.c.item_id)) self.assert_result(l, Item, {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])}, @@ -616,13 +676,13 @@ class LazyTest(MapperSuperTest): class EagerTest(MapperSuperTest): def testbasic(self): """tests a basic one-to-many eager load""" - m = mapper(Address, addresses) m = mapper(User, users, properties = dict( addresses = relation(m, lazy = False), )) - l = m.select() + q = create_session().query(m) + l = q.select() self.assert_result(l, User, *user_address_result) def testorderby(self): @@ -631,7 +691,8 @@ class EagerTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(m, lazy = False, order_by=addresses.c.email_address), )) - l = m.select() + q = create_session().query(m) + l = q.select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])}, @@ -644,7 +705,8 @@ class EagerTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(m, lazy = False, order_by=[desc(addresses.c.email_address)]), )) - l = m.select() + q = create_session().query(m) + l = q.select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, @@ -661,20 +723,24 @@ class EagerTest(MapperSuperTest): addresses = relation(mapper(Address, addresses), lazy = False), orders = relation(ordermapper, primaryjoin = users.c.user_id==orders.c.user_id, lazy = False), )) - l = m.select(limit=2, offset=1) + sess = create_session() + q = sess.query(m) + + l = q.select(limit=2, offset=1) self.assert_result(l, User, *user_all_result[1:3]) # this is an involved 3x union of the users table to get a lot of rows. # then see if the "distinct" works its way out. you actually get the same # result with or without the distinct, just via less or more rows. u2 = users.alias('u2') s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') - l = m.select(s.c.u2_user_id==User.c.user_id, distinct=True) + l = q.select(s.c.u2_user_id==User.c.user_id, distinct=True) self.assert_result(l, User, *user_all_result) - objectstore.clear() + sess.clear() m = mapper(Item, orderitems, is_primary=True, properties = dict( keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = False, order_by=[keywords.c.keyword_id]), )) - l = m.select((Item.c.item_name=='item 2') | (Item.c.item_name=='item 5') | (Item.c.item_name=='item 3'), order_by=[Item.c.item_id], limit=2) + q = sess.query(m) + l = q.select((Item.c.item_name=='item 2') | (Item.c.item_name=='item 5') | (Item.c.item_name=='item 3'), order_by=[Item.c.item_id], limit=2) self.assert_result(l, Item, *[item_keyword_result[1], item_keyword_result[2]]) @@ -683,7 +749,8 @@ class EagerTest(MapperSuperTest): m = mapper(User, users, properties = dict( address = relation(mapper(Address, addresses), lazy = False, uselist = False) )) - l = m.select(users.c.user_id == 7) + q = create_session().query(m) + l = q.select(users.c.user_id == 7) self.assert_result(l, User, {'user_id' : 7, 'address' : (Address, {'address_id' : 1, 'email_address': 'jack@bean.com'})}, ) @@ -693,7 +760,8 @@ class EagerTest(MapperSuperTest): user = relation(mapper(User, users), lazy = False) )) self.echo(repr(m.props['user'].uselist)) - l = m.select(addresses.c.address_id == 1) + q = create_session().query(m) + l = q.select(addresses.c.address_id == 1) self.assert_result(l, Address, {'address_id' : 1, 'email_address' : 'jack@bean.com', 'user' : (User, {'user_id' : 7, 'user_name' : 'jack'}) @@ -707,7 +775,8 @@ class EagerTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), primaryjoin = users.c.user_id==addresses.c.user_id, lazy = False) )) - l = m.select(and_(addresses.c.email_address == 'ed@lala.com', addresses.c.user_id==users.c.user_id)) + q = create_session().query(m) + l = q.select(and_(addresses.c.email_address == 'ed@lala.com', addresses.c.user_id==users.c.user_id)) self.assert_result(l, User, {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2, 'email_address':'ed@wood.com'}, {'address_id':3, 'email_address':'ed@bettyboop.com'}, {'address_id':4, 'email_address':'ed@lala.com'}])}, ) @@ -715,6 +784,7 @@ class EagerTest(MapperSuperTest): def testcompile(self): """tests deferred operation of a pre-compiled mapper statement""" + session = create_session() m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False) )) @@ -722,7 +792,7 @@ class EagerTest(MapperSuperTest): c = s.compile() self.echo("\n" + str(c) + repr(c.get_params())) - l = m.instances(s.execute(emailad = 'jack@bean.com')) + l = m.instances(s.execute(emailad = 'jack@bean.com'), session) self.echo(repr(l)) def testmulti(self): @@ -731,7 +801,8 @@ class EagerTest(MapperSuperTest): addresses = relation(mapper(Address, addresses), primaryjoin = users.c.user_id==addresses.c.user_id, lazy = False), orders = relation(mapper(Order, orders), lazy = False), )) - l = m.select() + q = create_session().query(m) + l = q.select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}]), @@ -754,10 +825,11 @@ class EagerTest(MapperSuperTest): ordermapper = mapper(Order, orders) m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False), - open_orders = relation(mapper(Order, openorders), primaryjoin = and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id), lazy = False), - closed_orders = relation(mapper(Order, closedorders), primaryjoin = and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id), lazy = False) + open_orders = relation(mapper(Order, openorders, non_primary=True), primaryjoin = and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id), lazy = False), + closed_orders = relation(mapper(Order, closedorders, non_primary=True), primaryjoin = and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id), lazy = False) )) - l = m.select() + q = create_session().query(m) + l = q.select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}]), @@ -787,7 +859,8 @@ class EagerTest(MapperSuperTest): addresses = relation(mapper(Address, addresses), lazy = False), orders = relation(ordermapper, primaryjoin = users.c.user_id==orders.c.user_id, lazy = False), )) - l = m.select() + q = create_session().query(m) + l = q.select() self.assert_result(l, User, *user_all_result) def testmanytomany(self): @@ -796,16 +869,37 @@ class EagerTest(MapperSuperTest): m = mapper(Item, items, properties = dict( keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy=False, order_by=[keywords.c.keyword_id]), )) - l = m.select() + q = create_session().query(m) + l = q.select() self.assert_result(l, Item, *item_keyword_result) -# l = m.select() - l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id)) + l = q.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id)) self.assert_result(l, Item, {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])}, ) + def testmanytomanyoptions(self): + items = orderitems + + m = mapper(Item, items, properties = dict( + keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy=True, order_by=[keywords.c.keyword_id]), + )) + m2 = m.options(eagerload('keywords')) + q = create_session().query(m2) + def go(): + l = q.select() + self.assert_result(l, Item, *item_keyword_result) + self.assert_sql_count(db, go, 1) + + def go(): + l = q.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id)) + self.assert_result(l, Item, + {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, + {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])}, + ) + self.assert_sql_count(db, go, 1) + def testoneandmany(self): """tests eager load for a parent object with a child object that contains a many-to-many relationship to a third object.""" @@ -819,7 +913,8 @@ class EagerTest(MapperSuperTest): m = mapper(Order, orders, properties = dict( items = relation(m, lazy = False) )) - l = m.select("orders.order_id in (1,2,3)") + q = create_session().query(m) + l = q.select("orders.order_id in (1,2,3)") self.assert_result(l, Order, {'order_id' : 1, 'items': (Item, [])}, {'order_id' : 2, 'items': (Item, [ diff --git a/test/masscreate.py b/test/masscreate.py index 885c1f653..5f99bb6c1 100644 --- a/test/masscreate.py +++ b/test/masscreate.py @@ -16,7 +16,7 @@ attr_manager = AttributeManager() if manage_attributes: attr_manager.register_attribute(User, 'id', uselist=False) attr_manager.register_attribute(User, 'name', uselist=False) - attr_manager.register_attribute(User, 'addresses', uselist=True) + attr_manager.register_attribute(User, 'addresses', uselist=True, trackparent=True) attr_manager.register_attribute(Address, 'email', uselist=False) now = time.time() @@ -35,7 +35,7 @@ for i in range(0,130): a.email = 'foo@bar.com' u.addresses.append(a) # gc.collect() - print len(managed_attributes) +# print len(managed_attributes) # managed_attributes.clear() total = time.time() - now print "Total time", total diff --git a/test/massload.py b/test/massload.py index ab47415e8..d36746968 100644 --- a/test/massload.py +++ b/test/massload.py @@ -35,7 +35,7 @@ class LoadTest(AssertMixin): clear_mappers() for x in range(1,NUM/500+1): l = [] - for y in range(x*500-499, x*500 + 1): + for y in range(x*500-500, x*500): l.append({'item_id':y, 'value':'this is item #%d' % y}) items.insert().execute(*l) @@ -47,7 +47,7 @@ class LoadTest(AssertMixin): for x in range (1,NUM/100): # this is not needed with cpython which clears non-circular refs immediately #gc.collect() - l = m.select(items.c.item_id.between(x*100 - 99, x*100 )) + l = m.select(items.c.item_id.between(x*100 - 100, x*100 - 1)) assert len(l) == 100 print "loaded ", len(l), " items " # modifying each object will insure that the objects get placed in the "dirty" list @@ -56,10 +56,10 @@ class LoadTest(AssertMixin): a.value = 'changed...' assert len(objectstore.get_session().dirty) == len(l) assert len(objectstore.get_session().identity_map) == len(l) - #assert len(attributes.managed_attributes) == len(l) + assert len(attributes.managed_attributes) == len(l) print len(objectstore.get_session().dirty) print len(objectstore.get_session().identity_map) - objectstore.expunge(*l) + #objectstore.expunge(*l) if __name__ == "__main__": testbase.main() diff --git a/test/objectstore.py b/test/objectstore.py index 9ec6ca7e2..404f9ff94 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -3,12 +3,28 @@ import unittest, sys, os from sqlalchemy import * import StringIO import testbase - +from sqlalchemy.orm.mapper import global_extensions +from sqlalchemy.ext.sessioncontext import SessionContext +import sqlalchemy.ext.assignmapper as assignmapper from tables import * import tables -class HistoryTest(AssertMixin): +class SessionTest(AssertMixin): def setUpAll(self): + global ctx, assign_mapper + ctx = SessionContext(create_session) + def assign_mapper(*args, **kwargs): + return assignmapper.assign_mapper(ctx, *args, **kwargs) + global_extensions.append(ctx.mapper_extension) + def tearDownAll(self): + global_extensions.remove(ctx.mapper_extension) + def tearDown(self): + ctx.current.clear() + clear_mappers() + +class HistoryTest(SessionTest): + def setUpAll(self): + SessionTest.setUpAll(self) db.echo = False users.create() addresses.create() @@ -18,9 +34,7 @@ class HistoryTest(AssertMixin): addresses.drop() users.drop() db.echo = testbase.echo - def setUp(self): - objectstore.clear() - clear_mappers() + SessionTest.tearDownAll(self) def testattr(self): """tests the rolling back of scalar and list attributes. this kind of thing @@ -43,7 +57,7 @@ class HistoryTest(AssertMixin): self.assert_result([u], data[0], *data[1:]) self.echo(repr(u.addresses)) - objectstore.uow().rollback_object(u) + ctx.current.uow.rollback_object(u) data = [User, {'user_name' : None, 'addresses' : (Address, []) @@ -52,6 +66,7 @@ class HistoryTest(AssertMixin): self.assert_result([u], data[0], *data[1:]) def testbackref(self): + s = create_session() class User(object):pass class Address(object):pass am = mapper(Address, addresses) @@ -59,177 +74,82 @@ class HistoryTest(AssertMixin): addresses = relation(am, backref='user', lazy=False)) ) - u = User() - a = Address() + u = User(_sa_session=s) + a = Address(_sa_session=s) a.user = u #print repr(a.__class__._attribute_manager.get_history(a, 'user').added_items()) #print repr(u.addresses.added_items()) self.assert_(u.addresses == [a]) - objectstore.commit() + s.flush() - objectstore.clear() - u = m.select()[0] + s.clear() + u = s.query(m).select()[0] print u.addresses[0].user -class SessionTest(AssertMixin): - def setUpAll(self): - db.echo = False - users.create() - db.echo = testbase.echo - def tearDownAll(self): - db.echo = False - users.drop() - db.echo = testbase.echo - def setUp(self): - objectstore.get_session().clear() - clear_mappers() - tables.user_data() - #db.echo = "debug" - def tearDown(self): - tables.delete_user_data() - - def test_nested_begin_commit(self): - """tests that nesting objectstore transactions with multiple commits - affects only the outermost transaction""" - class User(object):pass - m = mapper(User, users) - def name_of(id): - return users.select(users.c.user_id == id).execute().fetchone().user_name - name1 = "Oliver Twist" - name2 = 'Mr. Bumble' - self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) - self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) - s = objectstore.get_session() - trans = s.begin() - trans2 = s.begin() - m.get(7).user_name = name1 - trans3 = s.begin() - m.get(8).user_name = name2 - trans3.commit() - s.commit() # should do nothing - self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) - self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) - trans2.commit() - s.commit() # should do nothing - self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) - self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) - trans.commit() - self.assert_(name_of(7) == name1, msg="user_name should be %s" % name1) - self.assert_(name_of(8) == name2, msg="user_name should be %s" % name2) - - def test_nested_rollback(self): - """tests that nesting objectstore transactions with a rollback inside - affects only the outermost transaction""" - class User(object):pass - m = mapper(User, users) - def name_of(id): - return users.select(users.c.user_id == id).execute().fetchone().user_name - name1 = "Oliver Twist" - name2 = 'Mr. Bumble' - self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) - self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) - s = objectstore.get_session() - trans = s.begin() - trans2 = s.begin() - m.get(7).user_name = name1 - trans3 = s.begin() - m.get(8).user_name = name2 - trans3.rollback() - self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) - self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) - trans2.commit() - self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) - self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) - trans.commit() - self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) - self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) - - @testbase.unsupported('sqlite') - def test_true_nested(self): - """tests creating a new Session inside a database transaction, in - conjunction with an engine-level nested transaction, which uses - a second connection in order to achieve a nested transaction that commits, inside - of another engine session that rolls back.""" -# testbase.db.echo='debug' - class User(object): - pass - testbase.db.begin() - try: - m = mapper(User, users) - name1 = "Oliver Twist" - name2 = 'Mr. Bumble' - m.get(7).user_name = name1 - s = objectstore.Session(nest_on=testbase.db) - m.using(s).get(8).user_name = name2 - s.commit() - objectstore.commit() - testbase.db.rollback() - except: - testbase.db.rollback() - raise - objectstore.clear() - self.assert_(m.get(8).user_name == name2) - self.assert_(m.get(7).user_name != name1) -class VersioningTest(AssertMixin): +class VersioningTest(SessionTest): def setUpAll(self): - objectstore.clear() + SessionTest.setUpAll(self) + ctx.current.clear() global version_table version_table = Table('version_test', db, - Column('id', Integer, primary_key=True), + Column('id', Integer, Sequence('version_test_seq'), primary_key=True ), Column('version_id', Integer, nullable=False), Column('value', String(40), nullable=False) ).create() def tearDownAll(self): version_table.drop() + SessionTest.tearDownAll(self) def tearDown(self): version_table.delete().execute() - objectstore.clear() - clear_mappers() + SessionTest.tearDown(self) @testbase.unsupported('mysql') def testbasic(self): + s = create_session() class Foo(object):pass assign_mapper(Foo, version_table, version_id_col=version_table.c.version_id) - f1 =Foo(value='f1') - f2 = Foo(value='f2') - objectstore.commit() + f1 =Foo(value='f1', _sa_session=s) + f2 = Foo(value='f2', _sa_session=s) + s.flush() f1.value='f1rev2' - objectstore.commit() - s = objectstore.Session() - f1_s = Foo.mapper.using(s).get(f1.id) + s.flush() + s2 = create_session() + f1_s = Foo.mapper.using(s2).get(f1.id) f1_s.value='f1rev3' - s.commit() + s2.flush() f1.value='f1rev3mine' success = False try: # a concurrent session has modified this, should throw # an exception - objectstore.commit() - except SQLAlchemyError: + s.flush() + except exceptions.SQLAlchemyError, e: + #print e success = True assert success - objectstore.clear() - f1 = Foo.mapper.get(f1.id) - f2 = Foo.mapper.get(f2.id) + s.clear() + f1 = s.query(Foo).get(f1.id) + f2 = s.query(Foo).get(f2.id) f1_s.value='f1rev4' - s.commit() + s2.flush() - objectstore.delete(f1, f2) + s.delete(f1, f2) success = False try: - objectstore.commit() - except SQLAlchemyError: + s.flush() + except exceptions.SQLAlchemyError, e: + #print e success = True assert success -class UnicodeTest(AssertMixin): +class UnicodeTest(SessionTest): def setUpAll(self): - objectstore.clear() + SessionTest.setUpAll(self) global uni_table uni_table = Table('uni_test', db, Column('id', Integer, primary_key=True), @@ -237,22 +157,25 @@ class UnicodeTest(AssertMixin): def tearDownAll(self): uni_table.drop() - uni_table.deregister() + SessionTest.tearDownAll(self) def testbasic(self): class Test(object): - pass - assign_mapper(Test, uni_table) + def __init__(self, id, txt): + self.id = id + self.txt = txt + mapper(Test, uni_table) txt = u"\u0160\u0110\u0106\u010c\u017d" t1 = Test(id=1, txt = txt) self.assert_(t1.txt == txt) - objectstore.commit() + ctx.current.flush() self.assert_(t1.txt == txt) -class PKTest(AssertMixin): +class PKTest(SessionTest): def setUpAll(self): + SessionTest.setUpAll(self) db.echo = False global table global table2 @@ -286,9 +209,7 @@ class PKTest(AssertMixin): table2.drop() table3.drop() db.echo = testbase.echo - def setUp(self): - objectstore.clear() - clear_mappers() + SessionTest.tearDownAll(self) @testbase.unsupported('sqlite') def testprimarykey(self): @@ -299,9 +220,9 @@ class PKTest(AssertMixin): e.name = 'entry1' e.value = 'this is entry 1' e.multi_rev = 2 - objectstore.commit() - objectstore.clear() - e2 = Entry.mapper.get(e.multi_id, 2) + ctx.current.flush() + ctx.current.clear() + e2 = Entry.mapper.get((e.multi_id, 2)) self.assert_(e is not e2 and e._instance_key == e2._instance_key) def testmanualpk(self): class Entry(object): @@ -311,7 +232,7 @@ class PKTest(AssertMixin): e.pk_col_1 = 'pk1' e.pk_col_2 = 'pk1_related' e.data = 'im the data' - objectstore.commit() + ctx.current.flush() def testkeypks(self): import datetime class Entity(object): @@ -322,11 +243,12 @@ class PKTest(AssertMixin): e.secondary = 'pk2' e.assigned = datetime.date.today() e.data = 'some more data' - objectstore.commit() + ctx.current.flush() -class PrivateAttrTest(AssertMixin): +class PrivateAttrTest(SessionTest): """tests various things to do with private=True mappers""" def setUpAll(self): + SessionTest.setUpAll(self) global a_table, b_table a_table = Table('a',testbase.db, Column('a_id', Integer, Sequence('next_a_id'), primary_key=True), @@ -340,9 +262,7 @@ class PrivateAttrTest(AssertMixin): def tearDownAll(self): b_table.drop() a_table.drop() - def setUp(self): - objectstore.clear() - clear_mappers() + SessionTest.tearDownAll(self) def testsinglecommit(self): """tests that a commit of a single object deletes private relationships""" @@ -350,8 +270,7 @@ class PrivateAttrTest(AssertMixin): class B(object):pass assign_mapper(B,b_table) - assign_mapper(A,a_table,properties= {'bs' : relation - (B.mapper,private=True)}) + assign_mapper(A,a_table,properties= {'bs' : relation(B.mapper,private=True)}) # create some objects a = A(data='a1') @@ -366,10 +285,10 @@ class PrivateAttrTest(AssertMixin): a.bs.append(b2) # inserts both A and Bs - objectstore.commit(a) + ctx.current.flush([a]) - objectstore.delete(a) - objectstore.commit(a) + ctx.current.delete(a) + ctx.current.flush([a]) assert b_table.count().scalar() == 0 @@ -386,23 +305,24 @@ class PrivateAttrTest(AssertMixin): a2 = A(data='testa2') b = B(data='testb') b.a = a1 - objectstore.commit() - objectstore.clear() - sess = objectstore.get_session() + ctx.current.flush() + ctx.current.clear() + sess = ctx.current a1 = A.mapper.get(a1.a_id) a2 = A.mapper.get(a2.a_id) assert a1.bs[0].a is a1 b = a1.bs[0] b.a = a2 assert b not in sess.deleted - objectstore.commit() + ctx.current.flush() assert b in sess.identity_map.values() -class DefaultTest(AssertMixin): +class DefaultTest(SessionTest): """tests that when saving objects whose table contains DefaultGenerators, either python-side, preexec or database-side, the newly saved instances receive all the default values either through a post-fetch or getting the pre-exec'ed defaults back from the engine.""" def setUpAll(self): + SessionTest.setUpAll(self) #db.echo = 'debug' use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite') @@ -423,6 +343,7 @@ class DefaultTest(AssertMixin): self.table.create() def tearDownAll(self): self.table.drop() + SessionTest.tearDownAll(self) def setUp(self): self.table = Table('default_test', db) def testinsert(self): @@ -433,7 +354,7 @@ class DefaultTest(AssertMixin): h3 = Hoho(hoho=self.althohoval, counter=12) h4 = Hoho() h5 = Hoho(foober='im the new foober') - objectstore.commit() + ctx.current.flush() self.assert_(h1.hoho==self.althohoval) self.assert_(h3.hoho==self.althohoval) self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval) @@ -441,7 +362,7 @@ class DefaultTest(AssertMixin): self.assert_(h1.counter == h4.counter==h5.counter==7) self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') self.assert_(h5.foober=='im the new foober') - objectstore.clear() + ctx.current.clear() l = Hoho.mapper.select() (h1, h2, h3, h4, h5) = l self.assert_(h1.hoho==self.althohoval) @@ -457,7 +378,7 @@ class DefaultTest(AssertMixin): class Hoho(object):pass assign_mapper(Hoho, self.table) h1 = Hoho(hoho="15", counter="15") - objectstore.commit() + ctx.current.flush() self.assert_(h1.hoho=="15") self.assert_(h1.counter=="15") self.assert_(h1.foober=="im foober") @@ -466,30 +387,27 @@ class DefaultTest(AssertMixin): class Hoho(object):pass assign_mapper(Hoho, self.table) h1 = Hoho() - objectstore.commit() + ctx.current.flush() self.assert_(h1.foober == 'im foober') h1.counter = 19 - objectstore.commit() + ctx.current.flush() self.assert_(h1.foober == 'im the update') -class SaveTest(AssertMixin): +class SaveTest(SessionTest): def setUpAll(self): + SessionTest.setUpAll(self) db.echo = False tables.create() db.echo = testbase.echo def tearDownAll(self): db.echo = False - db.commit() tables.drop() db.echo = testbase.echo + SessionTest.tearDownAll(self) def setUp(self): db.echo = False - # remove all history/identity maps etc. - objectstore.clear() - # remove all mapperes - clear_mappers() keywords.insert().execute( dict(name='blue'), dict(name='red'), @@ -499,32 +417,29 @@ class SaveTest(AssertMixin): dict(name='round'), dict(name='square') ) - db.commit() db.echo = testbase.echo def tearDown(self): db.echo = False - db.commit() tables.delete() db.echo = testbase.echo - self.assert_(len(objectstore.uow().new) == 0) - self.assert_(len(objectstore.uow().dirty) == 0) - self.assert_(len(objectstore.uow().modified_lists) == 0) - + #self.assert_(len(ctx.current.new) == 0) + #self.assert_(len(ctx.current.dirty) == 0) + SessionTest.tearDown(self) + def testbasic(self): # save two users u = User() u.user_name = 'savetester' - m = mapper(User, users) u2 = User() u2.user_name = 'savetester2' - objectstore.uow().register_new(u) + ctx.current.save(u) - objectstore.uow().commit(u) - objectstore.uow().commit() + ctx.current.flush([u]) + ctx.current.flush() # assert the first one retreives the same from the identity map nu = m.get(u.user_id) @@ -532,18 +447,20 @@ class SaveTest(AssertMixin): self.assert_(u is nu) # clear out the identity map, so next get forces a SELECT - objectstore.clear() + ctx.current.clear() # check it again, identity should be different but ids the same nu = m.get(u.user_id) self.assert_(u is not nu and u.user_id == nu.user_id and nu.user_name == 'savetester') # change first users name and save + ctx.current.update(u) u.user_name = 'modifiedname' - objectstore.uow().commit() + assert u in ctx.current.dirty + ctx.current.flush() # select both - #objectstore.clear() + #ctx.current.clear() userlist = m.select(users.c.user_id.in_(u.user_id, u2.user_id), order_by=[users.c.user_name]) print repr(u.user_id), repr(userlist[0].user_id), repr(userlist[0].user_name) self.assert_(u.user_id == userlist[0].user_id and userlist[0].user_name == 'modifiedname') @@ -561,12 +478,12 @@ class SaveTest(AssertMixin): u.addresses.append(Address()) u.addresses.append(Address()) u.addresses.append(Address()) - objectstore.commit() - objectstore.clear() + ctx.current.flush() + ctx.current.clear() ulist = m1.select() u1 = ulist[0] u1.user_name = 'newname' - objectstore.commit() + ctx.current.flush() self.assert_(len(u1.addresses) == 4) def testinherits(self): @@ -583,8 +500,8 @@ class SaveTest(AssertMixin): ) au = AddressUser() - objectstore.commit() - objectstore.clear() + ctx.current.flush() + ctx.current.clear() l = AddressUser.mapper.selectone() self.assert_(l.user_id == au.user_id and l.address_id == au.address_id) @@ -592,9 +509,9 @@ class SaveTest(AssertMixin): """tests a save of an object where each instance spans two tables. also tests redefinition of the keynames for the column properties.""" usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) - print usersaddresses._get_col_by_original(users.c.user_id) + print usersaddresses.corresponding_column(users.c.user_id) print repr(usersaddresses._orig_cols) - m = mapper(User, usersaddresses, primarytable = users, + m = mapper(User, usersaddresses, properties = dict( email = addresses.c.email_address, foo_id = [users.c.user_id, addresses.c.user_id], @@ -605,7 +522,7 @@ class SaveTest(AssertMixin): u.user_name = 'multitester' u.email = 'multi@test.org' - objectstore.uow().commit() + ctx.current.flush() usertable = users.select(users.c.user_id.in_(u.foo_id)).execute().fetchall() self.assertEqual(usertable[0].values(), [u.foo_id, 'multitester']) @@ -614,7 +531,7 @@ class SaveTest(AssertMixin): u.email = 'lala@hey.com' u.user_name = 'imnew' - objectstore.uow().commit() + ctx.current.flush() usertable = users.select(users.c.user_id.in_(u.foo_id)).execute().fetchall() self.assertEqual(usertable[0].values(), [u.foo_id, 'imnew']) @@ -632,12 +549,34 @@ class SaveTest(AssertMixin): u.user_name = 'one2onetester' u.address = Address() u.address.email_address = 'myonlyaddress@foo.com' - objectstore.uow().commit() + ctx.current.flush() u.user_name = 'imnew' - objectstore.uow().commit() + ctx.current.flush() u.address.email_address = 'imnew@foo.com' - objectstore.uow().commit() + ctx.current.flush() + def testchildmove(self): + """tests moving a child from one parent to the other, then deleting the first parent, properly + updates the child with the new parent. this tests the 'trackparent' option in the attributes module.""" + m = mapper(User, users, properties = dict( + addresses = relation(mapper(Address, addresses), lazy = True, private = False) + )) + u1 = User() + u1.user_name = 'user1' + u2 = User() + u2.user_name = 'user2' + a = Address() + a.email_address = 'address1' + u1.addresses.append(a) + ctx.current.flush() + del u1.addresses[0] + u2.addresses.append(a) + ctx.current.delete(u1) + ctx.current.flush() + ctx.current.clear() + u2 = m.get(u2.user_id) + assert len(u2.addresses) == 1 + def testdelete(self): m = mapper(User, users, properties = dict( address = relation(mapper(Address, addresses), lazy = True, uselist = False, private = False) @@ -647,89 +586,11 @@ class SaveTest(AssertMixin): u.user_name = 'one2onetester' u.address = a u.address.email_address = 'myonlyaddress@foo.com' - objectstore.uow().commit() + ctx.current.flush() self.echo("\n\n\n") - objectstore.uow().register_deleted(u) - objectstore.uow().commit() - self.assert_(a.address_id is not None and a.user_id is None and not objectstore.uow().identity_map.has_key(u._instance_key) and objectstore.uow().identity_map.has_key(a._instance_key)) - - def testcascadingdelete(self): - m = mapper(User, users, properties = dict( - address = relation(mapper(Address, addresses), lazy = False, uselist = False, private = True), - orders = relation( - mapper(Order, orders, properties = dict ( - items = relation(mapper(Item, orderitems), lazy = False, uselist =True, private = True) - )), - lazy = True, uselist = True, private = True) - )) - - data = [User, - {'user_name' : 'ed', - 'address' : (Address, {'email_address' : 'foo@bar.com'}), - 'orders' : (Order, [ - {'description' : 'eds 1st order', 'items' : (Item, [{'item_name' : 'eds o1 item'}, {'item_name' : 'eds other o1 item'}])}, - {'description' : 'eds 2nd order', 'items' : (Item, [{'item_name' : 'eds o2 item'}, {'item_name' : 'eds other o2 item'}])} - ]) - }, - {'user_name' : 'jack', - 'address' : (Address, {'email_address' : 'jack@jack.com'}), - 'orders' : (Order, [ - {'description' : 'jacks 1st order', 'items' : (Item, [{'item_name' : 'im a lumberjack'}, {'item_name' : 'and im ok'}])} - ]) - }, - {'user_name' : 'foo', - 'address' : (Address, {'email_address': 'hi@lala.com'}), - 'orders' : (Order, [ - {'description' : 'foo order', 'items' : (Item, [])}, - {'description' : 'foo order 2', 'items' : (Item, [{'item_name' : 'hi'}])}, - {'description' : 'foo order three', 'items' : (Item, [{'item_name' : 'there'}])} - ]) - } - ] - - for elem in data[1:]: - u = User() - u.user_name = elem['user_name'] - u.address = Address() - u.address.email_address = elem['address'][1]['email_address'] - u.orders = [] - for order in elem['orders'][1]: - o = Order() - o.isopen = None - o.description = order['description'] - u.orders.append(o) - o.items = [] - for item in order['items'][1]: - i = Item() - i.item_name = item['item_name'] - o.items.append(i) - - objectstore.uow().commit() - objectstore.clear() - - l = m.select() - for u in l: - self.echo( repr(u.orders)) - self.assert_result(l, data[0], *data[1:]) - - self.echo("\n\n\n") - objectstore.uow().register_deleted(l[0]) - objectstore.uow().register_deleted(l[2]) - objectstore.commit() - return - res = self.capture_exec(db, lambda: objectstore.uow().commit()) - state = None - - for line in res.split('\n'): - if line == "DELETE FROM items WHERE items.item_id = :item_id": - self.assert_(state is None or state == 'addresses') - elif line == "DELETE FROM orders WHERE orders.order_id = :order_id": - state = 'orders' - elif line == "DELETE FROM email_addresses WHERE email_addresses.address_id = :address_id": - if state is None: - state = 'addresses' - elif line == "DELETE FROM users WHERE users.user_id = :user_id": - self.assert_(state is not None) + ctx.current.delete(u) + ctx.current.flush() + self.assert_(a.address_id is not None and a.user_id is None and not ctx.current.identity_map.has_key(u._instance_key) and ctx.current.identity_map.has_key(a._instance_key)) def testbackwardsonetoone(self): # test 'backwards' @@ -755,37 +616,37 @@ class SaveTest(AssertMixin): a.user.user_name = elem['user_name'] objects.append(a) - objectstore.uow().commit() + ctx.current.flush() objects[2].email_address = 'imnew@foo.bar' objects[3].user = User() objects[3].user.user_name = 'imnewlyadded' - self.assert_sql(db, lambda: objectstore.uow().commit(), [ + self.assert_sql(db, lambda: ctx.current.flush(), [ ( "INSERT INTO users (user_name) VALUES (:user_name)", {'user_name': 'imnewlyadded'} ), { "UPDATE email_addresses SET email_address=:email_address WHERE email_addresses.address_id = :email_addresses_address_id": - lambda: [{'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id}] + lambda ctx: {'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id} , "UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id": - lambda: [{'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id}] + lambda ctx: {'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id} }, ], with_sequences=[ ( "INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)", - lambda:{'user_name': 'imnewlyadded', 'user_id':db.last_inserted_ids()[0]} + lambda ctx:{'user_name': 'imnewlyadded', 'user_id':ctx.last_inserted_ids()[0]} ), { "UPDATE email_addresses SET email_address=:email_address WHERE email_addresses.address_id = :email_addresses_address_id": - lambda: [{'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id}] + lambda ctx: {'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id} , "UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id": - lambda: [{'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id}] + lambda ctx: {'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id} }, ]) @@ -810,7 +671,7 @@ class SaveTest(AssertMixin): u.addresses.append(a2) self.echo( repr(u.addresses)) self.echo( repr(u.addresses.added_items())) - objectstore.uow().commit() + ctx.current.flush() usertable = users.select(users.c.user_id.in_(u.user_id)).execute().fetchall() self.assertEqual(usertable[0].values(), [u.user_id, 'one2manytester']) @@ -823,7 +684,7 @@ class SaveTest(AssertMixin): a2.email_address = 'somethingnew@foo.com' - objectstore.uow().commit() + ctx.current.flush() addresstable = addresses.select(addresses.c.address_id == addressid).execute().fetchall() @@ -837,7 +698,6 @@ class SaveTest(AssertMixin): dict(user_id = 8, user_name = 'ed'), dict(user_id = 9, user_name = 'fred') ) - db.commit() # mapper with just users table assign_mapper(User, users) @@ -853,7 +713,7 @@ class SaveTest(AssertMixin): u[0].addresses[0].email_address='hi' # insure that upon commit, the new mapper with the address relation is used - self.assert_sql(db, lambda: objectstore.commit(), + self.assert_sql(db, lambda: ctx.current.flush(), [ ( "INSERT INTO email_addresses (user_id, email_address) VALUES (:user_id, :email_address)", @@ -863,7 +723,7 @@ class SaveTest(AssertMixin): with_sequences=[ ( "INSERT INTO email_addresses (address_id, user_id, email_address) VALUES (:address_id, :user_id, :email_address)", - lambda:{'email_address': 'hi', 'user_id': 7, 'address_id':db.last_inserted_ids()[0]} + lambda ctx:{'email_address': 'hi', 'user_id': 7, 'address_id':ctx.last_inserted_ids()[0]} ), ] ) @@ -890,7 +750,7 @@ class SaveTest(AssertMixin): a3 = Address() a3.email_address = 'emailaddress3' - objectstore.commit() + ctx.current.flush() self.echo("\n\n\n") # modify user2 directly, append an address to user1. @@ -899,18 +759,18 @@ class SaveTest(AssertMixin): u2.user_name = 'user2modified' u1.addresses.append(a3) del u1.addresses[0] - self.assert_sql(db, lambda: objectstore.commit(), + self.assert_sql(db, lambda: ctx.current.flush(), [ ( "UPDATE users SET user_name=:user_name WHERE users.user_id = :users_user_id", - [{'users_user_id': u2.user_id, 'user_name': 'user2modified'}] + {'users_user_id': u2.user_id, 'user_name': 'user2modified'} ), ( "UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id", - [{'user_id': u1.user_id, 'email_addresses_address_id': a3.address_id}] + {'user_id': u1.user_id, 'email_addresses_address_id': a3.address_id} ), ("UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id", - [{'user_id': None, 'email_addresses_address_id': a1.address_id}] + {'user_id': None, 'email_addresses_address_id': a1.address_id} ) ]) @@ -924,12 +784,12 @@ class SaveTest(AssertMixin): u1.user_name='user1' a1.user = u1 - objectstore.commit() + ctx.current.flush() self.echo("\n\n\n") - objectstore.delete(u1) + ctx.current.delete(u1) a1.user = None - objectstore.commit() + ctx.current.flush() def _testalias(self): """tests that an alias of a table can be used in a mapper. @@ -971,12 +831,13 @@ class SaveTest(AssertMixin): def testmanytomany(self): items = orderitems + keywordmapper = mapper(Keyword, keywords) + items.select().execute() m = mapper(Item, items, properties = dict( - keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = False), + keywords = relation(keywordmapper, itemkeywords, lazy = False), )) - keywordmapper = mapper(Keyword, keywords) data = [Item, {'item_name': 'mm_item1', 'keywords' : (Keyword,[{'name': 'big'},{'name': 'green'}, {'name': 'purple'},{'name': 'round'}])}, @@ -1007,7 +868,7 @@ class SaveTest(AssertMixin): k.name = kname item.keywords.append(k) - objectstore.uow().commit() + ctx.current.flush() l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name, keywords.c.name]) self.assert_result(l, *data) @@ -1016,48 +877,48 @@ class SaveTest(AssertMixin): k = Keyword() k.name = 'yellow' objects[5].keywords.append(k) - self.assert_sql(db, lambda:objectstore.commit(), [ + self.assert_sql(db, lambda:ctx.current.flush(), [ { "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id": - [{'item_name': 'item4updated', 'items_item_id': objects[4].item_id}] + {'item_name': 'item4updated', 'items_item_id': objects[4].item_id} , "INSERT INTO keywords (name) VALUES (:name)": {'name': 'yellow'} }, ("INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)", - lambda: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}] + lambda ctx: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}] ) ], with_sequences = [ { "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id": - [{'item_name': 'item4updated', 'items_item_id': objects[4].item_id}] + {'item_name': 'item4updated', 'items_item_id': objects[4].item_id} , "INSERT INTO keywords (keyword_id, name) VALUES (:keyword_id, :name)": - lambda: {'name': 'yellow', 'keyword_id':db.last_inserted_ids()[0]} + lambda ctx: {'name': 'yellow', 'keyword_id':ctx.last_inserted_ids()[0]} }, ("INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)", - lambda: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}] + lambda ctx: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}] ) ] ) objects[2].keywords.append(k) dkid = objects[5].keywords[1].keyword_id del objects[5].keywords[1] - self.assert_sql(db, lambda:objectstore.commit(), [ + self.assert_sql(db, lambda:ctx.current.flush(), [ ( "DELETE FROM itemkeywords WHERE itemkeywords.item_id = :item_id AND itemkeywords.keyword_id = :keyword_id", [{'item_id': objects[5].item_id, 'keyword_id': dkid}] ), ( "INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)", - lambda: [{'item_id': objects[2].item_id, 'keyword_id': k.keyword_id}] + lambda ctx: [{'item_id': objects[2].item_id, 'keyword_id': k.keyword_id}] ) ]) - objectstore.delete(objects[3]) - objectstore.commit() + ctx.current.delete(objects[3]) + ctx.current.flush() def testassociation(self): class IKAssociation(object): @@ -1072,7 +933,7 @@ class SaveTest(AssertMixin): # the reorganization of mapper construction affected this, but was fixed again m = mapper(Item, items, properties = dict( keywords = relation(mapper(IKAssociation, itemkeywords, properties = dict( - keyword = relation(mapper(Keyword, keywords), lazy = False, uselist = False) + keyword = relation(mapper(Keyword, keywords, non_primary=True), lazy = False, uselist = False) ), primary_key = [itemkeywords.c.item_id, itemkeywords.c.keyword_id]), lazy = False) )) @@ -1117,8 +978,8 @@ class SaveTest(AssertMixin): ik.keyword = k item.keywords.append(ik) - objectstore.uow().commit() - objectstore.clear() + ctx.current.flush() + ctx.current.clear() l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name, keywords.c.name]) self.assert_result(l, *data) @@ -1136,7 +997,7 @@ class SaveTest(AssertMixin): a = Address() a.email_address = 'testaddress' a.user = u - objectstore.commit() + ctx.current.flush() print repr(u.addresses) x = False try: @@ -1148,8 +1009,8 @@ class SaveTest(AssertMixin): if x: self.assert_(False, "User addresses element should be scalar based") - objectstore.delete(u) - objectstore.commit() + ctx.current.delete(u) + ctx.current.flush() def testdoublerelation(self): m2 = mapper(Address, addresses) @@ -1169,13 +1030,13 @@ class SaveTest(AssertMixin): u.boston_addresses.append(a) u.newyork_addresses.append(b) - objectstore.commit() + ctx.current.flush() -class SaveTest2(AssertMixin): +class SaveTest2(SessionTest): def setUp(self): db.echo = False - objectstore.clear() + ctx.current.clear() clear_mappers() self.users = Table('users', db, Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), @@ -1201,6 +1062,7 @@ class SaveTest2(AssertMixin): self.addresses.drop() self.users.drop() db.echo = testbase.echo + SessionTest.tearDown(self) def testbackwardsnonmatch(self): m = mapper(Address, self.addresses, properties = dict( @@ -1217,7 +1079,7 @@ class SaveTest2(AssertMixin): a.user = User() a.user.user_name = elem['user_name'] objects.append(a) - self.assert_sql(db, lambda: objectstore.commit(), [ + self.assert_sql(db, lambda: ctx.current.flush(), [ ( "INSERT INTO users (user_name) VALUES (:user_name)", {'user_name': 'thesub'} @@ -1239,19 +1101,19 @@ class SaveTest2(AssertMixin): with_sequences = [ ( "INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)", - lambda: {'user_name': 'thesub', 'user_id':db.last_inserted_ids()[0]} + lambda ctx: {'user_name': 'thesub', 'user_id':ctx.last_inserted_ids()[0]} ), ( "INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)", - lambda: {'user_name': 'assdkfj', 'user_id':db.last_inserted_ids()[0]} + lambda ctx: {'user_name': 'assdkfj', 'user_id':ctx.last_inserted_ids()[0]} ), ( "INSERT INTO email_addresses (address_id, rel_user_id, email_address) VALUES (:address_id, :rel_user_id, :email_address)", - lambda:{'rel_user_id': 1, 'email_address': 'bar@foo.com', 'address_id':db.last_inserted_ids()[0]} + lambda ctx:{'rel_user_id': 1, 'email_address': 'bar@foo.com', 'address_id':ctx.last_inserted_ids()[0]} ), ( "INSERT INTO email_addresses (address_id, rel_user_id, email_address) VALUES (:address_id, :rel_user_id, :email_address)", - lambda:{'rel_user_id': 2, 'email_address': 'thesdf@asdf.com', 'address_id':db.last_inserted_ids()[0]} + lambda ctx:{'rel_user_id': 2, 'email_address': 'thesdf@asdf.com', 'address_id':ctx.last_inserted_ids()[0]} ) ] ) diff --git a/test/onetoone.py b/test/onetoone.py index 9ff330c92..5dc5b1204 100644 --- a/test/onetoone.py +++ b/test/onetoone.py @@ -1,5 +1,6 @@ from sqlalchemy import * import testbase +from sqlalchemy.ext.sessioncontext import SessionContext class Jack(object): def __repr__(self): @@ -23,8 +24,10 @@ class Port(object): class O2OTest(testbase.AssertMixin): def setUpAll(self): - global jack, port - jack = Table('jack', testbase.db, + global jack, port, metadata, ctx + metadata = BoundMetaData(testbase.db) + ctx = SessionContext(create_session) + jack = Table('jack', metadata, Column('id', Integer, primary_key=True), #Column('room_id', Integer, ForeignKey("room.id")), Column('number', String(50)), @@ -33,54 +36,54 @@ class O2OTest(testbase.AssertMixin): ) - port = Table('port', testbase.db, + port = Table('port', metadata, Column('id', Integer, primary_key=True), #Column('device_id', Integer, ForeignKey("device.id")), Column('name', String(30)), Column('description', String(100)), Column('jack_id', Integer, ForeignKey("jack.id")), ) - jack.create() - port.create() + metadata.create_all() def setUp(self): - objectstore.clear() + pass def tearDown(self): clear_mappers() def tearDownAll(self): - port.drop() - jack.drop() + metadata.drop_all() def test1(self): - assign_mapper(Port, port) - assign_mapper(Jack, jack, order_by=[jack.c.number],properties = { - 'port': relation(Port.mapper, backref='jack', uselist=False, lazy=True), - }) + mapper(Port, port, extension=ctx.mapper_extension) + mapper(Jack, jack, order_by=[jack.c.number],properties = { + 'port': relation(Port, backref='jack', uselist=False, lazy=True), + }, extension=ctx.mapper_extension) j=Jack(number='101') p=Port(name='fa0/1') j.port=p - objectstore.commit() + ctx.current.flush() jid = j.id pid = p.id - j=Jack.get(jid) - p=Port.get(pid) + j=ctx.current.query(Jack).get(jid) + p=ctx.current.query(Port).get(pid) print p.jack - print j.port + assert p.jack is not None + assert p.jack is j + assert j.port is not None p.jack=None assert j.port is None #works - objectstore.clear() + ctx.current.clear() - j=Jack.get(jid) - p=Port.get(pid) + j=ctx.current.query(Jack).get(jid) + p=ctx.current.query(Port).get(pid) j.port=None self.assert_(p.jack is None) - objectstore.commit() + ctx.current.flush() - j.delete() - objectstore.commit() + ctx.current.delete(j) + ctx.current.flush() if __name__ == "__main__": testbase.main() diff --git a/test/parseconnect.py b/test/parseconnect.py new file mode 100644 index 000000000..e1f50e8c9 --- /dev/null +++ b/test/parseconnect.py @@ -0,0 +1,29 @@ +from testbase import PersistTest +import sqlalchemy.engine.url as url +import unittest + +class ParseConnectTest(PersistTest): + def testrfc1738(self): + for text in ( + 'dbtype://username:password@hostspec:110//usr/db_file.db', + 'dbtype://username:password@hostspec/database', + 'dbtype://username:password@hostspec', + 'dbtype://username:password@/database', + 'dbtype://username@hostspec', + 'dbtype://username:password@127.0.0.1:1521', + 'dbtype://hostspec/database', + 'dbtype://hostspec', + 'dbtype:///database', + 'dbtype:///:memory:', + 'dbtype:///foo/bar/im/a/file', + 'dbtype:///E:/work/src/LEM/db/hello.db', + 'dbtype://' + ): + u = url.make_url(text) + print u, text + assert str(u) == text + + +if __name__ == "__main__": + unittest.main() +
\ No newline at end of file diff --git a/test/polymorph.py b/test/polymorph.py new file mode 100644 index 000000000..ec7e95a0a --- /dev/null +++ b/test/polymorph.py @@ -0,0 +1,169 @@ +import testbase +from sqlalchemy import * +import sets + +# test classes +class Person(object): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + def get_name(self): + try: + return getattr(self, 'person_name') + except AttributeError: + return getattr(self, 'name') + def __repr__(self): + return "Ordinary person %s" % self.get_name() +class Engineer(Person): + def __repr__(self): + return "Engineer %s, status %s, engineer_name %s, primary_language %s" % (self.get_name(), self.status, self.engineer_name, self.primary_language) +class Manager(Person): + def __repr__(self): + return "Manager %s, status %s, manager_name %s" % (self.get_name(), self.status, self.manager_name) +class Company(object): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + def __repr__(self): + return "Company %s" % self.name + +class MultipleTableTest(testbase.PersistTest): + def setUpAll(self, use_person_column=False): + global companies, people, engineers, managers, metadata + metadata = BoundMetaData(testbase.db) + + # a table to store companies + companies = Table('companies', metadata, + Column('company_id', Integer, primary_key=True), + Column('name', String(50))) + + # we will define an inheritance relationship between the table "people" and "engineers", + # and a second inheritance relationship between the table "people" and "managers" + people = Table('people', metadata, + Column('person_id', Integer, primary_key=True), + Column('company_id', Integer, ForeignKey('companies.company_id')), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('engineer_name', String(50)), + Column('primary_language', String(50)), + ) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('manager_name', String(50)) + ) + + metadata.create_all() + + def tearDownAll(self): + metadata.drop_all() + + def tearDown(self): + clear_mappers() + for t in metadata.table_iterator(reverse=True): + t.delete().execute() + + def test_f_f_f(self): + self.do_test(False, False, False) + def test_f_f_t(self): + self.do_test(False, False, True) + def test_f_t_f(self): + self.do_test(False, True, False) + def test_f_t_t(self): + self.do_test(False, True, True) + def test_t_f_f(self): + self.do_test(True, False, False) + def test_t_f_t(self): + self.do_test(True, False, True) + def test_t_t_f(self): + self.do_test(True, True, False) + def test_t_t_t(self): + self.do_test(True, True, True) + + + def do_test(self, include_base=False, lazy_relation=True, redefine_colprop=False): + """tests the polymorph.py example, with several options: + + include_base - whether or not to include the base 'person' type in the union. + lazy_relation - whether or not the Company relation to People is lazy or eager. + redefine_colprop - if we redefine the 'name' column to be 'people_name' on the base Person class + """ + # create a union that represents both types of joins. + if include_base: + person_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + 'person':people.select(people.c.type=='person'), + }, None, 'pjoin') + else: + person_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + }, None, 'pjoin') + + if redefine_colprop: + person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name}) + else: + person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person') + + mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') + mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') + + mapper(Company, companies, properties={ + 'employees': relation(Person, lazy=lazy_relation, private=True, backref='company') + }) + + if redefine_colprop: + person_attribute_name = 'person_name' + else: + person_attribute_name = 'name' + + session = create_session() + c = Company(name='company1') + c.employees.append(Manager(status='AAB', manager_name='manager1', **{person_attribute_name:'pointy haired boss'})) + c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', **{person_attribute_name:'dilbert'})) + if include_base: + c.employees.append(Person(status='HHH', **{person_attribute_name:'joesmith'})) + c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'})) + c.employees.append(Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'})) + session.save(c) + print session.new + session.flush() + session.clear() + id = c.company_id + c = session.query(Company).get(id) + for e in c.employees: + print e, e._instance_key, e.company + if include_base: + assert sets.Set([e.get_name() for e in c.employees]) == sets.Set(['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith']) + else: + assert sets.Set([e.get_name() for e in c.employees]) == sets.Set(['pointy haired boss', 'dilbert', 'wally', 'jsmith']) + print "\n" + + + dilbert = session.query(Person).selectfirst(person_join.c.name=='dilbert') + dilbert2 = session.query(Engineer).selectfirst(people.c.name=='dilbert') + assert dilbert is dilbert2 + + dilbert.engineer_name = 'hes dibert!' + + session.flush() + session.clear() + + c = session.query(Company).get(id) + for e in c.employees: + print e, e._instance_key + + session.delete(c) + session.flush() + +if __name__ == "__main__": + testbase.main() + diff --git a/test/pool.py b/test/pool.py index 2737a33b1..d8c984aa8 100644 --- a/test/pool.py +++ b/test/pool.py @@ -1,5 +1,5 @@ from testbase import PersistTest -import unittest, sys, os +import unittest, sys, os, time from pysqlite2 import dbapi2 as sqlite import sqlalchemy.pool as pool @@ -40,7 +40,14 @@ class PoolTest(PersistTest): self.assert_(connection.cursor() is not None) self.assert_(connection is not connection2) - def testqueuepool(self): + def testqueuepool_del(self): + self._do_testqueuepool(useclose=False) + + def testqueuepool_close(self): + self._do_testqueuepool(useclose=True) + + def _do_testqueuepool(self, useclose=False): + p = pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = False, echo = False) def status(pool): @@ -60,30 +67,73 @@ class PoolTest(PersistTest): self.assert_(status(p) == (3,0,2,5)) c6 = p.connect() self.assert_(status(p) == (3,0,3,6)) - c4 = c3 = c2 = None + if useclose: + c4.close() + c3.close() + c2.close() + else: + c4 = c3 = c2 = None self.assert_(status(p) == (3,3,3,3)) - c1 = c5 = c6 = None + if useclose: + c1.close() + c5.close() + c6.close() + else: + c1 = c5 = c6 = None self.assert_(status(p) == (3,3,0,0)) c1 = p.connect() c2 = p.connect() self.assert_(status(p) == (3, 1, 0, 2)) - c2 = None + if useclose: + c2.close() + else: + c2 = None self.assert_(status(p) == (3, 2, 0, 1)) - def testthreadlocal(self): + def test_timeout(self): + p = pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = 0, use_threadlocal = False, echo = False, timeout=2) + c1 = p.get() + c2 = p.get() + c3 = p.get() + now = time.time() + c4 = p.get() + assert int(time.time() - now) == 2 + + def testthreadlocal_del(self): + self._do_testthreadlocal(useclose=False) + + def testthreadlocal_close(self): + self._do_testthreadlocal(useclose=True) + + def _do_testthreadlocal(self, useclose=False): for p in ( pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True, echo = False), pool.SingletonThreadPool(creator = lambda: sqlite.connect('foo.db'), use_threadlocal = True) - ): + ): c1 = p.connect() c2 = p.connect() self.assert_(c1 is c2) c3 = p.unique_connection() self.assert_(c3 is not c1) - c2 = None + if useclose: + c2.close() + else: + c2 = None c2 = p.connect() self.assert_(c1 is c2) self.assert_(c3 is not c1) + if useclose: + c2.close() + else: + c2 = None + + if useclose: + c1 = p.connect() + c2 = p.connect() + c3 = p.connect() + c3.close() + c2.close() + self.assert_(c1.connection is not None) def tearDown(self): pool.clear_managers() diff --git a/test/proxy_engine.py b/test/proxy_engine.py index 2a2cebc5b..df0c64398 100644 --- a/test/proxy_engine.py +++ b/test/proxy_engine.py @@ -1,9 +1,10 @@ +import os + from sqlalchemy import * from sqlalchemy.ext.proxy import ProxyEngine from testbase import PersistTest import testbase -import os # # Define an engine, table and mapper at the module level, to show that the @@ -11,20 +12,31 @@ import os # -module_engine = ProxyEngine(echo=testbase.echo) -users = Table('users', module_engine, - Column('user_id', Integer, primary_key=True), - Column('user_name', String(16)), - Column('password', String(20)) - ) +class ProxyTestBase(PersistTest): + def setUpAll(self): + + global users, User, module_engine, module_metadata + + module_engine = ProxyEngine(echo=testbase.echo) + module_metadata = MetaData() -class User(object): - pass + users = Table('users', module_metadata, + Column('user_id', Integer, primary_key=True), + Column('user_name', String(16)), + Column('password', String(20)) + ) + class User(object): + pass -class ConstructTest(PersistTest): + User.mapper = mapper(User, users) + def tearDownAll(self): + clear_mappers() + +class ConstructTest(ProxyTestBase): """tests that we can build SQL constructs without engine-specific parameters, particulary oid_column, being needed, as the proxy engine is usually not connected yet.""" + def test_join(self): engine = ProxyEngine() t = Table('table1', engine, @@ -33,47 +45,46 @@ class ConstructTest(PersistTest): Column('col2', Integer, ForeignKey('table1.col1'))) j = join(t, t2) -class ProxyEngineTest1(PersistTest): - def setUp(self): - clear_mappers() - objectstore.clear() - +class ProxyEngineTest1(ProxyTestBase): + def test_engine_connect(self): # connect to a real engine module_engine.connect(testbase.db_uri) - users.create() - assign_mapper(User, users) + module_metadata.create_all(module_engine) + + session = create_session(bind_to=module_engine) try: - trans = objectstore.begin() user = User() user.user_name='fred' user.password='*' - trans.commit() + + session.save(user) + session.flush() + + query = session.query(User) # select - sqluser = User.select_by(user_name='fred')[0] + sqluser = query.select_by(user_name='fred')[0] assert sqluser.user_name == 'fred' # modify sqluser.user_name = 'fred jones' - # commit - saves everything that changed - objectstore.commit() + # flush - saves everything that changed + session.flush() - allusers = [ user.user_name for user in User.select() ] - assert allusers == [ 'fred jones' ] + allusers = [ user.user_name for user in query.select() ] + assert allusers == ['fred jones'] + finally: - users.drop() + module_metadata.drop_all(module_engine) + + +class ThreadProxyTest(ProxyTestBase): -class ThreadProxyTest(PersistTest): - def setUp(self): - assign_mapper(User, users) - def tearDown(self): - clear_mappers() def tearDownAll(self): - pass os.remove('threadtesta.db') os.remove('threadtestb.db') @@ -92,23 +103,26 @@ class ThreadProxyTest(PersistTest): try: module_engine.connect(db_uri) - users.create() + module_metadata.create_all(module_engine) try: - trans = objectstore.begin() + session = create_session(bind_to=module_engine) + + query = session.query(User) - all = User.select()[:] + all = list(query.select()) assert all == [] u = User() u.user_name = uname u.password = 'whatever' - trans.commit() - names = [ us.user_name for us in User.select() ] - assert names == [ uname ] + session.save(u) + session.flush() + + names = [u.user_name for u in query.select()] + assert names == [uname] finally: - users.drop() - module_engine.dispose() + module_metadata.drop_all(module_engine) except Exception, e: import traceback traceback.print_exc() @@ -119,8 +133,8 @@ class ThreadProxyTest(PersistTest): # NOTE: I'm not sure how to give the test runner the option to # override these uris, or how to safely clear them after test runs - a = Thread(target=run('sqlite://filename=threadtesta.db', 'jim', qa)) - b = Thread(target=run('sqlite://filename=threadtestb.db', 'joe', qb)) + a = Thread(target=run('sqlite:///threadtesta.db', 'jim', qa)) + b = Thread(target=run('sqlite:///threadtestb.db', 'joe', qb)) a.start() b.start() @@ -134,11 +148,8 @@ class ThreadProxyTest(PersistTest): if res != False: raise res -class ProxyEngineTest2(PersistTest): - def setUp(self): - clear_mappers() - objectstore.clear() +class ProxyEngineTest2(ProxyTestBase): def test_table_singleton_a(self): """set up for table singleton check @@ -153,8 +164,9 @@ class ProxyEngineTest2(PersistTest): Column('cat_name', String)) engine.connect(testbase.db_uri) - cats.create() - cats.drop() + + cats.create(engine) + cats.drop(engine) ProxyEngineTest2.cats_table_a = cats assert isinstance(cats, Table) @@ -179,141 +191,8 @@ class ProxyEngineTest2(PersistTest): # this will fail because the old reference's local storage will # not have the default attributes engine.connect(testbase.db_uri) - cats.create() - cats.drop() - - def test_type_engine_caching(self): - from sqlalchemy.engine import SQLEngine - import sqlalchemy.types as sqltypes - - class EngineA(SQLEngine): - def __init__(self): - pass - - def hash_key(self): - return 'a' - - def type_descriptor(self, typeobj): - if isinstance(typeobj, types.Integer): - return TypeEngineX2() - else: - return TypeEngineSTR() - - class EngineB(SQLEngine): - def __init__(self): - pass - - def hash_key(self): - return 'b' - - def type_descriptor(self, typeobj): - return TypeEngineMonkey() - - class TypeEngineX2(sqltypes.TypeEngine): - def convert_bind_param(self, value, engine): - return value * 2 - - class TypeEngineSTR(sqltypes.TypeEngine): - def convert_bind_param(self, value, engine): - return repr(str(value)) - - class TypeEngineMonkey(sqltypes.TypeEngine): - def convert_bind_param(self, value, engine): - return 'monkey' - - engine = ProxyEngine() - engine.storage.engine = EngineA() - - a = sqltypes.Integer().engine_impl(engine) - assert a.convert_bind_param(12, engine) == 24 - assert a.convert_bind_param([1,2,3], engine) == [1, 2, 3, 1, 2, 3] - - a2 = sqltypes.String().engine_impl(engine) - assert a2.convert_bind_param(12, engine) == "'12'" - assert a2.convert_bind_param([1,2,3], engine) == "'[1, 2, 3]'" - - engine.storage.engine = EngineB() - b = sqltypes.Integer().engine_impl(engine) - assert b.convert_bind_param(12, engine) == 'monkey' - assert b.convert_bind_param([1,2,3], engine) == 'monkey' - - - def test_type_engine_autoincrement(self): - engine = ProxyEngine() - dogs = Table('dogs', engine, - Column('dog_id', Integer, primary_key=True), - Column('breed', String), - Column('name', String)) - - class Dog(object): - pass - - assign_mapper(Dog, dogs) - - engine.connect(testbase.db_uri) - dogs.create() - try: - spot = Dog() - spot.breed = 'beagle' - spot.name = 'Spot' + cats.create(engine) + cats.drop(engine) - rover = Dog() - rover.breed = 'spaniel' - rover.name = 'Rover' - - objectstore.commit() - - assert spot.dog_id > 0, "Spot did not get an id" - assert rover.dog_id != spot.dog_id - finally: - dogs.drop() - - def test_type_proxy_schema_gen(self): - from sqlalchemy.databases.postgres import PGSchemaGenerator - - engine = ProxyEngine() - lizards = Table('lizards', engine, - Column('id', Integer, primary_key=True), - Column('name', String)) - - # this doesn't really CONNECT to pg, just establishes pg as the - # actual engine so that we can determine that it gets the right - # answer - engine.connect('postgres://database=test&port=5432&host=127.0.0.1&user=scott&password=tiger') - - sg = PGSchemaGenerator(engine) - id_spec = sg.get_column_specification(lizards.c.id) - assert id_spec == 'id SERIAL NOT NULL PRIMARY KEY' - - if __name__ == "__main__": testbase.main() - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/test/query.py b/test/query.py index 946180618..0c5164133 100644 --- a/test/query.py +++ b/test/query.py @@ -6,7 +6,6 @@ import sqlalchemy.databases.sqlite as sqllite import tables db = testbase.db -#db.echo='debug' from sqlalchemy import * from sqlalchemy.engine import ResultProxy, RowProxy @@ -103,8 +102,6 @@ class QueryTest(PersistTest): def testdelete(self): - c = db.connection() - self.users.insert().execute(user_id = 7, user_name = 'jack') self.users.insert().execute(user_id = 8, user_name = 'fred') print repr(self.users.select().execute().fetchall()) @@ -113,14 +110,6 @@ class QueryTest(PersistTest): print repr(self.users.select().execute().fetchall()) - def testtransaction(self): - def dostuff(): - self.users.insert().execute(user_id = 7, user_name = 'john') - self.users.insert().execute(user_id = 8, user_name = 'jack') - - db.transaction(dostuff) - print repr(self.users.select().execute().fetchall()) - def testselectlimit(self): self.users.insert().execute(user_id=1, user_name='john') self.users.insert().execute(user_id=2, user_name='jack') @@ -158,10 +147,13 @@ class QueryTest(PersistTest): self.users.insert().execute(user_id=1, user_name='foo') r = self.users.select().execute().fetchone() self.assertEqual(len(r), 2) + r.close() r = db.execute('select user_name, user_id from query_users', {}).fetchone() self.assertEqual(len(r), 2) + r.close() r = db.execute('select user_name from query_users', {}).fetchone() self.assertEqual(len(r), 1) + r.close() def test_column_order_with_simple_query(self): # should return values in column definition order @@ -180,11 +172,9 @@ class QueryTest(PersistTest): self.assertEqual(r[1], 1) self.assertEqual(r.keys(), ['user_name', 'user_id']) self.assertEqual(r.values(), ['foo', 1]) - + + @testbase.unsupported('oracle') def test_column_accessor_shadow(self): - if db.engine.__module__.endswith('oracle'): - return - shadowed = Table('test_shadowed', db, Column('shadow_id', INT, primary_key = True), Column('shadow_name', VARCHAR(20)), @@ -209,6 +199,7 @@ class QueryTest(PersistTest): self.fail('Should not allow access to private attributes') except AttributeError: pass # expected + r.close() finally: shadowed.drop() diff --git a/test/reflection.py b/test/reflection.py index 718957add..20a5fd90c 100644 --- a/test/reflection.py +++ b/test/reflection.py @@ -1,8 +1,6 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.databases.postgres as postgres -import sqlalchemy.databases.oracle as oracle -import sqlalchemy.databases.sqlite as sqllite from sqlalchemy import * @@ -14,7 +12,7 @@ class ReflectionTest(PersistTest): def testbasic(self): # really trip it up with a circular reference - use_function_defaults = testbase.db.engine.__module__.endswith('postgres') or testbase.db.engine.__module__.endswith('oracle') + use_function_defaults = testbase.db.engine.name == 'postgres' or testbase.db.engine.name == 'oracle' use_string_defaults = use_function_defaults or testbase.db.engine.__module__.endswith('sqlite') @@ -123,9 +121,10 @@ class ReflectionTest(PersistTest): table.drop() def testtoengine(self): - db = ansisql.engine() + meta = MetaData('md1') + meta2 = MetaData('md2') - table = Table('mytable', db, + table = Table('mytable', meta, Column('myid', Integer, key = 'id'), Column('name', String, key = 'name', nullable=False), Column('description', String, key = 'description'), @@ -133,14 +132,14 @@ class ReflectionTest(PersistTest): print repr(table) - pgdb = postgres.engine({}) + table2 = table.tometadata(meta2) - pgtable = table.toengine(pgdb) + print repr(table2) - print repr(pgtable) - assert pgtable.c.id.nullable - assert not pgtable.c.name.nullable - assert pgtable.c.description.nullable + assert table is not table2 + assert table2.c.id.nullable + assert not table2.c.name.nullable + assert table2.c.description.nullable def testoverride(self): table = Table( @@ -165,6 +164,58 @@ class ReflectionTest(PersistTest): self.assert_(isinstance(table.c.col4.type, String)) finally: table.drop() + +class CreateDropTest(PersistTest): + def setUpAll(self): + global metadata + metadata = MetaData() + users = Table('users', metadata, + Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), + Column('user_name', String(40)), + ) + + addresses = Table('email_addresses', metadata, + Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True), + Column('user_id', Integer, ForeignKey(users.c.user_id)), + Column('email_address', String(40)), + + ) + + orders = Table('orders', metadata, + Column('order_id', Integer, Sequence('order_id_seq', optional=True), primary_key = True), + Column('user_id', Integer, ForeignKey(users.c.user_id)), + Column('description', String(50)), + Column('isopen', Integer), + + ) + + orderitems = Table('items', metadata, + Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True), + Column('order_id', INT, ForeignKey("orders")), + Column('item_name', VARCHAR(50)), + + ) + + def test_sorter( self ): + tables = metadata._sort_tables(metadata.tables.values()) + table_names = [t.name for t in tables] + self.assert_( table_names == ['users', 'orders', 'items', 'email_addresses'] or table_names == ['users', 'email_addresses', 'orders', 'items']) + + + def test_createdrop(self): + metadata.create_all(engine=testbase.db) + self.assertEqual( testbase.db.has_table('items'), True ) + self.assertEqual( testbase.db.has_table('email_addresses'), True ) + metadata.create_all(engine=testbase.db) + self.assertEqual( testbase.db.has_table('items'), True ) + + metadata.drop_all(engine=testbase.db) + self.assertEqual( testbase.db.has_table('items'), False ) + self.assertEqual( testbase.db.has_table('email_addresses'), False ) + metadata.drop_all(engine=testbase.db) + self.assertEqual( testbase.db.has_table('items'), False ) + + if __name__ == "__main__": testbase.main() diff --git a/test/relationships.py b/test/relationships.py index 84a45c2dd..d9a9d6e50 100644 --- a/test/relationships.py +++ b/test/relationships.py @@ -22,33 +22,34 @@ class RelationTest(testbase.PersistTest): global tbl_b global tbl_c global tbl_d - tbl_a = Table("tbl_a", db, + metadata = MetaData() + tbl_a = Table("tbl_a", metadata, Column("id", Integer, primary_key=True), Column("name", String), ) - tbl_b = Table("tbl_b", db, + tbl_b = Table("tbl_b", metadata, Column("id", Integer, primary_key=True), Column("name", String), ) - tbl_c = Table("tbl_c", db, + tbl_c = Table("tbl_c", metadata, Column("id", Integer, primary_key=True), Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False), Column("name", String), ) - tbl_d = Table("tbl_d", db, + tbl_d = Table("tbl_d", metadata, Column("id", Integer, primary_key=True), Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False), Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")), Column("name", String), ) def setUp(self): - tbl_a.create() - tbl_b.create() - tbl_c.create() - tbl_d.create() - - objectstore.clear() - clear_mappers() + global session + session = create_session(bind_to=testbase.db) + conn = session.connect() + conn.create(tbl_a) + conn.create(tbl_b) + conn.create(tbl_c) + conn.create(tbl_d) class A(object): pass @@ -75,30 +76,31 @@ class RelationTest(testbase.PersistTest): b = B(); b.name = "b1" c = C(); c.name = "c1"; c.a_row = a # we must have more than one d row or it won't fail - d = D(); d.name = "d1"; d.b_row = b; d.c_row = c - d = D(); d.name = "d2"; d.b_row = b; d.c_row = c - d = D(); d.name = "d3"; d.b_row = b; d.c_row = c - + d1 = D(); d1.name = "d1"; d1.b_row = b; d1.c_row = c + d2 = D(); d2.name = "d2"; d2.b_row = b; d2.c_row = c + d3 = D(); d3.name = "d3"; d3.b_row = b; d3.c_row = c + session.save_or_update(a) + session.save_or_update(b) + def tearDown(self): - tbl_d.drop() - tbl_c.drop() - tbl_b.drop() - tbl_a.drop() + conn = session.connect() + conn.drop(tbl_d) + conn.drop(tbl_c) + conn.drop(tbl_b) + conn.drop(tbl_a) def tearDownAll(self): - testbase.db.tables.clear() + testbase.metadata.tables.clear() def testDeleteRootTable(self): - session = objectstore.get_session() - session.commit() + session.flush() session.delete(a) # works as expected - session.commit() - + session.flush() + def testDeleteMiddleTable(self): - session = objectstore.get_session() - session.commit() + session.flush() session.delete(c) # fails - session.commit() + session.flush() if __name__ == "__main__": diff --git a/test/select.py b/test/select.py index fb136cfec..0fc3ca60f 100644 --- a/test/select.py +++ b/test/select.py @@ -1,14 +1,6 @@ from sqlalchemy import * -import sqlalchemy.ansisql as ansisql -import sqlalchemy.databases.postgres as postgres -import sqlalchemy.databases.oracle as oracle -import sqlalchemy.databases.sqlite as sqlite -import sqlalchemy.databases.mysql as mysql - -db = ansisql.engine() -#db = create_engine('mssql') - +from sqlalchemy.databases import sqlite, postgres, mysql, oracle from testbase import PersistTest import unittest, re @@ -34,8 +26,9 @@ table3 = table( column('otherstuff'), ) +metadata = MetaData() table4 = Table( - 'remotetable', db, + 'remotetable', metadata, Column('rem_id', Integer, primary_key=True), Column('datatype_id', Integer), Column('value', String(20)), @@ -58,8 +51,8 @@ addresses = table('addresses', ) class SQLTest(PersistTest): - def runtest(self, clause, result, engine = None, params = None, checkparams = None): - c = clause.compile(parameters=params, engine=engine) + def runtest(self, clause, result, dialect = None, params = None, checkparams = None): + c = clause.compile(parameters=params, dialect=dialect) self.echo("\nSQL String:\n" + str(c) + repr(c.get_params())) cc = re.sub(r'\n', '', str(c)) self.assert_(cc == result, str(c) + "\n does not match \n" + result) @@ -80,7 +73,6 @@ myothertable.othername FROM mytable, myothertable") """tests placing select statements in the column clause of another select, for the purposes of selecting from the exported columns of that select.""" s = select([table1], table1.c.name == 'jack') - #print [key for key in s.c.keys()] self.runtest( select( [s], @@ -151,7 +143,6 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A self.runtest( select([users, s.c.street], from_obj=[s]), """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""") - def testcolumnsubquery(self): s = select([table1.c.myid], scalar=True, correlate=False) @@ -213,7 +204,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A ) self.runtest( - literal("a") + literal("b") * literal("c"), ":literal + (:liter_1 * :liter_2)", db + literal("a") + literal("b") * literal("c"), ":literal + (:liter_1 * :liter_2)" ) def testmultiparam(self): @@ -234,9 +225,9 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A ) def testoraclelimit(self): - e = create_engine('oracle') - users = Table('users', e, Column('name', String(10), key='username')) - self.runtest(select([users.c.username], limit=5), "SELECT name FROM (SELECT users.name AS name, ROW_NUMBER() OVER (ORDER BY users.rowid ASC) AS ora_rn FROM users) WHERE ora_rn<=5", engine=e) + metadata = MetaData() + users = Table('users', metadata, Column('name', String(10), key='username')) + self.runtest(select([users.c.username], limit=5), "SELECT name FROM (SELECT users.name AS name, ROW_NUMBER() OVER (ORDER BY users.rowid) AS ora_rn FROM users) WHERE ora_rn<=5", dialect=oracle.dialect()) def testgroupby_and_orderby(self): self.runtest( @@ -276,15 +267,13 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = def testtext(self): self.runtest( text("select * from foo where lala = bar") , - "select * from foo where lala = bar", - engine = db + "select * from foo where lala = bar" ) self.runtest(select( ["foobar(a)", "pk_foo_bar(syslaal)"], "a = 12", - from_obj = ["foobar left outer join lala on foobar.foo = lala.foo"], - engine = db + from_obj = ["foobar left outer join lala on foobar.foo = lala.foo"] ), "SELECT foobar(a), pk_foo_bar(syslaal) FROM foobar left outer join lala on foobar.foo = lala.foo WHERE a = 12") @@ -296,33 +285,32 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = s.append_whereclause("column2=19") s.order_by("column1") s.append_from("table1") - self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1", db) + self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1") def testtextbinds(self): self.runtest( - db.text("select * from foo where lala=:bar and hoho=:whee"), + text("select * from foo where lala=:bar and hoho=:whee"), "select * from foo where lala=:bar and hoho=:whee", checkparams={'bar':4, 'whee': 7}, params={'bar':4, 'whee': 7, 'hoho':10}, - engine=db ) - engine = postgres.engine({}) + dialect = postgres.dialect() self.runtest( - engine.text("select * from foo where lala=:bar and hoho=:whee"), + text("select * from foo where lala=:bar and hoho=:whee"), "select * from foo where lala=%(bar)s and hoho=%(whee)s", checkparams={'bar':4, 'whee': 7}, params={'bar':4, 'whee': 7, 'hoho':10}, - engine=engine + dialect=dialect ) - engine = sqlite.engine({}) + dialect = sqlite.dialect() self.runtest( - engine.text("select * from foo where lala=:bar and hoho=:whee"), + text("select * from foo where lala=:bar and hoho=:whee"), "select * from foo where lala=? and hoho=?", checkparams=[4, 7], params={'bar':4, 'whee': 7, 'hoho':10}, - engine=engine + dialect=dialect ) def testtextmix(self): @@ -393,7 +381,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today "SELECT foo.bar.lala(:lala)") # test a dotted func off the engine itself - self.runtest(db.func.lala.hoho(7), "lala.hoho(:hoho)") + self.runtest(func.lala.hoho(7), "lala.hoho(:hoho)") def testjoin(self): self.runtest( @@ -461,6 +449,13 @@ FROM mytable WHERE mytable.myid = :mytable_my_1 ORDER BY mytable.myid") FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \ FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thirdtable") + u = union( + select([table1]), + select([table2]), + select([table3]) + ) + assert u.corresponding_column(table2.c.otherid) is u.c.otherid + def testouterjoin(self): # test an outer join. the oracle module should take the ON clause of the join and @@ -485,19 +480,19 @@ FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid WHERE mytable.name = %(mytable_name)s AND mytable.myid = %(mytable_myid)s AND \ myothertable.othername != %(myothertable_othername)s AND \ EXISTS (select yay from foo where boo = lar)", - engine = postgres.engine({})) - + dialect=postgres.dialect() + ) self.runtest(query, "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid(+) AND \ mytable.name = :mytable_name AND mytable.myid = :mytable_myid AND \ myothertable.othername != :myothertable_othername AND EXISTS (select yay from foo where boo = lar)", - engine = oracle.engine({}, use_ansi = False)) + dialect=oracle.OracleDialect(use_ansi = False)) query = table1.outerjoin(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid) self.runtest(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON thirdtable.userid = myothertable.otherid") - self.runtest(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", engine=oracle.engine({}, use_ansi=False)) + self.runtest(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", dialect=oracle.dialect(use_ansi=False)) def testbindparam(self): self.runtest(select( @@ -513,7 +508,7 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable # check that the bind params sent along with a compile() call # get preserved when the params are retreived later s = select([table1], table1.c.myid == bindparam('test')) - c = s.compile(parameters = {'test' : 7}, engine=db) + c = s.compile(parameters = {'test' : 7}) self.assert_(c.get_params() == {'test' : 7}) @@ -542,28 +537,26 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable Column('ts', TIMESTAMP), ) - def check_results(engine, expected_results, literal): + def check_results(dialect, expected_results, literal): self.assertEqual(len(expected_results), 5, 'Incorrect number of expected results') - self.assertEqual(str(cast(tbl.c.v1, Numeric, engine=engine)), 'CAST(casttest.v1 AS %s)' %expected_results[0]) - self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9), engine=engine)), 'CAST(casttest.v1 AS %s)' %expected_results[1]) - self.assertEqual(str(cast(tbl.c.ts, Date, engine=engine)), 'CAST(casttest.ts AS %s)' %expected_results[2]) - self.assertEqual(str(cast(1234, TEXT, engine=engine)), 'CAST(%s AS %s)' %(literal, expected_results[3])) - self.assertEqual(str(cast('test', String(20), engine=engine)), 'CAST(%s AS %s)' %(literal, expected_results[4])) - - sel = select([tbl, cast(tbl.c.v1, Numeric)], engine=engine) - self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) \nFROM casttest") - + self.assertEqual(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0]) + self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1]) + self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2]) + self.assertEqual(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3])) + self.assertEqual(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4])) + sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect) + self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) \nFROM casttest") # first test with Postgres engine - check_results(postgres.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s') + check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s') # then the Oracle engine - check_results(oracle.engine({}, use_ansi = False), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal') +# check_results(oracle.OracleDialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal') # then the sqlite engine - check_results(sqlite.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?') + check_results(sqlite.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?') # and the MySQL engine - check_results(mysql.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%s') + check_results(mysql.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%s') class CRUDTest(SQLTest): def testinsert(self): @@ -601,8 +594,7 @@ class CRUDTest(SQLTest): self.runtest(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'}) self.runtest(update(table1, table1.c.myid == 12, values = {table1.c.myid : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'}) s = table1.update(table1.c.myid == 12, values = {table1.c.name : 'lala'}) - c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'}, engine=db) - print str(c) + c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'}) self.assert_(str(s) == str(c)) def testupdateexpression(self): @@ -623,7 +615,7 @@ class CRUDTest(SQLTest): s = select([table2], table2.c.otherid == table1.c.myid) u = update(table1, table1.c.name == 'jack', values = {table1.c.name : s}) self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name") - + # test a correlated WHERE clause s = select([table2.c.othername], table2.c.otherid == 7) u = update(table1, table1.c.name==s) diff --git a/test/selectable.py b/test/selectable.py index c1a3f28e5..59f57331b 100755 --- a/test/selectable.py +++ b/test/selectable.py @@ -15,6 +15,7 @@ table = Table('table1', db, Column('col1', Integer, primary_key=True),
Column('col2', String(20)),
Column('col3', Integer),
+ Column('colx', Integer),
redefine=True
)
@@ -22,16 +23,10 @@ table2 = Table('table2', db, Column('col1', Integer, primary_key=True),
Column('col2', Integer, ForeignKey('table1.col1')),
Column('col3', String(20)),
+ Column('coly', Integer),
redefine=True
)
-table3 = Table('table3', db,
- Column('col1', Integer, ForeignKey('table1.col1'), primary_key=True),
- Column('col2', Integer),
- Column('col3', String(20)),
- redefine=True
- )
-
class SelectableTest(testbase.AssertMixin):
def testtablealias(self):
a = table.alias('a')
@@ -43,16 +38,51 @@ class SelectableTest(testbase.AssertMixin): print str(j)
self.assert_(criterion.compare(j.onclause))
- def testjoinpks(self):
- a = join(table, table3)
- b = join(table, table3, table.c.col1==table3.c.col2)
- c = join(table, table3, table.c.col2==table3.c.col2)
- d = join(table, table3, table.c.col2==table3.c.col1)
-
- self.assert_(a.primary_key==[table.c.col1])
- self.assert_(b.primary_key==[table.c.col1, table3.c.col1])
- self.assert_(c.primary_key==[table.c.col1, table3.c.col1])
- self.assert_(d.primary_key==[table.c.col1, table3.c.col1])
+ def testunion(self):
+ # tests that we can correspond a column in a Select statement with a certain Table, against
+ # a column in a Union where one of its underlying Selects matches to that same Table
+ u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
+ select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
+ )
+ s1 = table.select(use_labels=True)
+ s2 = table2.select(use_labels=True)
+ print ["%d %s" % (id(c),c.key) for c in u.c]
+ c = u.corresponding_column(s1.c.table1_col2)
+ print "%d %s" % (id(c), c.key)
+ assert u.corresponding_column(s1.c.table1_col2) is u.c.col2
+ assert u.corresponding_column(s2.c.table2_col2) is u.c.col2
+
+ def testaliasunion(self):
+ # same as testunion, except its an alias of the union
+ u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
+ select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
+ ).alias('analias')
+ s1 = table.select(use_labels=True)
+ s2 = table2.select(use_labels=True)
+ assert u.corresponding_column(s1.c.table1_col2) is u.c.col2
+ assert u.corresponding_column(s2.c.table2_col2) is u.c.col2
+ assert u.corresponding_column(s2.c.table2_coly) is u.c.coly
+ assert s2.corresponding_column(u.c.coly) is s2.c.table2_coly
+
+ def testselectunion(self):
+ # like testaliasunion, but off a Select off the union.
+ u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
+ select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
+ ).alias('analias')
+ s = select([u])
+ s1 = table.select(use_labels=True)
+ s2 = table2.select(use_labels=True)
+ assert s.corresponding_column(s1.c.table1_col2) is s.c.col2
+ assert s.corresponding_column(s2.c.table2_col2) is s.c.col2
+
+ def testunionagainstjoin(self):
+ # same as testunion, except its an alias of the union
+ u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
+ select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
+ ).alias('analias')
+ j1 = table.join(table2)
+ assert u.corresponding_column(j1.c.table1_colx) is u.c.colx
+ assert j1.corresponding_column(u.c.colx) is j1.c.table1_colx
def testjoin(self):
a = join(table, table2)
diff --git a/test/selectresults.py b/test/selectresults.py index d04918683..6bd619f3a 100644 --- a/test/selectresults.py +++ b/test/selectresults.py @@ -3,31 +3,40 @@ import testbase from sqlalchemy import * -from sqlalchemy.mods.selectresults import SelectResultsExt +from sqlalchemy.ext.selectresults import SelectResultsExt class Foo(object): pass class SelectResultsTest(PersistTest): def setUpAll(self): + self.install_threadlocal() global foo foo = Table('foo', testbase.db, Column('id', Integer, Sequence('foo_id_seq'), primary_key=True), - Column('bar', Integer)) + Column('bar', Integer), + Column('range', Integer)) assign_mapper(Foo, foo, extension=SelectResultsExt()) foo.create() for i in range(100): - Foo(bar=i) - objectstore.commit() + Foo(bar=i, range=i%10) + objectstore.flush() def setUp(self): - self.orig = Foo.mapper.select_whereclause() - self.res = Foo.select() + self.query = Foo.mapper.query() + self.orig = self.query.select_whereclause() + self.res = self.query.select() def tearDownAll(self): global foo foo.drop() + self.uninstall_threadlocal() + + def test_selectby(self): + res = self.query.select_by(range=5) + assert res.order_by([Foo.c.bar])[0].bar == 5 + assert res.order_by([desc(Foo.c.bar)])[0].bar == 95 def test_slice(self): assert self.res[1] == self.orig[1] diff --git a/test/session.py b/test/session.py new file mode 100644 index 000000000..9ed7f0f7a --- /dev/null +++ b/test/session.py @@ -0,0 +1,7 @@ + +# test merging a composed object. + +# test that when cascading an operation, like "merge", lazy-loaded scalar and list attributes that werent already loaded on the given object remain not loaded. + +# test putting an object in session A, "moving" it to session B, insure its in B and not in A + diff --git a/test/sessioncontext.py b/test/sessioncontext.py new file mode 100644 index 000000000..83bc2f2bf --- /dev/null +++ b/test/sessioncontext.py @@ -0,0 +1,47 @@ +from testbase import PersistTest, AssertMixin +import unittest, sys, os +from sqlalchemy.ext.sessioncontext import SessionContext +from sqlalchemy.orm.session import object_session, Session +from sqlalchemy import * +import testbase + +metadata = MetaData() +users = Table('users', metadata, + Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), + Column('user_name', String(40)), + mysql_engine='innodb' +) + +class SessionContextTest(AssertMixin): + def setUp(self): + clear_mappers() + + def do_test(self, class_, context): + """test session assignment on object creation""" + obj = class_() + assert context.current == object_session(obj) + + # keep a reference so the old session doesn't get gc'd + old_session = context.current + + context.current = Session() + assert context.current != object_session(obj) + assert old_session == object_session(obj) + + new_session = context.current + del context.current + assert context.current != new_session + assert old_session == object_session(obj) + + obj2 = class_() + assert context.current == object_session(obj2) + + def test_mapper_extension(self): + context = SessionContext(Session) + class User(object): pass + User.mapper = mapper(User, users, extension=context.mapper_extension) + self.do_test(User, context) + + +if __name__ == "__main__": + testbase.main() diff --git a/test/tables.py b/test/tables.py index f1e1a845b..2bfc75868 100644 --- a/test/tables.py +++ b/test/tables.py @@ -9,22 +9,22 @@ __all__ = ['db', 'users', 'addresses', 'orders', 'orderitems', 'keywords', 'item ECHO = testbase.echo db = testbase.db +metadata = BoundMetaData(db) - -users = Table('users', db, +users = Table('users', metadata, Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), Column('user_name', String(40)), mysql_engine='innodb' ) -addresses = Table('email_addresses', db, +addresses = Table('email_addresses', metadata, Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True), Column('user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(40)), ) -orders = Table('orders', db, +orders = Table('orders', metadata, Column('order_id', Integer, Sequence('order_id_seq', optional=True), primary_key = True), Column('user_id', Integer, ForeignKey(users.c.user_id)), Column('description', String(50)), @@ -32,51 +32,32 @@ orders = Table('orders', db, ) -orderitems = Table('items', db, +orderitems = Table('items', metadata, Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True), Column('order_id', INT, ForeignKey("orders")), Column('item_name', VARCHAR(50)), ) -keywords = Table('keywords', db, +keywords = Table('keywords', metadata, Column('keyword_id', Integer, Sequence('keyword_id_seq', optional=True), primary_key = True), Column('name', VARCHAR(50)), ) -itemkeywords = Table('itemkeywords', db, +itemkeywords = Table('itemkeywords', metadata, Column('item_id', INT, ForeignKey("items")), Column('keyword_id', INT, ForeignKey("keywords")), ) def create(): - users.create() - addresses.create() - orders.create() - orderitems.create() - keywords.create() - itemkeywords.create() - + metadata.create_all() def drop(): - itemkeywords.drop() - keywords.drop() - orderitems.drop() - orders.drop() - addresses.drop() - users.drop() - db.commit() - + metadata.drop_all() def delete(): - itemkeywords.delete().execute() - keywords.delete().execute() - orderitems.delete().execute() - orders.delete().execute() - addresses.delete().execute() - users.delete().execute() - db.commit() - + for t in metadata.table_iterator(reverse=True): + t.delete().execute() def user_data(): users.insert().execute( dict(user_id = 7, user_name = 'jack'), @@ -85,7 +66,6 @@ def user_data(): ) def delete_user_data(): users.delete().execute() - db.commit() def data(): delete() @@ -144,8 +124,6 @@ def data(): dict(keyword_id=7, item_id=2), dict(keyword_id=6, item_id=3) ) - - db.commit() class User(object): def __init__(self): diff --git a/test/testbase.py b/test/testbase.py index 8ef63ef27..04972779d 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -2,63 +2,96 @@ import unittest import StringIO import sqlalchemy.engine as engine import sqlalchemy.ext.proxy as proxy -import sqlalchemy.schema as schema +import sqlalchemy.pool as pool +#import sqlalchemy.schema as schema import re, sys +import sqlalchemy +import optparse + -echo = True -#echo = False -#echo = 'debug' db = None +metadata = None db_uri = None +echo = True + +# redefine sys.stdout so all those print statements go to the echo func +local_stdout = sys.stdout +class Logger(object): + def write(self, msg): + if echo: + local_stdout.write(msg) +sys.stdout = Logger() + +def echo_text(text): + print text def parse_argv(): # we are using the unittest main runner, so we are just popping out the # arguments we need instead of using our own getopt type of thing - global db, db_uri + global db, db_uri, metadata DBTYPE = 'sqlite' PROXY = False + + + parser = optparse.OptionParser(usage = "usage: %prog [options] files...") + parser.add_option("--dburi", action="store", dest="dburi", help="database uri (overrides --db)") + parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (sqlite, sqlite_file, postgres, mysql, oracle, oracle8, mssql)") + parser.add_option("--mockpool", action="store_true", dest="mockpool", help="use mock pool") + parser.add_option("--verbose", action="store_true", dest="verbose", help="full debug echoing") + parser.add_option("--quiet", action="store_true", dest="quiet", help="be totally quiet") + parser.add_option("--nothreadlocal", action="store_true", dest="nothreadlocal", help="dont use thread-local mod") + parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to SA default)") + + (options, args) = parser.parse_args() + sys.argv[1:] = args - if len(sys.argv) >= 3: - if sys.argv[1] == '--dburi': - (param, db_uri) = (sys.argv.pop(1), sys.argv.pop(1)) - elif sys.argv[1] == '--db': - (param, DBTYPE) = (sys.argv.pop(1), sys.argv.pop(1)) + if options.dburi: + db_uri = param = options.dburi + elif options.db: + DBTYPE = param = options.db + opts = {} if (None == db_uri): - p = DBTYPE.split('.') - if len(p) > 1: - arg = p[0] - DBTYPE = p[1] - if arg == 'proxy': - PROXY = True if DBTYPE == 'sqlite': - db_uri = 'sqlite://filename=:memory:' + db_uri = 'sqlite:///:memory:' elif DBTYPE == 'sqlite_file': - db_uri = 'sqlite://filename=querytest.db' + db_uri = 'sqlite:///querytest.db' elif DBTYPE == 'postgres': - db_uri = 'postgres://database=test&port=5432&host=127.0.0.1&user=scott&password=tiger' + db_uri = 'postgres://scott:tiger@127.0.0.1:5432/test' elif DBTYPE == 'mysql': - db_uri = 'mysql://database=test&host=127.0.0.1&user=scott&password=tiger' + db_uri = 'mysql://scott:tiger@127.0.0.1/test' elif DBTYPE == 'oracle': - db_uri = 'oracle://user=scott&password=tiger' + db_uri = 'oracle://scott:tiger@127.0.0.1:1521' elif DBTYPE == 'oracle8': - db_uri = 'oracle://user=scott&password=tiger' + db_uri = 'oracle://scott:tiger@127.0.0.1:1521' opts = {'use_ansi':False} elif DBTYPE == 'mssql': - db_uri = 'mssql://database=test&user=scott&password=tiger' + db_uri = 'mssql://scott:tiger@/test' if not db_uri: raise "Could not create engine. specify --db <sqlite|sqlite_file|postgres|mysql|oracle|oracle8|mssql> to test runner." - if PROXY: - db = proxy.ProxyEngine(echo=echo, default_ordering=True, **opts) - db.connect(db_uri) + if not options.nothreadlocal: + __import__('sqlalchemy.mods.threadlocal') + sqlalchemy.mods.threadlocal.uninstall_plugin() + + global echo + echo = options.verbose and not options.quiet + + global quiet + quiet = options.quiet + + if options.enginestrategy is not None: + opts['strategy'] = options.enginestrategy + if options.mockpool: + db = engine.create_engine(db_uri, echo=True, default_ordering=True, poolclass=MockPool, **opts) else: - db = engine.create_engine(db_uri, echo=echo, default_ordering=True, **opts) + db = engine.create_engine(db_uri, echo=True, default_ordering=True, **opts) db = EngineAssert(db) - + metadata = sqlalchemy.BoundMetaData(db) + def unsupported(*dbs): """a decorator that marks a test as unsupported by one or more database implementations""" def decorate(func): @@ -87,21 +120,49 @@ def supported(*dbs): return lala return decorate -def echo_text(text): - print text class PersistTest(unittest.TestCase): """persist base class, provides default setUpAll, tearDownAll and echo functionality""" def __init__(self, *args, **params): unittest.TestCase.__init__(self, *args, **params) def echo(self, text): - if echo: - echo_text(text) + echo_text(text) + def install_threadlocal(self): + sqlalchemy.mods.threadlocal.install_plugin() + def uninstall_threadlocal(self): + sqlalchemy.mods.threadlocal.uninstall_plugin() def setUpAll(self): pass def tearDownAll(self): pass + def shortDescription(self): + """overridden to not return docstrings""" + return None + +class MockPool(pool.Pool): + """this pool is hardcore about only one connection being used at a time.""" + def __init__(self, creator, **params): + pool.Pool.__init__(self, **params) + self.connection = creator() + self._conn = self.connection + + def status(self): + return "MockPool" + + def do_return_conn(self, conn): + assert conn is self._conn and self.connection is None + self.connection = conn + + def do_return_invalid(self): + raise "Invalid" + def do_get(self): + if getattr(self, 'breakpoint', False): + raise "breakpoint" + assert self.connection is not None + c = self.connection + self.connection = None + return c class AssertMixin(PersistTest): """given a list-based structure of keys/properties which represent information within an object structure, and @@ -145,8 +206,10 @@ class EngineAssert(proxy.BaseProxyEngine): """decorates a SQLEngine object to match the incoming queries against a set of assertions.""" def __init__(self, engine): self._engine = engine - self.realexec = engine.post_exec - self.realexec.im_self.post_exec = self.post_exec + + self.real_execution_context = engine.dialect.create_execution_context + engine.dialect.create_execution_context = self.execution_context + self.logger = engine.logger self.set_assert_list(None, None) self.sql_count = 0 @@ -154,8 +217,6 @@ class EngineAssert(proxy.BaseProxyEngine): return self._engine def set_engine(self, e): self._engine = e -# def __getattr__(self, key): - # return getattr(self.engine, key) def set_assert_list(self, unittest, list): self.unittest = unittest self.assert_list = list @@ -164,46 +225,55 @@ class EngineAssert(proxy.BaseProxyEngine): def _set_echo(self, echo): self.engine.echo = echo echo = property(lambda s: s.engine.echo, _set_echo) - def post_exec(self, proxy, compiled, parameters, **kwargs): - self.engine.logger = self.logger - statement = str(compiled) - statement = re.sub(r'\n', '', statement) - - if self.assert_list is not None: - item = self.assert_list[-1] - if not isinstance(item, dict): - item = self.assert_list.pop() - else: - # asserting a dictionary of statements->parameters - # this is to specify query assertions where the queries can be in - # multiple orderings - if not item.has_key('_converted'): - for key in item.keys(): - ckey = self.convert_statement(key) - item[ckey] = item[key] - if ckey != key: - del item[key] - item['_converted'] = True - try: - entry = item.pop(statement) - if len(item) == 1: - self.assert_list.pop() - item = (statement, entry) - except KeyError: - self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)) - - (query, params) = item - if callable(params): - params = params() - - query = self.convert_statement(query) - - self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) - self.sql_count += 1 - return self.realexec(proxy, compiled, parameters, **kwargs) + + def execution_context(self): + def post_exec(engine, proxy, compiled, parameters, **kwargs): + ctx = e + self.engine.logger = self.logger + statement = str(compiled) + statement = re.sub(r'\n', '', statement) + if self.assert_list is not None: + item = self.assert_list[-1] + if not isinstance(item, dict): + item = self.assert_list.pop() + else: + # asserting a dictionary of statements->parameters + # this is to specify query assertions where the queries can be in + # multiple orderings + if not item.has_key('_converted'): + for key in item.keys(): + ckey = self.convert_statement(key) + item[ckey] = item[key] + if ckey != key: + del item[key] + item['_converted'] = True + try: + entry = item.pop(statement) + if len(item) == 1: + self.assert_list.pop() + item = (statement, entry) + except KeyError: + self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)) + + (query, params) = item + if callable(params): + params = params(ctx) + if params is not None and isinstance(params, list) and len(params) == 1: + params = params[0] + + query = self.convert_statement(query) + self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) + self.sql_count += 1 + return realexec(ctx, proxy, compiled, parameters, **kwargs) + + e = self.real_execution_context() + realexec = e.post_exec + realexec.im_self.post_exec = post_exec + return e + def convert_statement(self, query): - paramstyle = self.engine.paramstyle + paramstyle = self.engine.dialect.paramstyle if paramstyle == 'named': pass elif paramstyle =='pyformat': @@ -275,10 +345,11 @@ parse_argv() def runTests(suite): - runner = unittest.TextTestRunner(verbosity = 2, descriptions =1) + runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2) runner.run(suite) def main(): - unittest.main() + suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__')) + runTests(suite) diff --git a/test/testtypes.py b/test/testtypes.py index ae58f0f54..db1fbca20 100644 --- a/test/testtypes.py +++ b/test/testtypes.py @@ -2,7 +2,8 @@ from sqlalchemy import * import string,datetime, re, sys from testbase import PersistTest, AssertMixin import testbase - +import sqlalchemy.engine.url as url + db = testbase.db class MyType(types.TypeEngine): @@ -34,15 +35,15 @@ class MyUnicodeType(types.Unicode): class AdaptTest(PersistTest): def testadapt(self): - e1 = create_engine('postgres://') - e2 = create_engine('sqlite://') - e3 = create_engine('mysql://') + e1 = url.URL('postgres').get_module().dialect() + e2 = url.URL('mysql').get_module().dialect() + e3 = url.URL('sqlite').get_module().dialect() type = String(40) - t1 = type.engine_impl(e1) - t2 = type.engine_impl(e2) - t3 = type.engine_impl(e3) + t1 = type.dialect_impl(e1) + t2 = type.dialect_impl(e2) + t3 = type.dialect_impl(e3) assert t1 != t2 assert t2 != t3 assert t3 != t1 @@ -116,7 +117,7 @@ class ColumnsTest(AssertMixin): ) for aCol in testTable.c: - self.assertEquals(expectedResults[aCol.name], db.schemagenerator().get_column_specification(aCol)) + self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None).get_column_specification(aCol)) class UnicodeTest(AssertMixin): """tests the Unicode type. also tests the TypeDecorator with instances in the types package.""" @@ -130,13 +131,6 @@ class UnicodeTest(AssertMixin): unicode_table.create() def tearDownAll(self): unicode_table.drop() - def testwhereclause(self): - l = unicode_table.select(unicode_table.c.unicode_data==u'this is also unicode').execute() - def testmapperwhere(self): - class Foo(object):pass - m = mapper(Foo, unicode_table) - l = m.get_by(unicode_data=unicode('this is also unicode')) - l = m.get_by(plain_data=unicode('this is also unicode')) def testbasic(self): rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') @@ -147,15 +141,15 @@ class UnicodeTest(AssertMixin): self.assert_(isinstance(x['unicode_data'], unicode) and x['unicode_data'] == unicodedata) if isinstance(x['plain_data'], unicode): # SQLLite returns even non-unicode data as unicode - self.assert_(sys.modules[db.engine.__module__].descriptor()['name'] == 'sqlite') + self.assert_(db.name == 'sqlite') self.echo("its sqlite !") else: self.assert_(not isinstance(x['plain_data'], unicode) and x['plain_data'] == rawdata) def testengineparam(self): """tests engine-wide unicode conversion""" - prev_unicode = db.engine.convert_unicode + prev_unicode = db.engine.dialect.convert_unicode try: - db.engine.convert_unicode = True + db.engine.dialect.convert_unicode = True rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') unicode_table.insert().execute(unicode_data=unicodedata, plain_data=rawdata) @@ -165,8 +159,7 @@ class UnicodeTest(AssertMixin): self.assert_(isinstance(x['unicode_data'], unicode) and x['unicode_data'] == unicodedata) self.assert_(isinstance(x['plain_data'], unicode) and x['plain_data'] == unicodedata) finally: - db.engine.convert_unicode = prev_unicode - + db.engine.dialect.convert_unicode = prev_unicode class Foo(object): def __init__(self, moredata): @@ -175,7 +168,7 @@ class Foo(object): self.moredata = moredata def __eq__(self, other): return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata - + class BinaryTest(AssertMixin): def setUpAll(self): global binary_table @@ -184,21 +177,20 @@ class BinaryTest(AssertMixin): Column('data', Binary), Column('data_slice', Binary(100)), Column('misc', String(30)), - Column('pickled', PickleType)) + Column('pickled', PickleType) + ) binary_table.create() def tearDownAll(self): binary_table.drop() def testbinary(self): testobj1 = Foo('im foo 1') testobj2 = Foo('im foo 2') - + stream1 =self.get_module_stream('sqlalchemy.sql') - stream2 =self.get_module_stream('sqlalchemy.engine') + stream2 =self.get_module_stream('sqlalchemy.schema') binary_table.insert().execute(primary_id=1, misc='sql.pyc', data=stream1, data_slice=stream1[0:100], pickled=testobj1) - binary_table.insert().execute(primary_id=2, misc='engine.pyc', data=stream2, data_slice=stream2[0:99], pickled=testobj2) + binary_table.insert().execute(primary_id=2, misc='schema.pyc', data=stream2, data_slice=stream2[0:99], pickled=testobj2) l = binary_table.select().execute().fetchall() - print type(l[0]['data']) - return print len(stream1), len(l[0]['data']), len(l[0]['data_slice']) self.assert_(list(stream1) == list(l[0]['data'])) self.assert_(list(stream1[0:100]) == list(l[0]['data_slice'])) @@ -231,7 +223,7 @@ class DateTest(AssertMixin): collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime), Column('user_date', Date), Column('user_time', Time)] - if db.engine.__module__.endswith('mysql') or db.engine.__module__.endswith('mssql'): + if db.engine.name == 'mysql' or db.engine.name == 'mssql': # strip microseconds -- not supported by this engine (should be an easier way to detect this) for d in insert_data: if d[2] is not None: diff --git a/test/transaction.py b/test/transaction.py new file mode 100644 index 000000000..d76d7e0f4 --- /dev/null +++ b/test/transaction.py @@ -0,0 +1,63 @@ + +import testbase +import unittest, sys, datetime +import tables +db = testbase.db +from sqlalchemy import * + +class TransactionTest(testbase.PersistTest): + def setUpAll(self): + global users, metadata + metadata = MetaData() + users = Table('query_users', metadata, + Column('user_id', INT, primary_key = True), + Column('user_name', VARCHAR(20)), + ) + users.create(testbase.db) + + def tearDown(self): + testbase.db.connect().execute(users.delete()) + def tearDownAll(self): + users.drop(testbase.db) + + @testbase.unsupported('mysql') + def testrollback(self): + """test a basic rollback""" + connection = testbase.db.connect() + transaction = connection.begin() + connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=3, user_name='user3') + transaction.rollback() + + result = connection.execute("select * from query_users") + assert len(result.fetchall()) == 0 + connection.close() + +class AutoRollbackTest(testbase.PersistTest): + def setUpAll(self): + global metadata + metadata = MetaData() + + def tearDownAll(self): + metadata.drop_all(testbase.db) + + def testrollback_deadlock(self): + """test that returning connections to the pool clears any object locks.""" + conn1 = testbase.db.connect() + conn2 = testbase.db.connect() + users = Table('deadlock_users', metadata, + Column('user_id', INT, primary_key = True), + Column('user_name', VARCHAR(20)), + ) + users.create(conn1) + conn1.execute("select * from deadlock_users") + conn1.close() + # without auto-rollback in the connection pool's return() logic, this deadlocks in Postgres, + # because conn1 is returned to the pool but still has a lock on "deadlock_users" + # comment out the rollback in pool/ConnectionFairy._close() to see ! + users.drop(conn2) + conn2.close() + +if __name__ == "__main__": + testbase.main() |
