summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-05-25 14:20:23 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-05-25 14:20:23 +0000
commitbb79e2e871d0a4585164c1a6ed626d96d0231975 (patch)
tree6d457ba6c36c408b45db24ec3c29e147fe7504ff /test
parent4fc3a0648699c2b441251ba4e1d37a9107bd1986 (diff)
downloadsqlalchemy-bb79e2e871d0a4585164c1a6ed626d96d0231975.tar.gz
merged 0.2 branch into trunk; 0.1 now in sqlalchemy/branches/rel_0_1
Diffstat (limited to 'test')
-rw-r--r--test/activemapper.py157
-rw-r--r--test/alltests.py16
-rw-r--r--test/attributes.py3
-rw-r--r--test/cascade.py173
-rw-r--r--test/cycles.py187
-rw-r--r--test/defaults.py35
-rw-r--r--test/dependency.py66
-rw-r--r--test/eagertest1.py41
-rw-r--r--test/eagertest2.py111
-rw-r--r--test/engine.py64
-rw-r--r--test/entity.py68
-rw-r--r--test/indexes.py32
-rw-r--r--test/inheritance.py285
-rw-r--r--test/lazytest1.py88
-rw-r--r--test/legacy_objectstore.py113
-rw-r--r--test/manytomany.py153
-rw-r--r--test/mapper.py453
-rw-r--r--test/masscreate.py4
-rw-r--r--test/massload.py8
-rw-r--r--test/objectstore.py512
-rw-r--r--test/onetoone.py47
-rw-r--r--test/parseconnect.py29
-rw-r--r--test/polymorph.py169
-rw-r--r--test/pool.py66
-rw-r--r--test/proxy_engine.py243
-rw-r--r--test/query.py21
-rw-r--r--test/reflection.py73
-rw-r--r--test/relationships.py56
-rw-r--r--test/select.py100
-rwxr-xr-xtest/selectable.py64
-rw-r--r--test/selectresults.py21
-rw-r--r--test/session.py7
-rw-r--r--test/sessioncontext.py47
-rw-r--r--test/tables.py44
-rw-r--r--test/testbase.py223
-rw-r--r--test/testtypes.py48
-rw-r--r--test/transaction.py63
37 files changed, 2159 insertions, 1731 deletions
diff --git a/test/activemapper.py b/test/activemapper.py
index 7edb41488..6a8b0904e 100644
--- a/test/activemapper.py
+++ b/test/activemapper.py
@@ -1,80 +1,84 @@
-from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one
-from sqlalchemy.ext import activemapper
-from sqlalchemy import objectstore, global_connect
-from sqlalchemy import and_, or_
-from sqlalchemy import ForeignKey, String, Integer, DateTime
-from datetime import datetime
+from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one, objectstore
+from sqlalchemy import and_, or_, clear_mappers
+from sqlalchemy import ForeignKey, String, Integer, DateTime
+from datetime import datetime
import unittest
+import sqlalchemy.ext.activemapper as activemapper
-#
-# application-level model objects
-#
+import testbase
-class Person(ActiveMapper):
- class mapping:
- id = column(Integer, primary_key=True)
- full_name = column(String)
- first_name = column(String)
- middle_name = column(String)
- last_name = column(String)
- birth_date = column(DateTime)
- ssn = column(String)
- gender = column(String)
- home_phone = column(String)
- cell_phone = column(String)
- work_phone = column(String)
- prefs_id = column(Integer, foreign_key=ForeignKey('preferences.id'))
- addresses = one_to_many('Address', colname='person_id', backref='person')
- preferences = one_to_one('Preferences', colname='pref_id', backref='person')
-
- def __str__(self):
- s = '%s\n' % self.full_name
- s += ' * birthdate: %s\n' % (self.birth_date or 'not provided')
- s += ' * fave color: %s\n' % (self.preferences.favorite_color or 'Unknown')
- s += ' * personality: %s\n' % (self.preferences.personality_type or 'Unknown')
-
- for address in self.addresses:
- s += ' * address: %s\n' % address.address_1
- s += ' %s, %s %s\n' % (address.city, address.state, address.postal_code)
-
- return s
+class testcase(testbase.PersistTest):
+ def setUpAll(self):
+ global Person, Preferences, Address
+
+ class Person(ActiveMapper):
+ class mapping:
+ id = column(Integer, primary_key=True)
+ full_name = column(String)
+ first_name = column(String)
+ middle_name = column(String)
+ last_name = column(String)
+ birth_date = column(DateTime)
+ ssn = column(String)
+ gender = column(String)
+ home_phone = column(String)
+ cell_phone = column(String)
+ work_phone = column(String)
+ prefs_id = column(Integer, foreign_key=ForeignKey('preferences.id'))
+ addresses = one_to_many('Address', colname='person_id', backref='person')
+ preferences = one_to_one('Preferences', colname='pref_id', backref='person')
+ def __str__(self):
+ s = '%s\n' % self.full_name
+ s += ' * birthdate: %s\n' % (self.birth_date or 'not provided')
+ s += ' * fave color: %s\n' % (self.preferences.favorite_color or 'Unknown')
+ s += ' * personality: %s\n' % (self.preferences.personality_type or 'Unknown')
-class Preferences(ActiveMapper):
- class mapping:
- __table__ = 'preferences'
- id = column(Integer, primary_key=True)
- favorite_color = column(String)
- personality_type = column(String)
+ for address in self.addresses:
+ s += ' * address: %s\n' % address.address_1
+ s += ' %s, %s %s\n' % (address.city, address.state, address.postal_code)
+ return s
-class Address(ActiveMapper):
- class mapping:
- id = column(Integer, primary_key=True)
- type = column(String)
- address_1 = column(String)
- city = column(String)
- state = column(String)
- postal_code = column(String)
- person_id = column(Integer, foreign_key=ForeignKey('person.id'))
+ class Preferences(ActiveMapper):
+ class mapping:
+ __table__ = 'preferences'
+ id = column(Integer, primary_key=True)
+ favorite_color = column(String)
+ personality_type = column(String)
+ class Address(ActiveMapper):
+ class mapping:
+ id = column(Integer, primary_key=True)
+ type = column(String)
+ address_1 = column(String)
+ city = column(String)
+ state = column(String)
+ postal_code = column(String)
+ person_id = column(Integer, foreign_key=ForeignKey('person.id'))
+ activemapper.metadata.connect(testbase.db)
+ activemapper.create_tables()
-class testcase(unittest.TestCase):
-
+ def tearDownAll(self):
+ clear_mappers()
+ activemapper.drop_tables()
+
def tearDown(self):
- people = Person.select()
- for person in people: person.delete()
+ for t in activemapper.metadata.table_iterator(reverse=True):
+ t.delete().execute()
+ #people = Person.select()
+ #for person in people: person.delete()
- addresses = Address.select()
- for address in addresses: address.delete()
+ #addresses = Address.select()
+ #for address in addresses: address.delete()
- preferences = Preferences.select()
- for preference in preferences: preference.delete()
+ #preferences = Preferences.select()
+ #for preference in preferences: preference.delete()
- objectstore.commit()
- objectstore.clear()
+ #objectstore.flush()
+ #objectstore.clear()
def create_person_one(self):
# create a person
@@ -130,12 +134,8 @@ class testcase(unittest.TestCase):
def test_create(self):
- global_connect('sqlite:///', echo=False)
- activemapper.create_tables()
-
p1 = self.create_person_one()
-
- objectstore.commit()
+ objectstore.flush()
objectstore.clear()
results = Person.select()
@@ -151,14 +151,14 @@ class testcase(unittest.TestCase):
def test_delete(self):
p1 = self.create_person_one()
- objectstore.commit()
+ objectstore.flush()
objectstore.clear()
results = Person.select()
self.assertEquals(len(results), 1)
results[0].delete()
- objectstore.commit()
+ objectstore.flush()
objectstore.clear()
results = Person.select()
@@ -169,7 +169,7 @@ class testcase(unittest.TestCase):
p1 = self.create_person_one()
p2 = self.create_person_two()
- objectstore.commit()
+ objectstore.flush()
objectstore.clear()
# select and make sure we get back two results
@@ -200,29 +200,38 @@ class testcase(unittest.TestCase):
# FIXME: I don't know why, but it seems that my backwards relationship
# on preferences still ends up being a list even though I pass
# in uselist=False...
+ # FIXED: the backref is a new PropertyLoader which needs its own "uselist".
+ # uses a function which I dont think existed when you first wrote ActiveMapper.
p1 = self.create_person_one()
self.assertEquals(p1.preferences.person, p1)
p1.delete()
- objectstore.commit()
+ objectstore.flush()
objectstore.clear()
def test_select_by(self):
# FIXME: either I don't understand select_by, or it doesn't work.
+ # FIXED (as good as we can for now): yup....everyone thinks it works that way....it only
+ # generates joins for keyword arguments, not ColumnClause args. would need a new layer of
+ # "MapperClause" objects to use properties in expressions. (MB)
p1 = self.create_person_one()
p2 = self.create_person_two()
- objectstore.commit()
+ objectstore.flush()
objectstore.clear()
- results = Person.select_by(
- Address.c.postal_code.like('30075')
+ results = Person.select(
+ Address.c.postal_code.like('30075') &
+ Person.join_to('addresses')
)
self.assertEquals(len(results), 1)
if __name__ == '__main__':
+ # go ahead and setup the database connection, and create the tables
+
+ # launch the unit tests
unittest.main() \ No newline at end of file
diff --git a/test/alltests.py b/test/alltests.py
index 3595edd7e..c1662bce7 100644
--- a/test/alltests.py
+++ b/test/alltests.py
@@ -1,20 +1,16 @@
import testbase
import unittest
-testbase.echo = False
-
-#test
-
def suite():
modules_to_test = (
# core utilities
- 'historyarray',
+ 'historyarray',
'attributes',
'dependency',
-
+
# connectivity, execution
'pool',
- 'engine',
+ 'transaction',
# schema/tables
'reflection',
@@ -40,7 +36,9 @@ def suite():
'eagertest2',
# ORM persistence
+ 'sessioncontext',
'objectstore',
+ 'cascade',
'relationships',
# cyclical ORM persistence
@@ -51,9 +49,11 @@ def suite():
'manytomany',
'onetoone',
'inheritance',
+ 'polymorph',
# extensions
'proxy_engine',
+ 'activemapper'
#'wsgi_test',
)
@@ -62,8 +62,6 @@ def suite():
alltests.addTest(unittest.findTestCases(module, suiteClass=None))
return alltests
-import sys
-sys.stdout = sys.stderr
if __name__ == '__main__':
testbase.runTests(suite())
diff --git a/test/attributes.py b/test/attributes.py
index 126f50456..bff864fa6 100644
--- a/test/attributes.py
+++ b/test/attributes.py
@@ -45,8 +45,7 @@ class AttributesTest(PersistTest):
manager.register_attribute(MyTest, 'email_address', uselist = False)
x = MyTest()
x.user_id=7
- s = pickle.dumps(x)
- y = pickle.loads(s)
+ pickle.dumps(x)
def testlist(self):
class User(object):pass
diff --git a/test/cascade.py b/test/cascade.py
new file mode 100644
index 000000000..4a997d67f
--- /dev/null
+++ b/test/cascade.py
@@ -0,0 +1,173 @@
+import testbase, tables
+import unittest, sys, datetime
+
+from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy import *
+
+class O2MCascadeTest(testbase.AssertMixin):
+ def tearDown(self):
+ ctx.current.clear()
+ tables.delete()
+
+ def tearDownAll(self):
+ clear_mappers()
+ tables.drop()
+
+ def setUpAll(self):
+ global ctx, data
+ ctx = SessionContext(lambda: create_session(echo_uow=True))
+ tables.create()
+ mapper(tables.User, tables.users, properties = dict(
+ address = relation(mapper(tables.Address, tables.addresses), lazy = False, uselist = False, private = True),
+ orders = relation(
+ mapper(tables.Order, tables.orders, properties = dict (
+ items = relation(mapper(tables.Item, tables.orderitems), lazy = False, uselist =True, private = True)
+ )),
+ lazy = True, uselist = True, private = True)
+ ))
+
+ def setUp(self):
+ global data
+ data = [tables.User,
+ {'user_name' : 'ed',
+ 'address' : (tables.Address, {'email_address' : 'foo@bar.com'}),
+ 'orders' : (tables.Order, [
+ {'description' : 'eds 1st order', 'items' : (tables.Item, [{'item_name' : 'eds o1 item'}, {'item_name' : 'eds other o1 item'}])},
+ {'description' : 'eds 2nd order', 'items' : (tables.Item, [{'item_name' : 'eds o2 item'}, {'item_name' : 'eds other o2 item'}])}
+ ])
+ },
+ {'user_name' : 'jack',
+ 'address' : (tables.Address, {'email_address' : 'jack@jack.com'}),
+ 'orders' : (tables.Order, [
+ {'description' : 'jacks 1st order', 'items' : (tables.Item, [{'item_name' : 'im a lumberjack'}, {'item_name' : 'and im ok'}])}
+ ])
+ },
+ {'user_name' : 'foo',
+ 'address' : (tables.Address, {'email_address': 'hi@lala.com'}),
+ 'orders' : (tables.Order, [
+ {'description' : 'foo order', 'items' : (tables.Item, [])},
+ {'description' : 'foo order 2', 'items' : (tables.Item, [{'item_name' : 'hi'}])},
+ {'description' : 'foo order three', 'items' : (tables.Item, [{'item_name' : 'there'}])}
+ ])
+ }
+ ]
+
+ for elem in data[1:]:
+ u = tables.User()
+ ctx.current.save(u)
+ u.user_name = elem['user_name']
+ u.address = tables.Address()
+ u.address.email_address = elem['address'][1]['email_address']
+ u.orders = []
+ for order in elem['orders'][1]:
+ o = tables.Order()
+ o.isopen = None
+ o.description = order['description']
+ u.orders.append(o)
+ o.items = []
+ for item in order['items'][1]:
+ i = tables.Item()
+ i.item_name = item['item_name']
+ o.items.append(i)
+
+ ctx.current.flush()
+ ctx.current.clear()
+
+
+ def testdelete(self):
+ l = ctx.current.query(tables.User).select()
+ for u in l:
+ self.echo( repr(u.orders))
+ self.assert_result(l, data[0], *data[1:])
+
+ self.echo("\n\n\n")
+ ids = (l[0].user_id, l[2].user_id)
+ ctx.current.delete(l[0])
+ ctx.current.delete(l[2])
+
+ ctx.current.flush()
+ self.assert_(tables.orders.count(tables.orders.c.user_id.in_(*ids)).scalar() == 0)
+ self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0)
+ self.assert_(tables.addresses.count(tables.addresses.c.user_id.in_(*ids)).scalar() == 0)
+ self.assert_(tables.users.count(tables.users.c.user_id.in_(*ids)).scalar() == 0)
+
+
+ def testorphan(self):
+ l = ctx.current.query(tables.User).select()
+ jack = l[1]
+ jack.orders[:] = []
+
+ ids = [jack.user_id]
+ self.assert_(tables.orders.count(tables.orders.c.user_id.in_(*ids)).scalar() == 1)
+ self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 2)
+
+ ctx.current.flush()
+
+ self.assert_(tables.orders.count(tables.orders.c.user_id.in_(*ids)).scalar() == 0)
+ self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0)
+
+
+class M2OCascadeTest(testbase.AssertMixin):
+ def tearDown(self):
+ ctx.current.clear()
+ for t in metadata.table_iterator(reverse=True):
+ t.delete().execute()
+
+ def tearDownAll(self):
+ clear_mappers()
+ metadata.drop_all()
+
+ def setUpAll(self):
+ global ctx, data, metadata, User, Pref
+ ctx = SessionContext(create_session)
+ metadata = BoundMetaData(testbase.db)
+ prefs = Table('prefs', metadata,
+ Column('prefs_id', Integer, Sequence('prefs_id_seq', optional=True), primary_key=True),
+ Column('prefs_data', String(40)))
+
+ users = Table('users', metadata,
+ Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
+ Column('user_name', String(40)),
+ Column('pref_id', Integer, ForeignKey('prefs.prefs_id'))
+ )
+ class User(object):
+ pass
+ class Pref(object):
+ pass
+ metadata.create_all()
+ mapper(User, users, properties = dict(
+ pref = relation(mapper(Pref, prefs), lazy=False, cascade="all, delete-orphan")
+ ))
+
+ def setUp(self):
+ global data
+ data = [User,
+ {'user_name' : 'ed',
+ 'pref' : (Pref, {'prefs_data' : 'pref 1'}),
+ },
+ {'user_name' : 'jack',
+ 'pref' : (Pref, {'prefs_data' : 'pref 2'}),
+ },
+ {'user_name' : 'foo',
+ 'pref' : (Pref, {'prefs_data' : 'pref 3'}),
+ }
+ ]
+
+ for elem in data[1:]:
+ u = User()
+ ctx.current.save(u)
+ u.user_name = elem['user_name']
+ u.pref = Pref()
+ u.pref.prefs_data = elem['pref'][1]['prefs_data']
+
+ ctx.current.flush()
+ ctx.current.clear()
+
+ def testorphan(self):
+ l = ctx.current.query(User).select()
+ jack = l[1]
+ jack.pref = None
+ ctx.current.flush()
+
+if __name__ == "__main__":
+ testbase.main()
diff --git a/test/cycles.py b/test/cycles.py
index 02d34bbb0..fb051f47a 100644
--- a/test/cycles.py
+++ b/test/cycles.py
@@ -22,26 +22,22 @@ class Tester(object):
class SelfReferentialTest(AssertMixin):
"""tests a self-referential mapper, with an additional list of child objects."""
def setUpAll(self):
- testbase.db.tables.clear()
- global t1
- global t2
- t1 = Table('t1', testbase.db,
+ global t1, t2, metadata
+ metadata = BoundMetaData(testbase.db)
+ t1 = Table('t1', metadata,
Column('c1', Integer, primary_key=True),
Column('parent_c1', Integer, ForeignKey('t1.c1')),
Column('data', String(20))
)
- t2 = Table('t2', testbase.db,
+ t2 = Table('t2', metadata,
Column('c1', Integer, primary_key=True),
Column('c1id', Integer, ForeignKey('t1.c1')),
Column('data', String(20))
)
- t1.create()
- t2.create()
+ metadata.create_all()
def tearDownAll(self):
- t2.drop()
- t1.drop()
+ metadata.drop_all()
def setUp(self):
- objectstore.clear()
clear_mappers()
def testsingle(self):
@@ -53,9 +49,11 @@ class SelfReferentialTest(AssertMixin):
})
a = C1('head c1')
a.c1s.append(C1('another c1'))
- objectstore.commit()
- objectstore.delete(a)
- objectstore.commit()
+ sess = create_session(echo_uow=False)
+ sess.save(a)
+ sess.flush()
+ sess.delete(a)
+ sess.flush()
def testcycle(self):
class C1(Tester):
@@ -75,36 +73,34 @@ class SelfReferentialTest(AssertMixin):
a.c1s[0].c1s.append(C1('subchild2'))
a.c1s[1].c2s.append(C2('child2 data1'))
a.c1s[1].c2s.append(C2('child2 data2'))
- objectstore.commit()
+ sess = create_session(echo_uow=False)
+ sess.save(a)
+ sess.flush()
- objectstore.delete(a)
- objectstore.commit()
+ sess.delete(a)
+ sess.flush()
class BiDirectionalOneToManyTest(AssertMixin):
"""tests two mappers with a one-to-many relation to each other."""
def setUpAll(self):
- testbase.db.tables.clear()
- global t1
- global t2
- t1 = Table('t1', testbase.db,
+ global t1, t2, metadata
+ metadata = BoundMetaData(testbase.db)
+ t1 = Table('t1', metadata,
Column('c1', Integer, primary_key=True),
Column('c2', Integer, ForeignKey('t2.c1'))
)
- t2 = Table('t2', testbase.db,
+ t2 = Table('t2', metadata,
Column('c1', Integer, primary_key=True),
Column('c2', Integer)
)
- t2.create()
- t1.create()
+ metadata.create_all()
t2.c.c2.append_item(ForeignKey('t1.c1'))
def tearDownAll(self):
- t1.drop()
+ t1.drop()
t2.drop()
- def setUp(self):
- objectstore.clear()
- #objectstore.LOG = True
+ #metadata.drop_all()
+ def tearDown(self):
clear_mappers()
-
def testcycle(self):
class C1(object):pass
class C2(object):pass
@@ -123,35 +119,33 @@ class BiDirectionalOneToManyTest(AssertMixin):
a.c2s.append(b)
d.c1s.append(c)
b.c1s.append(c)
- objectstore.commit()
+ sess = create_session()
+ [sess.save(x) for x in [a,b,c,d,e,f]]
+ sess.flush()
class BiDirectionalOneToManyTest2(AssertMixin):
"""tests two mappers with a one-to-many relation to each other, with a second one-to-many on one of the mappers"""
def setUpAll(self):
- testbase.db.tables.clear()
- global t1
- global t2
- global t3
- t1 = Table('t1', testbase.db,
+ global t1, t2, t3, metadata
+ metadata = BoundMetaData(testbase.db)
+ t1 = Table('t1', metadata,
Column('c1', Integer, primary_key=True),
Column('c2', Integer, ForeignKey('t2.c1')),
)
- t2 = Table('t2', testbase.db,
+ t2 = Table('t2', metadata,
Column('c1', Integer, primary_key=True),
Column('c2', Integer),
)
t2.create()
t1.create()
t2.c.c2.append_item(ForeignKey('t1.c1'))
- t3 = Table('t1_data', testbase.db,
+ t3 = Table('t1_data', metadata,
Column('c1', Integer, primary_key=True),
Column('t1id', Integer, ForeignKey('t1.c1')),
Column('data', String(20)))
t3.create()
- def setUp(self):
- objectstore.clear()
- #objectstore.LOG = True
+ def tearDown(self):
clear_mappers()
def tearDownAll(self):
@@ -185,25 +179,28 @@ class BiDirectionalOneToManyTest2(AssertMixin):
a.data.append(C1Data('c1data1'))
a.data.append(C1Data('c1data2'))
c.data.append(C1Data('c1data3'))
- objectstore.commit()
+ sess = create_session()
+ [sess.save(x) for x in [a,b,c,d,e,f]]
+ sess.flush()
- objectstore.delete(d)
- objectstore.delete(c)
- objectstore.commit()
+ sess.delete(d)
+ sess.delete(c)
+ sess.flush()
class OneToManyManyToOneTest(AssertMixin):
"""tests two mappers, one has a one-to-many on the other mapper, the other has a separate many-to-one relationship to the first.
two tests will have a row for each item that is dependent on the other. without the "post_update" flag, such relationships
raise an exception when dependencies are sorted."""
def setUpAll(self):
- testbase.db.tables.clear()
+ global metadata
+ metadata = BoundMetaData(testbase.db)
global person
global ball
- ball = Table('ball', db,
+ ball = Table('ball', metadata,
Column('id', Integer, Sequence('ball_id_seq', optional=True), primary_key=True),
Column('person_id', Integer),
)
- person = Table('person', db,
+ person = Table('person', metadata,
Column('id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
Column('favoriteBall_id', Integer, ForeignKey('ball.id')),
# Column('favoriteBall_id', Integer),
@@ -223,9 +220,7 @@ class OneToManyManyToOneTest(AssertMixin):
person.drop()
ball.drop()
- def setUp(self):
- objectstore.clear()
- #objectstore.LOG = True
+ def tearDown(self):
clear_mappers()
def testcycle(self):
@@ -249,7 +244,10 @@ class OneToManyManyToOneTest(AssertMixin):
b = Ball()
p = Person()
p.balls.append(b)
- objectstore.commit()
+ sess = create_session()
+ sess.save(b)
+ sess.save(b)
+ sess.flush()
def testpostupdate_m2o(self):
"""tests a cycle between two rows, with a post_update on the many-to-one"""
@@ -275,76 +273,79 @@ class OneToManyManyToOneTest(AssertMixin):
p.balls.append(Ball())
p.balls.append(Ball())
p.favorateBall = b
-
- self.assert_sql(db, lambda: objectstore.uow().commit(), [
+ sess = create_session()
+ sess.save(b)
+ sess.save(p)
+
+ self.assert_sql(db, lambda: sess.flush(), [
(
"INSERT INTO person (favoriteBall_id) VALUES (:favoriteBall_id)",
{'favoriteBall_id': None}
),
(
"INSERT INTO ball (person_id) VALUES (:person_id)",
- lambda:{'person_id':p.id}
+ lambda ctx:{'person_id':p.id}
),
(
"INSERT INTO ball (person_id) VALUES (:person_id)",
- lambda:{'person_id':p.id}
+ lambda ctx:{'person_id':p.id}
),
(
"INSERT INTO ball (person_id) VALUES (:person_id)",
- lambda:{'person_id':p.id}
+ lambda ctx:{'person_id':p.id}
),
(
"INSERT INTO ball (person_id) VALUES (:person_id)",
- lambda:{'person_id':p.id}
+ lambda ctx:{'person_id':p.id}
),
(
"UPDATE person SET favoriteBall_id=:favoriteBall_id WHERE person.id = :person_id",
- lambda:[{'favoriteBall_id':p.favorateBall.id,'person_id':p.id}]
+ lambda ctx:{'favoriteBall_id':p.favorateBall.id,'person_id':p.id}
)
],
with_sequences= [
(
"INSERT INTO person (id, favoriteBall_id) VALUES (:id, :favoriteBall_id)",
- lambda:{'id':db.last_inserted_ids()[0], 'favoriteBall_id': None}
+ lambda ctx:{'id':ctx.last_inserted_ids()[0], 'favoriteBall_id': None}
),
(
"INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
- lambda:{'id':db.last_inserted_ids()[0],'person_id':p.id}
+ lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id}
),
(
"INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
- lambda:{'id':db.last_inserted_ids()[0],'person_id':p.id}
+ lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id}
),
(
"INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
- lambda:{'id':db.last_inserted_ids()[0],'person_id':p.id}
+ lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id}
),
(
"INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
- lambda:{'id':db.last_inserted_ids()[0],'person_id':p.id}
+ lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id}
),
# heres the post update
(
"UPDATE person SET favoriteBall_id=:favoriteBall_id WHERE person.id = :person_id",
- lambda:[{'favoriteBall_id':p.favorateBall.id,'person_id':p.id}]
+ lambda ctx:{'favoriteBall_id':p.favorateBall.id,'person_id':p.id}
)
])
- objectstore.delete(p)
- self.assert_sql(db, lambda: objectstore.uow().commit(), [
+ sess.delete(p)
+ self.assert_sql(db, lambda: sess.flush(), [
# heres the post update (which is a pre-update with deletes)
(
"UPDATE person SET favoriteBall_id=:favoriteBall_id WHERE person.id = :person_id",
- lambda:[{'person_id': p.id, 'favoriteBall_id': None}]
+ lambda ctx:{'person_id': p.id, 'favoriteBall_id': None}
),
(
"DELETE FROM ball WHERE ball.id = :id",
None
# order cant be predicted, but something like:
- #lambda:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}]
+ #lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}]
),
(
"DELETE FROM person WHERE person.id = :id",
- lambda:[{'id': p.id}]
+ lambda ctx:[{'id': p.id}]
)
@@ -377,8 +378,10 @@ class OneToManyManyToOneTest(AssertMixin):
b4 = Ball()
p.balls.append(b4)
p.favorateBall = b
-# objectstore.commit()
- self.assert_sql(db, lambda: objectstore.uow().commit(), [
+ sess = create_session()
+ [sess.save(x) for x in [b,p,b2,b3,b4]]
+
+ self.assert_sql(db, lambda: sess.flush(), [
(
"INSERT INTO ball (person_id) VALUES (:person_id)",
{'person_id':None}
@@ -397,92 +400,92 @@ class OneToManyManyToOneTest(AssertMixin):
),
(
"INSERT INTO person (favoriteBall_id) VALUES (:favoriteBall_id)",
- lambda:{'favoriteBall_id':b.id}
+ lambda ctx:{'favoriteBall_id':b.id}
),
# heres the post update on each one-to-many item
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id':p.id,'ball_id':b.id}]
+ lambda ctx:{'person_id':p.id,'ball_id':b.id}
),
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id':p.id,'ball_id':b2.id}]
+ lambda ctx:{'person_id':p.id,'ball_id':b2.id}
),
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id':p.id,'ball_id':b3.id}]
+ lambda ctx:{'person_id':p.id,'ball_id':b3.id}
),
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id':p.id,'ball_id':b4.id}]
+ lambda ctx:{'person_id':p.id,'ball_id':b4.id}
),
],
with_sequences=[
(
"INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
- lambda:{'id':db.last_inserted_ids()[0], 'person_id':None}
+ lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None}
),
(
"INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
- lambda:{'id':db.last_inserted_ids()[0], 'person_id':None}
+ lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None}
),
(
"INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
- lambda:{'id':db.last_inserted_ids()[0], 'person_id':None}
+ lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None}
),
(
"INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
- lambda:{'id':db.last_inserted_ids()[0], 'person_id':None}
+ lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None}
),
(
"INSERT INTO person (id, favoriteBall_id) VALUES (:id, :favoriteBall_id)",
- lambda:{'id':db.last_inserted_ids()[0], 'favoriteBall_id':b.id}
+ lambda ctx:{'id':ctx.last_inserted_ids()[0], 'favoriteBall_id':b.id}
),
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id':p.id,'ball_id':b.id}]
+ lambda ctx:{'person_id':p.id,'ball_id':b.id}
),
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id':p.id,'ball_id':b2.id}]
+ lambda ctx:{'person_id':p.id,'ball_id':b2.id}
),
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id':p.id,'ball_id':b3.id}]
+ lambda ctx:{'person_id':p.id,'ball_id':b3.id}
),
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id':p.id,'ball_id':b4.id}]
+ lambda ctx:{'person_id':p.id,'ball_id':b4.id}
),
])
- objectstore.delete(p)
- self.assert_sql(db, lambda: objectstore.uow().commit(), [
+ sess.delete(p)
+ self.assert_sql(db, lambda: sess.flush(), [
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id': None, 'ball_id': b.id}]
+ lambda ctx:{'person_id': None, 'ball_id': b.id}
),
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id': None, 'ball_id': b2.id}]
+ lambda ctx:{'person_id': None, 'ball_id': b2.id}
),
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id': None, 'ball_id': b3.id}]
+ lambda ctx:{'person_id': None, 'ball_id': b3.id}
),
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
- lambda:[{'person_id': None, 'ball_id': b4.id}]
+ lambda ctx:{'person_id': None, 'ball_id': b4.id}
),
(
"DELETE FROM person WHERE person.id = :id",
- lambda:[{'id':p.id}]
+ lambda ctx:[{'id':p.id}]
),
(
"DELETE FROM ball WHERE ball.id = :id",
None
# the order of deletion is not predictable, but its roughly:
- # lambda:[{'id': b.id}, {'id': b2.id}, {'id': b3.id}, {'id': b4.id}]
+ # lambda ctx:[{'id': b.id}, {'id': b2.id}, {'id': b3.id}, {'id': b4.id}]
)
])
diff --git a/test/defaults.py b/test/defaults.py
index 8d848f4c5..a271cbcb5 100644
--- a/test/defaults.py
+++ b/test/defaults.py
@@ -7,7 +7,7 @@ from sqlalchemy import *
import sqlalchemy
db = testbase.db
-testbase.echo=False
+
class DefaultTest(PersistTest):
def setUpAll(self):
@@ -20,15 +20,17 @@ class DefaultTest(PersistTest):
use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle'
is_oracle = db.engine.name == 'oracle'
- # select "count(1)" from the DB which returns different results
- # on different DBs
- currenttime = db.func.current_date(type=Date);
+ # select "count(1)" returns different results on different DBs
+ # also correct for "current_date" compatible as column default, value differences
+ currenttime = func.current_date(type=Date, engine=db);
if is_oracle:
- ts = db.func.sysdate().scalar()
+ ts = db.func.trunc(func.sysdate(), column("'DAY'")).scalar()
f = select([func.count(1) + 5], engine=db).scalar()
f2 = select([func.count(1) + 14], engine=db).scalar()
+ # TODO: engine propigation across nested functions not working
+ currenttime = func.trunc(currenttime, column("'DAY'"), engine=db)
def1 = currenttime
- def2 = text("sysdate")
+ def2 = func.trunc(text("sysdate"), column("'DAY'"))
deftype = Date
elif use_function_defaults:
f = select([func.count(1) + 5], engine=db).scalar()
@@ -72,9 +74,10 @@ class DefaultTest(PersistTest):
t.delete().execute()
def teststandalone(self):
- x = t.c.col1.default.execute()
+ c = db.engine.contextual_connect()
+ x = c.execute(t.c.col1.default)
y = t.c.col2.default.execute()
- z = t.c.col3.default.execute()
+ z = c.execute(t.c.col3.default)
self.assert_(50 <= x <= 57)
self.assert_(y == 'imthedefault')
self.assert_(z == f)
@@ -82,8 +85,8 @@ class DefaultTest(PersistTest):
self.assert_(5 <= z <= 6)
def testinsert(self):
- t.insert().execute()
- self.assert_(t.engine.lastrow_has_defaults())
+ r = t.insert().execute()
+ self.assert_(r.lastrow_has_defaults())
t.insert().execute()
t.insert().execute()
@@ -99,8 +102,8 @@ class DefaultTest(PersistTest):
def testupdate(self):
- t.insert().execute()
- pk = t.engine.last_inserted_ids()[0]
+ r = t.insert().execute()
+ pk = r.last_inserted_ids()[0]
t.update(t.c.col1==pk).execute(col4=None, col5=None)
ctexec = currenttime.scalar()
self.echo("Currenttime "+ repr(ctexec))
@@ -111,8 +114,8 @@ class DefaultTest(PersistTest):
self.assert_(14 <= f2 <= 15)
def testupdatevalues(self):
- t.insert().execute()
- pk = t.engine.last_inserted_ids()[0]
+ r = t.insert().execute()
+ pk = r.last_inserted_ids()[0]
t.update(t.c.col1==pk, values={'col3': 55}).execute()
l = t.select(t.c.col1==pk).execute()
l = l.fetchone()
@@ -143,10 +146,10 @@ class SequenceTest(PersistTest):
@testbase.supported('postgres', 'oracle')
def teststandalone(self):
- s = Sequence("my_sequence", engine=db)
+ s = Sequence("my_sequence", metadata=testbase.db)
s.create()
try:
- x =s.execute()
+ x = s.execute()
self.assert_(x == 1)
finally:
s.drop()
diff --git a/test/dependency.py b/test/dependency.py
index 81165dc6d..d2b5bd698 100644
--- a/test/dependency.py
+++ b/test/dependency.py
@@ -1,5 +1,5 @@
from testbase import PersistTest
-import sqlalchemy.mapping.topological as topological
+import sqlalchemy.orm.topological as topological
import unittest, sys, os
@@ -17,26 +17,6 @@ class thingy(object):
return repr(self)
class DependencySortTest(PersistTest):
-
- def _assert_sort(self, tuples, allnodes, **kwargs):
-
- head = DependencySorter(tuples, allnodes).sort(**kwargs)
-
- print "\n" + str(head)
- def findnode(t, n, parent=False):
- if n.item is t[0] or (n.cycles is not None and t[0] in [c.item for c in n.cycles]):
- parent=True
- elif n.item is t[1]:
- if not parent and (n.cycles is None or t[0] not in [c.item for c in n.cycles]):
- self.assert_(False, "Node " + str(t[1]) + " not a child of " +str(t[0]))
- else:
- return
- for c in n.children:
- findnode(t, c, parent)
-
- for t in tuples:
- findnode(t, head)
-
def testsort(self):
rootnode = thingy('root')
node2 = thingy('node2')
@@ -47,7 +27,6 @@ class DependencySortTest(PersistTest):
subnode3 = thingy('subnode3')
subnode4 = thingy('subnode4')
subsubnode1 = thingy('subsubnode1')
- allnodes = [rootnode, node2,node3,node4,subnode1,subnode2,subnode3,subnode4,subsubnode1]
tuples = [
(subnode3, subsubnode1),
(node2, subnode1),
@@ -58,8 +37,8 @@ class DependencySortTest(PersistTest):
(node4, subnode3),
(node4, subnode4)
]
-
- self._assert_sort(tuples, allnodes)
+ head = DependencySorter(tuples, []).sort()
+ print "\n" + str(head)
def testsort2(self):
node1 = thingy('node1')
@@ -76,7 +55,8 @@ class DependencySortTest(PersistTest):
(node5, node6),
(node6, node2)
]
- self._assert_sort(tuples, [node1,node2,node3,node4,node5,node6,node7])
+ head = DependencySorter(tuples, [node7]).sort()
+ print "\n" + str(head)
def testsort3(self):
['Mapper|Keyword|keywords,Mapper|IKAssociation|itemkeywords', 'Mapper|Item|items,Mapper|IKAssociation|itemkeywords']
@@ -88,10 +68,15 @@ class DependencySortTest(PersistTest):
(node3, node2),
(node1,node3)
]
- self._assert_sort(tuples, [node1, node2, node3])
- self._assert_sort(tuples, [node3, node1, node2])
- self._assert_sort(tuples, [node3, node2, node1])
+ head1 = DependencySorter(tuples, [node1, node2, node3]).sort()
+ head2 = DependencySorter(tuples, [node3, node1, node2]).sort()
+ head3 = DependencySorter(tuples, [node3, node2, node1]).sort()
+ # TODO: figure out a "node == node2" function
+ #self.assert_(str(head1) == str(head2) == str(head3))
+ print "\n" + str(head1)
+ print "\n" + str(head2)
+ print "\n" + str(head3)
def testsort4(self):
node1 = thingy('keywords')
@@ -104,7 +89,8 @@ class DependencySortTest(PersistTest):
(node1, node3),
(node3, node2)
]
- self._assert_sort(tuples, [node1,node2,node3,node4])
+ head = DependencySorter(tuples, []).sort()
+ print "\n" + str(head)
def testsort5(self):
# this one, depenending on the weather,
@@ -131,24 +117,10 @@ class DependencySortTest(PersistTest):
node3,
node4
]
- self._assert_sort(tuples, allitems)
-
- def testsort6(self):
- #('tbl_c', 'tbl_d'), ('tbl_a', 'tbl_c'), ('tbl_b', 'tbl_d')
- nodea = thingy('tbl_a')
- nodeb = thingy('tbl_b')
- nodec = thingy('tbl_c')
- noded = thingy('tbl_d')
- tuples = [
- (nodec, noded),
- (nodea, nodec),
- (nodeb, noded)
- ]
- allitems = [nodea,nodeb,nodec,noded]
- self._assert_sort(tuples, allitems)
+ head = DependencySorter(tuples, allitems).sort()
+ print "\n" + str(head)
def testcircular(self):
- #print "TESTCIRCULAR"
node1 = thingy('node1')
node2 = thingy('node2')
node3 = thingy('node3')
@@ -162,8 +134,8 @@ class DependencySortTest(PersistTest):
(node3, node1),
(node4, node1)
]
- self._assert_sort(tuples, [node1,node2,node3,node4,node5], allow_all_cycles=True)
- #print "TESTCIRCULAR DONE"
+ head = DependencySorter(tuples, []).sort(allow_all_cycles=True)
+ print "\n" + str(head)
if __name__ == "__main__":
diff --git a/test/eagertest1.py b/test/eagertest1.py
index 5897e4016..9765379f4 100644
--- a/test/eagertest1.py
+++ b/test/eagertest1.py
@@ -7,65 +7,60 @@ import datetime
class EagerTest(AssertMixin):
def setUpAll(self):
global designType, design, part, inheritedPart
-
- designType = Table('design_types', testbase.db,
+ designType = Table('design_types', testbase.metadata,
Column('design_type_id', Integer, primary_key=True),
)
- design =Table('design', testbase.db,
+ design =Table('design', testbase.metadata,
Column('design_id', Integer, primary_key=True),
Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
- part = Table('parts', testbase.db,
+ part = Table('parts', testbase.metadata,
Column('part_id', Integer, primary_key=True),
Column('design_id', Integer, ForeignKey('design.design_id')),
Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
- inheritedPart = Table('inherited_part', testbase.db,
+ inheritedPart = Table('inherited_part', testbase.metadata,
Column('ip_id', Integer, primary_key=True),
Column('part_id', Integer, ForeignKey('parts.part_id')),
Column('design_id', Integer, ForeignKey('design.design_id')),
)
- designType.create()
- design.create()
- part.create()
- inheritedPart.create()
+ testbase.metadata.create_all()
def tearDownAll(self):
- inheritedPart.drop()
- part.drop()
- design.drop()
- designType.drop()
-
+ testbase.metadata.drop_all()
+ testbase.metadata.clear()
def testone(self):
class Part(object):pass
class Design(object):pass
class DesignType(object):pass
class InheritedPart(object):pass
- assign_mapper(Part, part)
+ mapper(Part, part)
- assign_mapper(InheritedPart, inheritedPart, properties=dict(
+ mapper(InheritedPart, inheritedPart, properties=dict(
part=relation(Part, lazy=False)
))
- assign_mapper(Design, design, properties=dict(
+ mapper(Design, design, properties=dict(
parts=relation(Part, private=True, backref="design"),
inheritedParts=relation(InheritedPart, private=True, backref="design"),
))
- assign_mapper(DesignType, designType, properties=dict(
+ mapper(DesignType, designType, properties=dict(
# designs=relation(Design, private=True, backref="type"),
))
- Design.mapper.add_property("type", relation(DesignType, lazy=False, backref="designs"))
- Part.mapper.add_property("design", relation(Design, lazy=False, backref="parts"))
+ class_mapper(Design).add_property("type", relation(DesignType, lazy=False, backref="designs"))
+ class_mapper(Part).add_property("design", relation(Design, lazy=False, backref="parts"))
#Part.mapper.add_property("designType", relation(DesignType))
d = Design()
- objectstore.commit()
- objectstore.clear()
- x = Design.get(1)
+ sess = create_session()
+ sess.save(d)
+ sess.flush()
+ sess.clear()
+ x = sess.query(Design).get(1)
x.inheritedParts
if __name__ == "__main__":
diff --git a/test/eagertest2.py b/test/eagertest2.py
index 430e12ba7..ef385df16 100644
--- a/test/eagertest2.py
+++ b/test/eagertest2.py
@@ -3,70 +3,56 @@ import testbase
import unittest, sys, os
from sqlalchemy import *
import datetime
-
-db = testbase.db
+from sqlalchemy.ext.sessioncontext import SessionContext
class EagerTest(AssertMixin):
def setUpAll(self):
- objectstore.clear()
- clear_mappers()
- testbase.db.tables.clear()
-
- global companies_table, addresses_table, invoice_table, phones_table, items_table
+ global companies_table, addresses_table, invoice_table, phones_table, items_table, ctx, metadata
- companies_table = Table('companies', db,
+ metadata = BoundMetaData(testbase.db)
+ ctx = SessionContext(create_session)
+
+ companies_table = Table('companies', metadata,
Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key = True),
Column('company_name', String(40)),
)
- addresses_table = Table('addresses', db,
+ addresses_table = Table('addresses', metadata,
Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
Column('company_id', Integer, ForeignKey("companies.company_id")),
Column('address', String(40)),
)
- phones_table = Table('phone_numbers', db,
+ phones_table = Table('phone_numbers', metadata,
Column('phone_id', Integer, Sequence('phone_id_seq', optional=True), primary_key = True),
Column('address_id', Integer, ForeignKey('addresses.address_id')),
Column('type', String(20)),
Column('number', String(10)),
)
- invoice_table = Table('invoices', db,
+ invoice_table = Table('invoices', metadata,
Column('invoice_id', Integer, Sequence('invoice_id_seq', optional=True), primary_key = True),
Column('company_id', Integer, ForeignKey("companies.company_id")),
Column('date', DateTime),
)
- items_table = Table('items', db,
+ items_table = Table('items', metadata,
Column('item_id', Integer, Sequence('item_id_seq', optional=True), primary_key = True),
Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')),
Column('code', String(20)),
Column('qty', Integer),
)
- companies_table.create()
- addresses_table.create()
- phones_table.create()
- invoice_table.create()
- items_table.create()
+ metadata.create_all()
def tearDownAll(self):
- items_table.drop()
- invoice_table.drop()
- phones_table.drop()
- addresses_table.drop()
- companies_table.drop()
+ metadata.drop_all()
def tearDown(self):
- objectstore.clear()
clear_mappers()
- items_table.delete().execute()
- invoice_table.delete().execute()
- phones_table.delete().execute()
- addresses_table.delete().execute()
- companies_table.delete().execute()
+ for t in metadata.table_iterator(reverse=True):
+ t.delete().execute()
def testone(self):
"""tests eager load of a many-to-one attached to a one-to-many. this testcase illustrated
@@ -88,14 +74,14 @@ class EagerTest(AssertMixin):
def __repr__(self):
return "Invoice:" + repr(getattr(self, 'invoice_id', None)) + " " + repr(getattr(self, 'date', None)) + " " + repr(self.company)
- Address.mapper = mapper(Address, addresses_table, properties={
- })
- Company.mapper = mapper(Company, companies_table, properties={
- 'addresses' : relation(Address.mapper, lazy=False),
- })
- Invoice.mapper = mapper(Invoice, invoice_table, properties={
- 'company': relation(Company.mapper, lazy=False, )
- })
+ mapper(Address, addresses_table, properties={
+ }, extension=ctx.mapper_extension)
+ mapper(Company, companies_table, properties={
+ 'addresses' : relation(Address, lazy=False),
+ }, extension=ctx.mapper_extension)
+ mapper(Invoice, invoice_table, properties={
+ 'company': relation(Company, lazy=False, )
+ }, extension=ctx.mapper_extension)
c1 = Company()
c1.company_name = 'company 1'
@@ -109,19 +95,18 @@ class EagerTest(AssertMixin):
i1.date = datetime.datetime.now()
i1.company = c1
-
- objectstore.commit()
+ ctx.current.flush()
company_id = c1.company_id
invoice_id = i1.invoice_id
- objectstore.clear()
+ ctx.current.clear()
- c = Company.mapper.get(company_id)
+ c = ctx.current.query(Company).get(company_id)
- objectstore.clear()
+ ctx.current.clear()
- i = Invoice.mapper.get(invoice_id)
+ i = ctx.current.query(Invoice).get(invoice_id)
self.echo(repr(c))
self.echo(repr(i.company))
@@ -153,24 +138,24 @@ class EagerTest(AssertMixin):
def __repr__(self):
return "Item: " + repr(getattr(self, 'item_id', None)) + " " + repr(getattr(self, 'invoice_id', None)) + " " + repr(self.code) + " " + repr(self.qty)
- Phone.mapper = mapper(Phone, phones_table, is_primary=True)
+ mapper(Phone, phones_table, extension=ctx.mapper_extension)
- Address.mapper = mapper(Address, addresses_table, properties={
- 'phones': relation(Phone.mapper, lazy=False, backref='address')
- })
+ mapper(Address, addresses_table, properties={
+ 'phones': relation(Phone, lazy=False, backref='address')
+ }, extension=ctx.mapper_extension)
- Company.mapper = mapper(Company, companies_table, properties={
- 'addresses' : relation(Address.mapper, lazy=False, backref='company'),
- })
+ mapper(Company, companies_table, properties={
+ 'addresses' : relation(Address, lazy=False, backref='company'),
+ }, extension=ctx.mapper_extension)
- Item.mapper = mapper(Item, items_table, is_primary=True)
+ mapper(Item, items_table, extension=ctx.mapper_extension)
- Invoice.mapper = mapper(Invoice, invoice_table, properties={
- 'items': relation(Item.mapper, lazy=False, backref='invoice'),
- 'company': relation(Company.mapper, lazy=False, backref='invoices')
- })
+ mapper(Invoice, invoice_table, properties={
+ 'items': relation(Item, lazy=False, backref='invoice'),
+ 'company': relation(Company, lazy=False, backref='invoices')
+ }, extension=ctx.mapper_extension)
- objectstore.clear()
+ ctx.current.clear()
c1 = Company()
c1.company_name = 'company 1'
@@ -205,13 +190,13 @@ class EagerTest(AssertMixin):
c1.addresses.append(a2)
- objectstore.commit()
+ ctx.current.flush()
company_id = c1.company_id
- objectstore.clear()
+ ctx.current.clear()
- a = Company.mapper.get(company_id)
+ a = ctx.current.query(Company).get(company_id)
self.echo(repr(a))
# set up an invoice
@@ -234,18 +219,18 @@ class EagerTest(AssertMixin):
item3.qty = 3
item3.invoice = i1
- objectstore.commit()
+ ctx.current.flush()
invoice_id = i1.invoice_id
- objectstore.clear()
+ ctx.current.clear()
- c = Company.mapper.get(company_id)
+ c = ctx.current.query(Company).get(company_id)
self.echo(repr(c))
- objectstore.clear()
+ ctx.current.clear()
- i = Invoice.mapper.get(invoice_id)
+ i = ctx.current.query(Invoice).get(invoice_id)
self.echo(repr(i))
self.assert_(repr(i.company) == repr(c))
diff --git a/test/engine.py b/test/engine.py
deleted file mode 100644
index 193201fa1..000000000
--- a/test/engine.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from sqlalchemy import *
-
-from testbase import PersistTest
-import testbase
-import unittest, re
-import tables
-
-class TransactionTest(PersistTest):
- def setUpAll(self):
- tables.create()
- def tearDownAll(self):
- tables.drop()
- def tearDown(self):
- tables.delete()
-
- def testbasic(self):
- testbase.db.begin()
- tables.users.insert().execute(user_name='jack')
- tables.users.insert().execute(user_name='fred')
- testbase.db.commit()
- l = tables.users.select().execute().fetchall()
- print l
- self.assert_(len(l) == 2)
-
- def testrollback(self):
- testbase.db.begin()
- tables.users.insert().execute(user_name='jack')
- tables.users.insert().execute(user_name='fred')
- testbase.db.rollback()
- l = tables.users.select().execute().fetchall()
- print l
- self.assert_(len(l) == 0)
-
- @testbase.unsupported('sqlite')
- def testnested(self):
- """tests nested sessions. SQLite should raise an error."""
- testbase.db.begin()
- tables.users.insert().execute(user_name='jack')
- tables.users.insert().execute(user_name='fred')
- testbase.db.push_session()
- tables.users.insert().execute(user_name='ed')
- tables.users.insert().execute(user_name='wendy')
- testbase.db.pop_session()
- testbase.db.rollback()
- l = tables.users.select().execute().fetchall()
- print l
- self.assert_(len(l) == 2)
-
- def testtwo(self):
- testbase.db.begin()
- tables.users.insert().execute(user_name='jack')
- tables.users.insert().execute(user_name='fred')
- testbase.db.commit()
- testbase.db.begin()
- tables.users.insert().execute(user_name='ed')
- tables.users.insert().execute(user_name='wendy')
- testbase.db.commit()
- testbase.db.rollback()
- l = tables.users.select().execute().fetchall()
- print l
- self.assert_(len(l) == 4)
-
-if __name__ == "__main__":
- testbase.main()
diff --git a/test/entity.py b/test/entity.py
index 591cff7ea..22e74f171 100644
--- a/test/entity.py
+++ b/test/entity.py
@@ -2,6 +2,7 @@ from testbase import PersistTest, AssertMixin
import unittest
from sqlalchemy import *
import testbase
+from sqlalchemy.ext.sessioncontext import SessionContext
from tables import *
import tables
@@ -10,38 +11,35 @@ class EntityTest(AssertMixin):
"""tests mappers that are constructed based on "entity names", which allows the same class
to have multiple primary mappers """
def setUpAll(self):
- global user1, user2, address1, address2
- db = testbase.db
- user1 = Table('user1', db,
+ global user1, user2, address1, address2, metadata, ctx
+ metadata = BoundMetaData(testbase.db)
+ ctx = SessionContext(create_session)
+
+ user1 = Table('user1', metadata,
Column('user_id', Integer, Sequence('user1_id_seq'), primary_key=True),
Column('name', String(60), nullable=False)
- ).create()
- user2 = Table('user2', db,
+ )
+ user2 = Table('user2', metadata,
Column('user_id', Integer, Sequence('user2_id_seq'), primary_key=True),
Column('name', String(60), nullable=False)
- ).create()
- address1 = Table('address1', db,
+ )
+ address1 = Table('address1', metadata,
Column('address_id', Integer, Sequence('address1_id_seq'), primary_key=True),
Column('user_id', Integer, ForeignKey(user1.c.user_id), nullable=False),
Column('email', String(100), nullable=False)
- ).create()
- address2 = Table('address2', db,
+ )
+ address2 = Table('address2', metadata,
Column('address_id', Integer, Sequence('address2_id_seq'), primary_key=True),
Column('user_id', Integer, ForeignKey(user2.c.user_id), nullable=False),
Column('email', String(100), nullable=False)
- ).create()
+ )
+ metadata.create_all()
def tearDownAll(self):
- address1.drop()
- address2.drop()
- user1.drop()
- user2.drop()
+ metadata.drop_all()
def tearDown(self):
- address1.delete().execute()
- address2.delete().execute()
- user1.delete().execute()
- user2.delete().execute()
- objectstore.clear()
clear_mappers()
+ for t in metadata.table_iterator(reverse=True):
+ t.delete().execute()
def testbasic(self):
"""tests a pair of one-to-many mapper structures, establishing that both
@@ -50,14 +48,14 @@ class EntityTest(AssertMixin):
class User(object):pass
class Address(object):pass
- a1mapper = mapper(Address, address1, entity_name='address1')
- a2mapper = mapper(Address, address2, entity_name='address2')
+ a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.mapper_extension)
+ a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.mapper_extension)
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'addresses':relation(a1mapper)
- })
+ }, extension=ctx.mapper_extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'addresses':relation(a2mapper)
- })
+ }, extension=ctx.mapper_extension)
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
@@ -71,15 +69,15 @@ class EntityTest(AssertMixin):
a2.email='a2@foo.com'
u2.addresses.append(a2)
- objectstore.commit()
+ ctx.current.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')]
assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')]
- objectstore.clear()
- u1list = u1mapper.select()
- u2list = u2mapper.select()
+ ctx.current.clear()
+ u1list = ctx.current.query(User, entity_name='user1').select()
+ u2list = ctx.current.query(User, entity_name='user2').select()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
@@ -90,14 +88,14 @@ class EntityTest(AssertMixin):
class Address1(object):pass
class Address2(object):pass
- a1mapper = mapper(Address1, address1)
- a2mapper = mapper(Address2, address2)
+ a1mapper = mapper(Address1, address1, extension=ctx.mapper_extension)
+ a2mapper = mapper(Address2, address2, extension=ctx.mapper_extension)
u1mapper = mapper(User, user1, entity_name='user1', properties ={
'addresses':relation(a1mapper)
- })
+ }, extension=ctx.mapper_extension)
u2mapper =mapper(User, user2, entity_name='user2', properties={
'addresses':relation(a2mapper)
- })
+ }, extension=ctx.mapper_extension)
u1 = User(_sa_entity_name='user1')
u1.name = 'this is user 1'
@@ -111,15 +109,15 @@ class EntityTest(AssertMixin):
a2.email='a2@foo.com'
u2.addresses.append(a2)
- objectstore.commit()
+ ctx.current.flush()
assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')]
assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')]
- objectstore.clear()
- u1list = u1mapper.select()
- u2list = u2mapper.select()
+ ctx.current.clear()
+ u1list = ctx.current.query(User, entity_name='user1').select()
+ u2list = ctx.current.query(User, entity_name='user2').select()
assert len(u1list) == len(u2list) == 1
assert u1list[0] is not u2list[0]
assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
diff --git a/test/indexes.py b/test/indexes.py
index fbf1a2c81..a111b34e5 100644
--- a/test/indexes.py
+++ b/test/indexes.py
@@ -5,59 +5,51 @@ import testbase
class IndexTest(testbase.AssertMixin):
def setUp(self):
- self.created = []
+ global metadata
+ metadata = BoundMetaData(testbase.db)
self.echo = testbase.db.echo
self.logger = testbase.db.logger
def tearDown(self):
testbase.db.echo = self.echo
testbase.db.logger = testbase.db.engine.logger = self.logger
- if self.created:
- self.created.reverse()
- for entity in self.created:
- entity.drop()
+ metadata.drop_all()
def test_index_create(self):
- employees = Table('employees', testbase.db,
+ employees = Table('employees', metadata,
Column('id', Integer, primary_key=True),
Column('first_name', String(30)),
Column('last_name', String(30)),
Column('email_address', String(30)))
employees.create()
- self.created.append(employees)
i = Index('employee_name_index',
employees.c.last_name, employees.c.first_name)
i.create()
- self.created.append(i)
assert employees.indexes['employee_name_index'] is i
i2 = Index('employee_email_index',
employees.c.email_address, unique=True)
i2.create()
- self.created.append(i2)
assert employees.indexes['employee_email_index'] is i2
def test_index_create_camelcase(self):
"""test that mixed-case index identifiers are legal"""
- employees = Table('companyEmployees', testbase.db,
+ employees = Table('companyEmployees', metadata,
Column('id', Integer, primary_key=True),
Column('firstName', String(30)),
Column('lastName', String(30)),
Column('emailAddress', String(30)))
employees.create()
- self.created.append(employees)
i = Index('employeeNameIndex',
employees.c.lastName, employees.c.firstName)
i.create()
- self.created.append(i)
i = Index('employeeEmailIndex',
employees.c.emailAddress, unique=True)
i.create()
- self.created.append(i)
# Check that the table is useable. This is mostly for pg,
# which can be somewhat sticky with mixed-case identifiers
@@ -75,8 +67,7 @@ class IndexTest(testbase.AssertMixin):
stream = dummy()
stream.write = capt.append
testbase.db.logger = testbase.db.engine.logger = stream
-
- events = Table('events', testbase.db,
+ events = Table('events', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(30), unique=True),
Column('location', String(30), index=True),
@@ -94,22 +85,21 @@ class IndexTest(testbase.AssertMixin):
assert len(index_names) == 4
events.create()
- self.created.append(events)
# verify that the table is functional
events.insert().execute(id=1, name='hockey finals', location='rink',
sport='hockey', announcer='some canadian',
winner='sweden')
ss = events.select().execute().fetchall()
-
+
assert capt[0].strip().startswith('CREATE TABLE events')
- assert capt[2].strip() == \
+ assert capt[3].strip() == \
'CREATE UNIQUE INDEX ux_events_name ON events (name)'
- assert capt[4].strip() == \
- 'CREATE INDEX ix_events_location ON events (location)'
assert capt[6].strip() == \
+ 'CREATE INDEX ix_events_location ON events (location)'
+ assert capt[9].strip() == \
'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'
- assert capt[8].strip() == \
+ assert capt[12].strip() == \
'CREATE INDEX idx_winners ON events (winner)'
if __name__ == "__main__":
diff --git a/test/inheritance.py b/test/inheritance.py
index 8b683b8ff..265134173 100644
--- a/test/inheritance.py
+++ b/test/inheritance.py
@@ -1,11 +1,13 @@
-from sqlalchemy import *
import testbase
+from sqlalchemy import *
import string
import sqlalchemy.attributes as attr
import sys
class Principal( object ):
- pass
+ def __init__(self, **kwargs):
+ for key, value in kwargs.iteritems():
+ setattr(self, key, value)
class User( Principal ):
pass
@@ -20,16 +22,18 @@ class InheritTest(testbase.AssertMixin):
global users
global groups
global user_group_map
+ global metadata
+ metadata = BoundMetaData(testbase.db)
principals = Table(
'principals',
- testbase.db,
+ metadata,
Column('principal_id', Integer, Sequence('principal_id_seq', optional=False), primary_key=True),
Column('name', String(50), nullable=False),
)
users = Table(
'prin_users',
- testbase.db,
+ metadata,
Column('principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True),
Column('password', String(50), nullable=False),
Column('email', String(50), nullable=False),
@@ -39,14 +43,14 @@ class InheritTest(testbase.AssertMixin):
groups = Table(
'prin_groups',
- testbase.db,
+ metadata,
Column( 'principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True),
)
user_group_map = Table(
'prin_user_group_map',
- testbase.db,
+ metadata,
Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), primary_key=True ),
Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), primary_key=True ),
#Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), ),
@@ -54,65 +58,57 @@ class InheritTest(testbase.AssertMixin):
)
- principals.create()
- users.create()
- groups.create()
- user_group_map.create()
+ metadata.create_all()
+
def tearDownAll(self):
- user_group_map.drop()
- groups.drop()
- users.drop()
- principals.drop()
- testbase.db.tables.clear()
+ metadata.drop_all()
+
def setUp(self):
- objectstore.clear()
clear_mappers()
def testbasic(self):
- assign_mapper( Principal, principals )
- assign_mapper(
+ mapper( Principal, principals )
+ mapper(
User,
users,
- inherits=Principal.mapper
+ inherits=Principal
)
- assign_mapper(
+ mapper(
Group,
groups,
- inherits=Principal.mapper,
- properties=dict( users = relation(User.mapper, user_group_map, lazy=True, backref="groups") )
+ inherits=Principal,
+ properties=dict( users = relation(User, secondary=user_group_map, lazy=True, backref="groups") )
)
g = Group(name="group1")
g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1"))
-
- objectstore.commit()
+ sess = create_session()
+ sess.save(g)
+ sess.flush()
# TODO: put an assertion
class InheritTest2(testbase.AssertMixin):
"""deals with inheritance and many-to-many relationships"""
def setUpAll(self):
- engine = testbase.db
- global foo, bar, foo_bar
- foo = Table('foo', engine,
+ global foo, bar, foo_bar, metadata
+ metadata = BoundMetaData(testbase.db)
+ foo = Table('foo', metadata,
Column('id', Integer, Sequence('foo_id_seq'), primary_key=True),
Column('data', String(20)),
).create()
- bar = Table('bar', engine,
+ bar = Table('bar', metadata,
Column('bid', Integer, ForeignKey('foo.id'), primary_key=True),
#Column('fid', Integer, ForeignKey('foo.id'), )
).create()
- foo_bar = Table('foo_bar', engine,
+ foo_bar = Table('foo_bar', metadata,
Column('foo_id', Integer, ForeignKey('foo.id')),
Column('bar_id', Integer, ForeignKey('bar.bid'))).create()
-
+ metadata.create_all()
def tearDownAll(self):
- foo_bar.drop()
- bar.drop()
- foo.drop()
- testbase.db.tables.clear()
+ metadata.drop_all()
def testbasic(self):
class Foo(object):
@@ -123,33 +119,28 @@ class InheritTest2(testbase.AssertMixin):
def __repr__(self):
return str(self)
- Foo.mapper = mapper(Foo, foo)
+ mapper(Foo, foo)
class Bar(Foo):
def __str__(self):
return "Bar(%s)" % self.data
- Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper, properties = {
- # the old way, you needed to explicitly set up a compound
- # column like this. but now the mapper uses SyncRules to match up
- # the parent/child inherited columns
- #'id':[bar.c.bid, foo.c.id]
- })
-
- #Bar.mapper.add_property('foos', relation(Foo.mapper, foo_bar, primaryjoin=bar.c.bid==foo_bar.c.bar_id, secondaryjoin=foo_bar.c.foo_id==foo.c.id, lazy=False))
- Bar.mapper.add_property('foos', relation(Foo.mapper, foo_bar, lazy=False))
-
- b = Bar('barfoo')
- objectstore.commit()
+ mapper(Bar, bar, inherits=Foo, properties={
+ 'foos': relation(Foo, secondary=foo_bar, lazy=False)
+ })
+
+ sess = create_session()
+ b = Bar('barfoo', _sa_session=sess)
+ sess.flush()
f1 = Foo('subfoo1')
f2 = Foo('subfoo2')
b.foos.append(f1)
b.foos.append(f2)
- objectstore.commit()
- objectstore.clear()
+ sess.flush()
+ sess.clear()
- l =b.mapper.select()
+ l = sess.query(Bar).select()
print l[0]
print l[0].foos
self.assert_result(l, Bar,
@@ -160,44 +151,38 @@ class InheritTest2(testbase.AssertMixin):
class InheritTest3(testbase.AssertMixin):
"""deals with inheritance and many-to-many relationships"""
def setUpAll(self):
- engine = testbase.db
- global foo, bar, blub, bar_foo, blub_bar, blub_foo,tables
- engine.engine.echo = 'debug'
+ global foo, bar, blub, bar_foo, blub_bar, blub_foo,metadata
+ metadata = BoundMetaData(testbase.db)
# the 'data' columns are to appease SQLite which cant handle a blank INSERT
- foo = Table('foo', engine,
+ foo = Table('foo', metadata,
Column('id', Integer, Sequence('foo_seq'), primary_key=True),
Column('data', String(20)))
- bar = Table('bar', engine,
+ bar = Table('bar', metadata,
Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
Column('data', String(20)))
- blub = Table('blub', engine,
+ blub = Table('blub', metadata,
Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
Column('data', String(20)))
- bar_foo = Table('bar_foo', engine,
+ bar_foo = Table('bar_foo', metadata,
Column('bar_id', Integer, ForeignKey('bar.id')),
Column('foo_id', Integer, ForeignKey('foo.id')))
- blub_bar = Table('bar_blub', engine,
+ blub_bar = Table('bar_blub', metadata,
Column('blub_id', Integer, ForeignKey('blub.id')),
Column('bar_id', Integer, ForeignKey('bar.id')))
- blub_foo = Table('blub_foo', engine,
+ blub_foo = Table('blub_foo', metadata,
Column('blub_id', Integer, ForeignKey('blub.id')),
Column('foo_id', Integer, ForeignKey('foo.id')))
-
- tables = [foo, bar, blub, bar_foo, blub_bar, blub_foo]
- for table in tables:
- table.create()
+ metadata.create_all()
def tearDownAll(self):
- for table in reversed(tables):
- table.drop()
- testbase.db.tables.clear()
+ metadata.drop_all()
def tearDown(self):
- for table in reversed(tables):
+ for table in metadata.table_iterator():
table.delete().execute()
def testbasic(self):
@@ -206,24 +191,24 @@ class InheritTest3(testbase.AssertMixin):
self.data = data
def __repr__(self):
return "Foo id %d, data %s" % (self.id, self.data)
- Foo.mapper = mapper(Foo, foo)
+ mapper(Foo, foo)
class Bar(Foo):
def __repr__(self):
return "Bar id %d, data %s" % (self.id, self.data)
- Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper, properties={
- #'foos' :relation(Foo.mapper, bar_foo, primaryjoin=bar.c.id==bar_foo.c.bar_id, lazy=False)
- 'foos' :relation(Foo.mapper, bar_foo, lazy=True)
+ mapper(Bar, bar, inherits=Foo, properties={
+ 'foos' :relation(Foo, secondary=bar_foo, lazy=True)
})
- b = Bar('bar #1')
+ sess = create_session()
+ b = Bar('bar #1', _sa_session=sess)
b.foos.append(Foo("foo #1"))
b.foos.append(Foo("foo #2"))
- objectstore.commit()
+ sess.flush()
compare = repr(b) + repr(b.foos)
- objectstore.clear()
- l = Bar.mapper.select()
+ sess.clear()
+ l = sess.query(Bar).select()
self.echo(repr(l[0]) + repr(l[0].foos))
self.assert_(repr(l[0]) + repr(l[0].foos) == compare)
@@ -233,83 +218,66 @@ class InheritTest3(testbase.AssertMixin):
self.data = data
def __repr__(self):
return "Foo id %d, data %s" % (self.id, self.data)
- Foo.mapper = mapper(Foo, foo)
+ mapper(Foo, foo)
class Bar(Foo):
def __repr__(self):
return "Bar id %d, data %s" % (self.id, self.data)
- Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper)
+ mapper(Bar, bar, inherits=Foo)
class Blub(Bar):
def __repr__(self):
return "Blub id %d, data %s, bars %s, foos %s" % (self.id, self.data, repr([b for b in self.bars]), repr([f for f in self.foos]))
- Blub.mapper = mapper(Blub, blub, inherits=Bar.mapper, properties={
-# 'bars':relation(Bar.mapper, blub_bar, primaryjoin=blub.c.id==blub_bar.c.blub_id, lazy=False),
-# 'foos':relation(Foo.mapper, blub_foo, primaryjoin=blub.c.id==blub_foo.c.blub_id, lazy=False),
- 'bars':relation(Bar.mapper, blub_bar, lazy=False),
- 'foos':relation(Foo.mapper, blub_foo, lazy=False),
+ mapper(Blub, blub, inherits=Bar, properties={
+ 'bars':relation(Bar, secondary=blub_bar, lazy=False),
+ 'foos':relation(Foo, secondary=blub_foo, lazy=False),
})
- useobjects = True
- if (useobjects):
- f1 = Foo("foo #1")
- b1 = Bar("bar #1")
- b2 = Bar("bar #2")
- bl1 = Blub("blub #1")
- bl1.foos.append(f1)
- bl1.bars.append(b2)
- objectstore.commit()
- compare = repr(bl1)
- blubid = bl1.id
- objectstore.clear()
- else:
- foo.insert().execute(data='foo #1')
- foo.insert().execute(data='foo #2')
- bar.insert().execute(id=1, data="bar #1")
- bar.insert().execute(id=2, data="bar #2")
- blub.insert().execute(id=1, data="blub #1")
- blub_bar.insert().execute(blub_id=1, bar_id=2)
- blub_foo.insert().execute(blub_id=1, foo_id=2)
-
- l = Blub.mapper.select()
+ sess = create_session()
+ f1 = Foo("foo #1", _sa_session=sess)
+ b1 = Bar("bar #1", _sa_session=sess)
+ b2 = Bar("bar #2", _sa_session=sess)
+ bl1 = Blub("blub #1", _sa_session=sess)
+ bl1.foos.append(f1)
+ bl1.bars.append(b2)
+ sess.flush()
+ compare = repr(bl1)
+ blubid = bl1.id
+ sess.clear()
+
+ l = sess.query(Blub).select()
self.echo(l)
self.assert_(repr(l[0]) == compare)
- objectstore.clear()
- x = Blub.mapper.get_by(id=blubid) #traceback 2
+ sess.clear()
+ x = sess.query(Blub).get_by(id=blubid)
self.echo(x)
self.assert_(repr(x) == compare)
class InheritTest4(testbase.AssertMixin):
"""deals with inheritance and one-to-many relationships"""
def setUpAll(self):
- engine = testbase.db
- global foo, bar, blub, tables
- engine.engine.echo = 'debug'
+ global foo, bar, blub, metadata
+ metadata = BoundMetaData(testbase.db)
# the 'data' columns are to appease SQLite which cant handle a blank INSERT
- foo = Table('foo', engine,
+ foo = Table('foo', metadata,
Column('id', Integer, Sequence('foo_seq'), primary_key=True),
Column('data', String(20)))
- bar = Table('bar', engine,
+ bar = Table('bar', metadata,
Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
Column('data', String(20)))
- blub = Table('blub', engine,
+ blub = Table('blub', metadata,
Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False),
Column('data', String(20)))
-
- tables = [foo, bar, blub]
- for table in tables:
- table.create()
+ metadata.create_all()
def tearDownAll(self):
- for table in reversed(tables):
- table.drop()
- testbase.db.tables.clear()
+ metadata.drop_all()
def tearDown(self):
- for table in reversed(tables):
+ for table in metadata.table_iterator():
table.delete().execute()
def testbasic(self):
@@ -318,56 +286,55 @@ class InheritTest4(testbase.AssertMixin):
self.data = data
def __repr__(self):
return "Foo id %d, data %s" % (self.id, self.data)
- Foo.mapper = mapper(Foo, foo)
+ mapper(Foo, foo)
class Bar(Foo):
def __repr__(self):
return "Bar id %d, data %s" % (self.id, self.data)
- Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper)
+ mapper(Bar, bar, inherits=Foo)
class Blub(Bar):
def __repr__(self):
return "Blub id %d, data %s" % (self.id, self.data)
- Blub.mapper = mapper(Blub, blub, inherits=Bar.mapper, properties={
- # bug was raised specifically based on the order of cols in the join....
-# 'parent_foo':relation(Foo.mapper, primaryjoin=blub.c.foo_id==foo.c.id)
-# 'parent_foo':relation(Foo.mapper, primaryjoin=foo.c.id==blub.c.foo_id)
- 'parent_foo':relation(Foo.mapper)
+ mapper(Blub, blub, inherits=Bar, properties={
+ 'parent_foo':relation(Foo)
})
- b1 = Blub("blub #1")
- b2 = Blub("blub #2")
- f = Foo("foo #1")
+ sess = create_session()
+ b1 = Blub("blub #1", _sa_session=sess)
+ b2 = Blub("blub #2", _sa_session=sess)
+ f = Foo("foo #1", _sa_session=sess)
b1.parent_foo = f
b2.parent_foo = f
- objectstore.commit()
+ sess.flush()
compare = repr(b1) + repr(b2) + repr(b1.parent_foo) + repr(b2.parent_foo)
- objectstore.clear()
- l = Blub.mapper.select()
+ sess.clear()
+ l = sess.query(Blub).select()
result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo)
self.echo(result)
self.assert_(compare == result)
self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
-class InheritTest5(testbase.AssertMixin):
+class InheritTest5(testbase.AssertMixin):
+ """testing that construction of inheriting mappers works regardless of when extra properties
+ are added to the superclass mapper"""
def setUpAll(self):
- engine = testbase.db
- global content_type, content, product
- content_type = Table('content_type', engine,
+ global content_type, content, product, metadata
+ metadata = BoundMetaData(testbase.db)
+ content_type = Table('content_type', metadata,
Column('id', Integer, primary_key=True)
)
- content = Table('content', engine,
+ content = Table('content', metadata,
Column('id', Integer, primary_key=True),
Column('content_type_id', Integer, ForeignKey('content_type.id'))
)
- product = Table('product', engine,
+ product = Table('product', metadata,
Column('id', Integer, ForeignKey('content.id'), primary_key=True)
)
def tearDownAll(self):
- testbase.db.tables.clear()
-
+ pass
def tearDown(self):
pass
@@ -384,14 +351,12 @@ class InheritTest5(testbase.AssertMixin):
# shouldnt throw exception
products = mapper(Product, product, inherits=contents)
-
def testbackref(self):
- """this test is currently known to fail in the 0.1 series of SQLAlchemy, pending the resolution of [ticket:154]"""
+ """tests adding a property to the superclass mapper"""
class ContentType(object): pass
class Content(object): pass
class Product(Content): pass
- # this test fails currently
contents = mapper(Content, content)
products = mapper(Product, product, inherits=contents)
content_types = mapper(ContentType, content_type, properties={
@@ -400,25 +365,24 @@ class InheritTest5(testbase.AssertMixin):
p = Product()
p.contenttype = ContentType()
-
class InheritTest6(testbase.AssertMixin):
"""tests eager load/lazy load of child items off inheritance mappers, tests that
LazyLoader constructs the right query condition."""
def setUpAll(self):
- global foo, bar, bar_foo
- foo = Table('foo', testbase.db, Column('id', Integer, Sequence('foo_seq'), primary_key=True),
- Column('data', String(30))).create()
- bar = Table('bar', testbase.db, Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
- Column('data', String(30))).create()
-
- bar_foo = Table('bar_foo', testbase.db,
+ global foo, bar, bar_foo, metadata
+ metadata=BoundMetaData(testbase.db)
+ foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq'), primary_key=True),
+ Column('data', String(30)))
+ bar = Table('bar', metadata, Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+ Column('data', String(30)))
+
+ bar_foo = Table('bar_foo', metadata,
Column('bar_id', Integer, ForeignKey('bar.id')),
Column('foo_id', Integer, ForeignKey('foo.id'))
- ).create()
+ )
+ metadata.create_all()
def tearDownAll(self):
- bar_foo.drop()
- bar.drop()
- foo.drop()
+ metadata.drop_all()
def testbasic(self):
class Foo(object): pass
@@ -427,6 +391,7 @@ class InheritTest6(testbase.AssertMixin):
foos = mapper(Foo, foo)
bars = mapper(Bar, bar, inherits=foos)
bars.add_property('lazy', relation(foos, bar_foo, lazy=True))
+ print bars.props['lazy'].primaryjoin, bars.props['lazy'].secondaryjoin
bars.add_property('eager', relation(foos, bar_foo, lazy=False))
foo.insert().execute(data='foo1')
@@ -440,9 +405,11 @@ class InheritTest6(testbase.AssertMixin):
bar_foo.insert().execute(bar_id=1, foo_id=3)
bar_foo.insert().execute(bar_id=2, foo_id=4)
-
- self.assert_(len(bars.selectfirst().lazy) == 1)
- self.assert_(len(bars.selectfirst().eager) == 1)
+
+ sess = create_session()
+ q = sess.query(Bar)
+ self.assert_(len(q.selectfirst().lazy) == 1)
+ self.assert_(len(q.selectfirst().eager) == 1)
if __name__ == "__main__":
testbase.main()
diff --git a/test/lazytest1.py b/test/lazytest1.py
index 986f067aa..eb1310d66 100644
--- a/test/lazytest1.py
+++ b/test/lazytest1.py
@@ -6,28 +6,25 @@ import datetime
class LazyTest(AssertMixin):
def setUpAll(self):
- global info_table, data_table, rel_table
- engine = testbase.db
- info_table = Table('infos', engine,
+ global info_table, data_table, rel_table, metadata
+ metadata = BoundMetaData(testbase.db)
+ info_table = Table('infos', metadata,
Column('pk', Integer, primary_key=True),
Column('info', String))
- data_table = Table('data', engine,
+ data_table = Table('data', metadata,
Column('data_pk', Integer, primary_key=True),
Column('info_pk', Integer, ForeignKey(info_table.c.pk)),
Column('timeval', Integer),
Column('data_val', String))
- rel_table = Table('rels', engine,
+ rel_table = Table('rels', metadata,
Column('rel_pk', Integer, primary_key=True),
Column('info_pk', Integer, ForeignKey(info_table.c.pk)),
Column('start', Integer),
Column('finish', Integer))
-
- info_table.create()
- rel_table.create()
- data_table.create()
+ metadata.create_all()
info_table.insert().execute(
{'pk':1, 'info':'pk_1_info'},
{'pk':2, 'info':'pk_2_info'},
@@ -52,13 +49,8 @@ class LazyTest(AssertMixin):
def tearDownAll(self):
- data_table.drop()
- rel_table.drop()
- info_table.drop()
+ metadata.drop_all()
- def setUp(self):
- clear_mappers()
-
def testone(self):
"""tests a lazy load which has multiple join conditions, including two that are against
the same column in the child table"""
@@ -71,58 +63,28 @@ class LazyTest(AssertMixin):
class Data(object):
pass
- # Create the basic mappers, with no frills or modifications
- Information.mapper = mapper(Information, info_table)
- Data.mapper = mapper(Data, data_table)
- Relation.mapper = mapper(Relation, rel_table)
-
- Relation.mapper.add_property('datas', relation(Data.mapper,
- primaryjoin=and_(Relation.c.info_pk==Data.c.info_pk,
- Data.c.timeval >= Relation.c.start,
- Data.c.timeval <= Relation.c.finish
- ),
- foreignkey=Data.c.info_pk))
-
- Information.mapper.add_property('rels', relation(Relation.mapper))
-
- info = Information.mapper.get(1)
- assert info
- assert len(info.rels) == 2
- assert len(info.rels[0].datas) == 3
-
- def testtwo(self):
- """same thing, but reversing the order of the cols in the join"""
- class Information(object):
- pass
-
- class Relation(object):
- pass
-
- class Data(object):
- pass
-
- # Create the basic mappers, with no frills or modifications
- Information.mapper = mapper(Information, info_table)
- Data.mapper = mapper(Data, data_table)
- Relation.mapper = mapper(Relation, rel_table)
-
- Relation.mapper.add_property('datas', relation(Data.mapper,
- primaryjoin=and_(Relation.c.info_pk==Data.c.info_pk,
- Relation.c.start <= Data.c.timeval,
- Relation.c.finish >= Data.c.timeval,
- # Data.c.timeval >= Relation.c.start,
- # Data.c.timeval <= Relation.c.finish
- ),
- foreignkey=Data.c.info_pk))
-
- Information.mapper.add_property('rels', relation(Relation.mapper))
-
- info = Information.mapper.get(1)
+ session = create_session()
+
+ mapper(Data, data_table)
+ mapper(Relation, rel_table, properties={
+
+ 'datas': relation(Data,
+ primaryjoin=and_(rel_table.c.info_pk==Data.c.info_pk,
+ Data.c.timeval >= rel_table.c.start,
+ Data.c.timeval <= rel_table.c.finish),
+ foreignkey=Data.c.info_pk)
+ }
+
+ )
+ mapper(Information, info_table, properties={
+ 'rels': relation(Relation)
+ })
+
+ info = session.query(Information).get(1)
assert info
assert len(info.rels) == 2
assert len(info.rels[0].datas) == 3
-
if __name__ == "__main__":
testbase.main()
diff --git a/test/legacy_objectstore.py b/test/legacy_objectstore.py
new file mode 100644
index 000000000..3aa99a1ae
--- /dev/null
+++ b/test/legacy_objectstore.py
@@ -0,0 +1,113 @@
+from testbase import PersistTest, AssertMixin
+import unittest, sys, os
+from sqlalchemy import *
+import StringIO
+import testbase
+
+from tables import *
+import tables
+
+install_mods('legacy_session')
+
+
+class LegacySessionTest(AssertMixin):
+ def setUpAll(self):
+ db.echo = False
+ users.create()
+ db.echo = testbase.echo
+ def tearDownAll(self):
+ db.echo = False
+ users.drop()
+ db.echo = testbase.echo
+ def setUp(self):
+ objectstore.get_session().clear()
+ clear_mappers()
+ tables.user_data()
+ #db.echo = "debug"
+ def tearDown(self):
+ tables.delete_user_data()
+
+ def test_nested_begin_commit(self):
+ """tests that nesting objectstore transactions with multiple commits
+ affects only the outermost transaction"""
+ class User(object):pass
+ m = mapper(User, users)
+ def name_of(id):
+ return users.select(users.c.user_id == id).execute().fetchone().user_name
+ name1 = "Oliver Twist"
+ name2 = 'Mr. Bumble'
+ self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+ self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+ s = objectstore.get_session()
+ trans = s.begin()
+ trans2 = s.begin()
+ m.get(7).user_name = name1
+ trans3 = s.begin()
+ m.get(8).user_name = name2
+ trans3.commit()
+ s.commit() # should do nothing
+ self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+ self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+ trans2.commit()
+ s.commit() # should do nothing
+ self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+ self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+ trans.commit()
+ self.assert_(name_of(7) == name1, msg="user_name should be %s" % name1)
+ self.assert_(name_of(8) == name2, msg="user_name should be %s" % name2)
+
+ def test_nested_rollback(self):
+ """tests that nesting objectstore transactions with a rollback inside
+ affects only the outermost transaction"""
+ class User(object):pass
+ m = mapper(User, users)
+ def name_of(id):
+ return users.select(users.c.user_id == id).execute().fetchone().user_name
+ name1 = "Oliver Twist"
+ name2 = 'Mr. Bumble'
+ self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+ self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+ s = objectstore.get_session()
+ trans = s.begin()
+ trans2 = s.begin()
+ m.get(7).user_name = name1
+ trans3 = s.begin()
+ m.get(8).user_name = name2
+ trans3.rollback()
+ self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+ self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+ trans2.commit()
+ self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+ self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+ trans.commit()
+ self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+ self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+
+ def test_true_nested(self):
+ """tests creating a new Session inside a database transaction, in
+ conjunction with an engine-level nested transaction, which uses
+ a second connection in order to achieve a nested transaction that commits, inside
+ of another engine session that rolls back."""
+# testbase.db.echo='debug'
+ class User(object):
+ pass
+ testbase.db.begin()
+ try:
+ m = mapper(User, users)
+ name1 = "Oliver Twist"
+ name2 = 'Mr. Bumble'
+ m.get(7).user_name = name1
+ s = objectstore.Session(nest_on=testbase.db)
+ m.using(s).get(8).user_name = name2
+ s.commit()
+ objectstore.commit()
+ testbase.db.rollback()
+ except:
+ testbase.db.rollback()
+ raise
+ objectstore.clear()
+ self.assert_(m.get(8).user_name == name2)
+ self.assert_(m.get(7).user_name != name1)
+
+if __name__ == "__main__":
+ testbase.main()
diff --git a/test/manytomany.py b/test/manytomany.py
index 12f62dbdb..92b7efb26 100644
--- a/test/manytomany.py
+++ b/test/manytomany.py
@@ -1,5 +1,5 @@
-from sqlalchemy import *
import testbase
+from sqlalchemy import *
import string
import sqlalchemy.attributes as attr
@@ -28,21 +28,22 @@ class Transition(object):
class M2MTest(testbase.AssertMixin):
def setUpAll(self):
- db = testbase.db
+ self.install_threadlocal()
+ metadata = testbase.metadata
global place
- place = Table('place', db,
+ place = Table('place', metadata,
Column('place_id', Integer, Sequence('pid_seq', optional=True), primary_key=True),
Column('name', String(30), nullable=False),
)
global transition
- transition = Table('transition', db,
+ transition = Table('transition', metadata,
Column('transition_id', Integer, Sequence('tid_seq', optional=True), primary_key=True),
Column('name', String(30), nullable=False),
)
global place_thingy
- place_thingy = Table('place_thingy', db,
+ place_thingy = Table('place_thingy', metadata,
Column('thingy_id', Integer, Sequence('thid_seq', optional=True), primary_key=True),
Column('place_id', Integer, ForeignKey('place.place_id'), nullable=False),
Column('name', String(30), nullable=False)
@@ -50,20 +51,20 @@ class M2MTest(testbase.AssertMixin):
# association table #1
global place_input
- place_input = Table('place_input', db,
+ place_input = Table('place_input', metadata,
Column('place_id', Integer, ForeignKey('place.place_id')),
Column('transition_id', Integer, ForeignKey('transition.transition_id')),
)
# association table #2
global place_output
- place_output = Table('place_output', db,
+ place_output = Table('place_output', metadata,
Column('place_id', Integer, ForeignKey('place.place_id')),
Column('transition_id', Integer, ForeignKey('transition.transition_id')),
)
global place_place
- place_place = Table('place_place', db,
+ place_place = Table('place_place', metadata,
Column('pl1_id', Integer, ForeignKey('place.place_id')),
Column('pl2_id', Integer, ForeignKey('place.place_id')),
)
@@ -83,7 +84,8 @@ class M2MTest(testbase.AssertMixin):
place.drop()
transition.drop()
#testbase.db.tables.clear()
-
+ self.uninstall_threadlocal()
+
def setUp(self):
objectstore.clear()
clear_mappers()
@@ -140,7 +142,7 @@ class M2MTest(testbase.AssertMixin):
pp = p.places
self.echo("Place " + str(p) +" places " + repr(pp))
- objectstore.delete(p1,p2,p3,p4,p5,p6,p7)
+ [objectstore.delete(p) for p in p1,p2,p3,p4,p5,p6,p7]
objectstore.flush()
def testdouble(self):
@@ -152,8 +154,8 @@ class M2MTest(testbase.AssertMixin):
})
Transition.mapper = mapper(Transition, transition, properties = dict(
- inputs = relation(Place.mapper, place_output, lazy=False, selectalias='op_alias'),
- outputs = relation(Place.mapper, place_input, lazy=False, selectalias='ip_alias'),
+ inputs = relation(Place.mapper, place_output, lazy=False),
+ outputs = relation(Place.mapper, place_input, lazy=False),
)
)
@@ -161,7 +163,7 @@ class M2MTest(testbase.AssertMixin):
tran.inputs.append(Place('place1'))
tran.outputs.append(Place('place2'))
tran.outputs.append(Place('place3'))
- objectstore.commit()
+ objectstore.flush()
objectstore.clear()
r = Transition.mapper.select()
@@ -201,20 +203,21 @@ class M2MTest(testbase.AssertMixin):
p3.inputs.append(t2)
p1.outputs.append(t1)
- objectstore.commit()
+ objectstore.flush()
self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])})
self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])})
class M2MTest2(testbase.AssertMixin):
def setUpAll(self):
- db = testbase.db
+ self.install_threadlocal()
+ metadata = testbase.metadata
global studentTbl
- studentTbl = Table('student', db, Column('name', String(20), primary_key=True))
+ studentTbl = Table('student', metadata, Column('name', String(20), primary_key=True))
global courseTbl
- courseTbl = Table('course', db, Column('name', String(20), primary_key=True))
+ courseTbl = Table('course', metadata, Column('name', String(20), primary_key=True))
global enrolTbl
- enrolTbl = Table('enrol', db,
+ enrolTbl = Table('enrol', metadata,
Column('student_id', String(20), ForeignKey('student.name'),primary_key=True),
Column('course_id', String(20), ForeignKey('course.name'), primary_key=True))
@@ -227,7 +230,8 @@ class M2MTest2(testbase.AssertMixin):
studentTbl.drop()
courseTbl.drop()
#testbase.db.tables.clear()
-
+ self.uninstall_threadlocal()
+
def setUp(self):
objectstore.clear()
clear_mappers()
@@ -258,7 +262,7 @@ class M2MTest2(testbase.AssertMixin):
c3.students.append(s1)
self.assert_(len(s1.courses) == 3)
self.assert_(len(c1.students) == 1)
- objectstore.commit()
+ objectstore.flush()
objectstore.clear()
s = Student.mapper.get_by(name='Student1')
c = Course.mapper.get_by(name='Course3')
@@ -267,65 +271,66 @@ class M2MTest2(testbase.AssertMixin):
self.assert_(len(s.courses) == 2)
class M2MTest3(testbase.AssertMixin):
- def setUpAll(self):
- e = testbase.db
- global c, c2a1, c2a2, b, a
- c = Table('c', e,
- Column('c1', Integer, primary_key = True),
- Column('c2', String(20)),
- ).create()
-
- a = Table('a', e,
- Column('a1', Integer, primary_key=True),
- Column('a2', String(20)),
- Column('c1', Integer, ForeignKey('c.c1'))
- ).create()
-
- c2a1 = Table('ctoaone', e,
- Column('c1', Integer, ForeignKey('c.c1')),
- Column('a1', Integer, ForeignKey('a.a1'))
- ).create()
- c2a2 = Table('ctoatwo', e,
- Column('c1', Integer, ForeignKey('c.c1')),
- Column('a1', Integer, ForeignKey('a.a1'))
- ).create()
-
- b = Table('b', e,
- Column('b1', Integer, primary_key=True),
- Column('a1', Integer, ForeignKey('a.a1')),
- Column('b2', Boolean)
- ).create()
-
- def tearDownAll(self):
- b.drop()
- c2a2.drop()
- c2a1.drop()
- a.drop()
- c.drop()
- #testbase.db.tables.clear()
-
- def testbasic(self):
- class C(object):pass
- class A(object):pass
- class B(object):pass
+ def setUpAll(self):
+ self.install_threadlocal()
+ metadata = testbase.metadata
+ global c, c2a1, c2a2, b, a
+ c = Table('c', metadata,
+ Column('c1', Integer, primary_key = True),
+ Column('c2', String(20)),
+ ).create()
+
+ a = Table('a', metadata,
+ Column('a1', Integer, primary_key=True),
+ Column('a2', String(20)),
+ Column('c1', Integer, ForeignKey('c.c1'))
+ ).create()
+
+ c2a1 = Table('ctoaone', metadata,
+ Column('c1', Integer, ForeignKey('c.c1')),
+ Column('a1', Integer, ForeignKey('a.a1'))
+ ).create()
+ c2a2 = Table('ctoatwo', metadata,
+ Column('c1', Integer, ForeignKey('c.c1')),
+ Column('a1', Integer, ForeignKey('a.a1'))
+ ).create()
+
+ b = Table('b', metadata,
+ Column('b1', Integer, primary_key=True),
+ Column('a1', Integer, ForeignKey('a.a1')),
+ Column('b2', Boolean)
+ ).create()
- assign_mapper(B, b)
+ def tearDownAll(self):
+ b.drop()
+ c2a2.drop()
+ c2a1.drop()
+ a.drop()
+ c.drop()
+ #testbase.db.tables.clear()
+ self.uninstall_threadlocal()
+
+ def testbasic(self):
+ class C(object):pass
+ class A(object):pass
+ class B(object):pass
- assign_mapper(A, a,
- properties = {
- 'tbs' : relation(B, primaryjoin=and_(b.c.a1==a.c.a1, b.c.b2 == True), lazy=False),
- }
- )
+ assign_mapper(B, b)
- assign_mapper(C, c,
- properties = {
- 'a1s' : relation(A, secondary=c2a1, lazy=False),
- 'a2s' : relation(A, secondary=c2a2, lazy=False)
- }
- )
+ assign_mapper(A, a,
+ properties = {
+ 'tbs' : relation(B, primaryjoin=and_(b.c.a1==a.c.a1, b.c.b2 == True), lazy=False),
+ }
+ )
- o1 = C.get(1)
+ assign_mapper(C, c,
+ properties = {
+ 'a1s' : relation(A, secondary=c2a1, lazy=False),
+ 'a2s' : relation(A, secondary=c2a2, lazy=False)
+ }
+ )
+ o1 = C.get(1)
if __name__ == "__main__":
diff --git a/test/mapper.py b/test/mapper.py
index 28982f756..546ed1a87 100644
--- a/test/mapper.py
+++ b/test/mapper.py
@@ -2,7 +2,7 @@ from testbase import PersistTest, AssertMixin
import testbase
import unittest, sys, os
from sqlalchemy import *
-
+import sqlalchemy.exceptions as exceptions
from tables import *
import tables
@@ -68,35 +68,38 @@ class MapperSuperTest(AssertMixin):
tables.drop()
db.echo = testbase.echo
def tearDown(self):
- objectstore.clear()
clear_mappers()
def setUp(self):
pass
class MapperTest(MapperSuperTest):
def testget(self):
- m = mapper(User, users)
- self.assert_(m.get(19) is None)
- u = m.get(7)
- u2 = m.get(7)
+ s = create_session()
+ mapper(User, users)
+ self.assert_(s.get(User, 19) is None)
+ u = s.get(User, 7)
+ u2 = s.get(User, 7)
self.assert_(u is u2)
- objectstore.clear()
- u2 = m.get(7)
+ s.clear()
+ u2 = s.get(User, 7)
self.assert_(u is not u2)
def testrefresh(self):
- m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
- u = m.get(7)
+ mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
+ s = create_session()
+ u = s.get(User, 7)
u.user_name = 'foo'
a = Address()
+ import sqlalchemy.orm.session
+ assert sqlalchemy.orm.session.object_session(a) is None
u.addresses.append(a)
self.assert_(a in u.addresses)
- objectstore.refresh(u)
+ s.refresh(u)
# its refreshed, so not dirty
- self.assert_(u not in objectstore.get_session().uow.dirty)
+ self.assert_(u not in s.dirty)
# username is back to the DB
self.assert_(u.user_name == 'jack')
@@ -106,10 +109,10 @@ class MapperTest(MapperSuperTest):
u.user_name = 'foo'
u.addresses.append(a)
# now its dirty
- self.assert_(u in objectstore.get_session().uow.dirty)
+ self.assert_(u in s.dirty)
self.assert_(u.user_name == 'foo')
self.assert_(a in u.addresses)
- objectstore.expire(u)
+ s.expire(u)
# get the attribute, it refreshes
self.assert_(u.user_name == 'jack')
@@ -117,41 +120,22 @@ class MapperTest(MapperSuperTest):
def testrefresh_lazy(self):
"""tests that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems"""
- m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
- m2 = m.options(lazyload('addresses'))
- u = m2.selectfirst(users.c.user_id==8)
+ s = create_session()
+ mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
+ q2 = s.query(User).options(lazyload('addresses'))
+ u = q2.selectfirst(users.c.user_id==8)
def go():
- objectstore.refresh(u)
+ s.refresh(u)
self.assert_sql_count(db, go, 1)
- def testexpire_eager(self):
- """tests that an eager load will populate expire()'d objects"""
- m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
- [u1, u2, u3] = m.select(users.c.user_id.in_(7, 8, 9))
- self.echo([repr(x.addresses) for x in [u1, u2, u3]])
- [objectstore.expire(u) for u in [u1, u2, u3]]
- m2 = m.options(eagerload('addresses'))
- l = m2.select(users.c.user_id.in_(7,8,9))
- def go():
- u1.addresses
- u2.addresses
- u3.addresses
- self.assert_sql_count(db, go, 0)
-
- def testsessionpropigation(self):
- sess = objectstore.Session()
- m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=True)})
- u = m.using(sess).get(7)
- assert objectstore.get_session(u) is sess
- assert objectstore.get_session(u.addresses[0]) is sess
-
def testexpire(self):
- m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
- u = m.get(7)
+ s = create_session()
+ mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
+ u = s.get(User, 7)
assert(len(u.addresses) == 1)
u.user_name = 'foo'
del u.addresses[0]
- objectstore.expire(u)
+ s.expire(u)
# test plain expire
self.assert_(u.user_name =='jack')
self.assert_(len(u.addresses) == 1)
@@ -159,14 +143,14 @@ class MapperTest(MapperSuperTest):
# we're changing the database here, so if this test fails in the middle,
# it'll screw up the other tests which are hardcoded to 7/'jack'
u.user_name = 'foo'
- objectstore.commit()
+ s.flush()
# change the value in the DB
users.update(users.c.user_id==7, values=dict(user_name='jack')).execute()
- objectstore.expire(u)
+ s.expire(u)
# object isnt refreshed yet, using dict to bypass trigger
self.assert_(u.__dict__['user_name'] != 'jack')
# do a select
- m.select()
+ s.query(User).select()
# test that it refreshed
self.assert_(u.__dict__['user_name'] == 'jack')
@@ -175,41 +159,43 @@ class MapperTest(MapperSuperTest):
self.assert_(u.user_name =='jack')
def testrefresh2(self):
- assign_mapper(Address, addresses)
+ s = create_session()
+ mapper(Address, addresses)
- assign_mapper(User, users, properties = dict(addresses=relation(Address.mapper,private=True,lazy=False)) )
+ mapper(User, users, properties = dict(addresses=relation(Address,private=True,lazy=False)) )
u=User()
u.user_name='Justin'
a = Address()
a.address_id=17 # to work around the hardcoded IDs in this test suite....
u.addresses.append(a)
- objectstore.commit()
- objectstore.clear()
- u = User.mapper.selectfirst()
+ s.flush()
+ s.clear()
+ u = s.query(User).selectfirst()
print u.user_name
#ok so far
- u.expire() #hangs when
+ s.expire(u) #hangs when
print u.user_name #this line runs
- u.refresh() #hangs
+ s.refresh(u) #hangs
def testmagic(self):
- m = mapper(User, users, properties = {
+ mapper(User, users, properties = {
'addresses' : relation(mapper(Address, addresses))
})
- l = m.select_by(user_name='fred')
+ sess = create_session()
+ l = sess.query(User).select_by(user_name='fred')
self.assert_result(l, User, *[{'user_id':9}])
u = l[0]
- u2 = m.get_by_user_name('fred')
+ u2 = sess.query(User).get_by_user_name('fred')
self.assert_(u is u2)
- l = m.select_by(email_address='ed@bettyboop.com')
+ l = sess.query(User).select_by(email_address='ed@bettyboop.com')
self.assert_result(l, User, *[{'user_id':8}])
- l = m.select_by(User.c.user_name=='fred', addresses.c.email_address!='ed@bettyboop.com', user_id=9)
+ l = sess.query(User).select_by(User.c.user_name=='fred', addresses.c.email_address!='ed@bettyboop.com', user_id=9)
def testprops(self):
"""tests the various attributes of the properties attached to classes"""
@@ -220,46 +206,69 @@ class MapperTest(MapperSuperTest):
def testload(self):
"""tests loading rows with a mapper and producing object instances"""
- m = mapper(User, users)
- l = m.select()
+ mapper(User, users)
+ l = create_session().query(User).select()
self.assert_result(l, User, *user_result)
- l = m.select(users.c.user_name.endswith('ed'))
+ l = create_session().query(User).select(users.c.user_name.endswith('ed'))
self.assert_result(l, User, *user_result[1:3])
+ def testjoinvia(self):
+ m = mapper(User, users, properties={
+ 'orders':relation(mapper(Order, orders, properties={
+ 'items':relation(mapper(Item, orderitems))
+ }))
+ })
+
+ q = create_session().query(m)
+
+ l = q.select((orderitems.c.item_name=='item 4') & q.join_via(['orders', 'items']))
+ self.assert_result(l, User, user_result[0])
+
+ l = q.select_by(item_name='item 4')
+ self.assert_result(l, User, user_result[0])
+
+ l = q.select((orderitems.c.item_name=='item 4') & q.join_to('item_name'))
+ self.assert_result(l, User, user_result[0])
+
+ l = q.select((orderitems.c.item_name=='item 4') & q.join_to('items'))
+ self.assert_result(l, User, user_result[0])
+
def testorderby(self):
# TODO: make a unit test out of these various combinations
# m = mapper(User, users, order_by=desc(users.c.user_name))
- m = mapper(User, users, order_by=None)
-# m = mapper(User, users)
+ mapper(User, users, order_by=None)
+# mapper(User, users)
-# l = m.select(order_by=[desc(users.c.user_name), asc(users.c.user_id)])
- l = m.select()
-# l = m.select(order_by=[])
-# l = m.select(order_by=None)
+# l = create_session().query(User).select(order_by=[desc(users.c.user_name), asc(users.c.user_id)])
+ l = create_session().query(User).select()
+# l = create_session().query(User).select(order_by=[])
+# l = create_session().query(User).select(order_by=None)
def testfunction(self):
"""tests mapping to a SELECT statement that has functions in it."""
s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')],
users.c.user_id==addresses.c.user_id, group_by=[c for c in users.c]).alias('myselect')
- m = mapper(User, s, primarytable=users)
- print [c.key for c in m.c]
- l = m.select()
+ mapper(User, s)
+ sess = create_session()
+ l = sess.query(User).select()
for u in l:
print "User", u.user_id, u.user_name, u.concat, u.count
- #l[1].user_name='asdf'
- #objectstore.commit()
-
+ assert l[0].concat == l[0].user_id * 2 == 14
+ assert l[1].concat == l[1].user_id * 2 == 16
+
def testcount(self):
- m = mapper(User, users)
- self.assert_(m.count()==3)
- self.assert_(m.count(users.c.user_id.in_(8,9))==2)
- self.assert_(m.count_by(user_name='fred')==1)
+ mapper(User, users)
+ q = create_session().query(User)
+ self.assert_(q.count()==3)
+ self.assert_(q.count(users.c.user_id.in_(8,9))==2)
+ self.assert_(q.count_by(user_name='fred')==1)
def testmultitable(self):
usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
- m = mapper(User, usersaddresses, primarytable = users, primary_key=[users.c.user_id])
- l = m.select()
+ m = mapper(User, usersaddresses, primary_key=[users.c.user_id])
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, User, *user_result[0:2])
def testoverride(self):
@@ -269,14 +278,16 @@ class MapperTest(MapperSuperTest):
'user_name' : relation(mapper(Address, addresses)),
})
self.assert_(False, "should have raised ArgumentError")
- except ArgumentError, e:
+ except exceptions.ArgumentError, e:
self.assert_(True)
-
+
+ clear_mappers()
# assert that allow_column_override cancels the error
m = mapper(User, users, properties = {
'user_name' : relation(mapper(Address, addresses))
}, allow_column_override=True)
+ clear_mappers()
# assert that the column being named else where also cancels the error
m = mapper(User, users, properties = {
'user_name' : relation(mapper(Address, addresses)),
@@ -285,27 +296,50 @@ class MapperTest(MapperSuperTest):
def testeageroptions(self):
"""tests that a lazy relation can be upgraded to an eager relation via the options method"""
- m = mapper(User, users, properties = dict(
+ sess = create_session()
+ mapper(User, users, properties = dict(
addresses = relation(mapper(Address, addresses), lazy = True)
))
- l = m.options(eagerload('addresses')).select()
+ l = sess.query(User).options(eagerload('addresses')).select()
def go():
self.assert_result(l, User, *user_address_result)
self.assert_sql_count(db, go, 0)
+ def testeagerdegrade(self):
+ """tests that an eager relation automatically degrades to a lazy relation if eager columns are not available"""
+ sess = create_session()
+ usermapper = mapper(User, users, properties = dict(
+ addresses = relation(mapper(Address, addresses), lazy = False)
+ ))
+
+ # first test straight eager load, 1 statement
+ def go():
+ l = usermapper.query(sess).select()
+ self.assert_result(l, User, *user_address_result)
+ self.assert_sql_count(db, go, 1)
+
+ # then select just from users. run it into instances.
+ # then assert the data, which will launch 3 more lazy loads
+ def go():
+ r = users.select().execute()
+ l = usermapper.instances(r, sess)
+ self.assert_result(l, User, *user_address_result)
+ self.assert_sql_count(db, go, 4)
+
def testlazyoptions(self):
"""tests that an eager relation can be upgraded to a lazy relation via the options method"""
- m = mapper(User, users, properties = dict(
+ sess = create_session()
+ mapper(User, users, properties = dict(
addresses = relation(mapper(Address, addresses), lazy = False)
))
- l = m.options(lazyload('addresses')).select()
+ l = sess.query(User).options(lazyload('addresses')).select()
def go():
self.assert_result(l, User, *user_address_result)
self.assert_sql_count(db, go, 3)
def testdeepoptions(self):
- m = mapper(User, users,
+ mapper(User, users,
properties = {
'orders': relation(mapper(Order, orders, properties = {
'items' : relation(mapper(Item, orderitems, properties = {
@@ -314,55 +348,43 @@ class MapperTest(MapperSuperTest):
}))
})
- m2 = m.options(eagerload('orders.items.keywords'))
- u = m.select()
+ sess = create_session()
+ q2 = sess.query(User).options(eagerload('orders.items.keywords'))
+ u = sess.query(User).select()
def go():
print u[0].orders[1].items[0].keywords[1]
self.assert_sql_count(db, go, 3)
- objectstore.clear()
- u = m2.select()
+ sess.clear()
+ u = q2.select()
self.assert_sql_count(db, go, 2)
-class PropertyTest(MapperSuperTest):
- def testbasic(self):
- """tests that you can create mappers inline with class definitions"""
- class _Address(object):
- pass
- assign_mapper(_Address, addresses)
-
- class _User(object):
- pass
- assign_mapper(_User, users, properties = dict(
- addresses = relation(_Address.mapper, lazy = False)
- ), is_primary = True)
-
- l = _User.mapper.select(_User.c.user_name == 'fred')
- self.echo(repr(l))
-
+class InheritanceTest(MapperSuperTest):
def testinherits(self):
class _Order(object):
pass
- assign_mapper(_Order, orders)
+ ordermapper = mapper(_Order, orders)
class _User(object):
pass
- assign_mapper(_User, users, properties = dict(
- orders = relation(_Order.mapper, lazy = False)
+ usermapper = mapper(_User, users, properties = dict(
+ orders = relation(ordermapper, lazy = False)
))
class AddressUser(_User):
pass
- assign_mapper(AddressUser, addresses, inherits = _User.mapper)
-
- l = AddressUser.mapper.select()
+ mapper(AddressUser, addresses, inherits = usermapper)
+
+ sess = create_session()
+ q = sess.query(AddressUser)
+ l = q.select()
jack = l[0]
self.assert_(jack.user_name=='jack')
jack.email_address = 'jack@gmail.com'
- objectstore.commit()
- objectstore.clear()
- au = AddressUser.mapper.get_by(user_name='jack')
+ sess.flush()
+ sess.clear()
+ au = q.get_by(user_name='jack')
self.assert_(au.email_address == 'jack@gmail.com')
def testinherits2(self):
@@ -372,19 +394,20 @@ class PropertyTest(MapperSuperTest):
pass
class AddressUser(_Address):
pass
- assign_mapper(_Order, orders)
- assign_mapper(_Address, addresses)
- assign_mapper(AddressUser, users, inherits = _Address.mapper,
+ ordermapper = mapper(_Order, orders)
+ addressmapper = mapper(_Address, addresses)
+ usermapper = mapper(AddressUser, users, inherits = addressmapper,
properties = {
- 'orders' : relation(_Order.mapper, lazy=False)
+ 'orders' : relation(ordermapper, lazy=False)
})
- l = AddressUser.mapper.select()
+ sess = create_session()
+ l = sess.query(usermapper).select()
jack = l[0]
self.assert_(jack.user_name=='jack')
jack.email_address = 'jack@gmail.com'
- objectstore.commit()
- objectstore.clear()
- au = AddressUser.mapper.get_by(user_name='jack')
+ sess.flush()
+ sess.clear()
+ au = sess.query(usermapper).get_by(user_name='jack')
self.assert_(au.email_address == 'jack@gmail.com')
@@ -400,13 +423,15 @@ class DeferredTest(MapperSuperTest):
o = Order()
self.assert_(o.description is None)
+ q = create_session().query(m)
def go():
- l = m.select()
+ l = q.select()
o2 = l[2]
print o2.description
+ orderby = str(orders.default_order_by()[0].compile(engine=db))
self.assert_sql(db, go, [
- ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}),
+ ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
])
@@ -415,10 +440,12 @@ class DeferredTest(MapperSuperTest):
'description':deferred(orders.c.description)
})
- l = m.select()
+ sess = create_session()
+ q = sess.query(m)
+ l = q.select()
o2 = l[2]
o2.isopen = 1
- objectstore.commit()
+ sess.flush()
def testgroup(self):
"""tests deferred load with a group"""
@@ -428,36 +455,43 @@ class DeferredTest(MapperSuperTest):
'description':deferred(orders.c.description, group='primary'),
'opened':deferred(orders.c.isopen, group='primary')
})
-
+ q = create_session().query(m)
def go():
- l = m.select()
+ l = q.select()
o2 = l[2]
print o2.opened, o2.description, o2.userident
+
+ orderby = str(orders.default_order_by()[0].compile(db))
self.assert_sql(db, go, [
- ("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}),
+ ("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}),
("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
])
def testoptions(self):
"""tests using options on a mapper to create deferred and undeferred columns"""
m = mapper(Order, orders)
- m2 = m.options(defer('user_id'))
+ sess = create_session()
+ q = sess.query(m)
+ q2 = q.options(defer('user_id'))
def go():
- l = m2.select()
+ l = q2.select()
print l[2].user_id
+
+ orderby = str(orders.default_order_by()[0].compile(db))
self.assert_sql(db, go, [
- ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}),
+ ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
])
- objectstore.clear()
- m3 = m2.options(undefer('user_id'))
+ sess.clear()
+ q3 = q2.options(undefer('user_id'))
def go():
- l = m3.select()
+ l = q3.select()
print l[3].user_id
self.assert_sql(db, go, [
- ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}),
+ ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
])
+
def testdeepoptions(self):
m = mapper(User, users, properties={
'orders':relation(mapper(Order, orders, properties={
@@ -466,15 +500,17 @@ class DeferredTest(MapperSuperTest):
}))
}))
})
- l = m.select()
+ sess = create_session()
+ q = sess.query(m)
+ l = q.select()
item = l[0].orders[1].items[1]
def go():
print item.item_name
self.assert_sql_count(db, go, 1)
self.assert_(item.item_name == 'item 4')
- objectstore.clear()
- m2 = m.options(undefer('orders.items.item_name'))
- l = m2.select()
+ sess.clear()
+ q2 = q.options(undefer('orders.items.item_name'))
+ l = q2.select()
item = l[0].orders[1].items[1]
def go():
print item.item_name
@@ -489,7 +525,8 @@ class LazyTest(MapperSuperTest):
m = mapper(User, users, properties = dict(
addresses = relation(mapper(Address, addresses), lazy = True)
))
- l = m.select(users.c.user_id == 7)
+ q = create_session().query(m)
+ l = q.select(users.c.user_id == 7)
self.assert_result(l, User,
{'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
)
@@ -500,7 +537,8 @@ class LazyTest(MapperSuperTest):
m = mapper(User, users, properties = dict(
addresses = relation(m, lazy = True, order_by=addresses.c.email_address),
))
- l = m.select()
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, User,
{'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
@@ -508,13 +546,29 @@ class LazyTest(MapperSuperTest):
{'user_id' : 9, 'addresses' : (Address, [])}
)
+ def testorderby_select(self):
+ """tests that a regular mapper select on a single table can order by a relation to a second table"""
+ m = mapper(Address, addresses)
+
+ m = mapper(User, users, properties = dict(
+ addresses = relation(m, lazy = True),
+ ))
+ q = create_session().query(m)
+ l = q.select(users.c.user_id==addresses.c.user_id, order_by=addresses.c.email_address)
+
+ self.assert_result(l, User,
+ {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@wood.com'}, {'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, ])},
+ {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
+ )
+
def testorderby_desc(self):
m = mapper(Address, addresses)
m = mapper(User, users, properties = dict(
addresses = relation(m, lazy = True, order_by=[desc(addresses.c.email_address)]),
))
- l = m.select()
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, User,
{'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
@@ -531,35 +585,39 @@ class LazyTest(MapperSuperTest):
addresses = relation(mapper(Address, addresses), lazy = True),
orders = relation(ordermapper, primaryjoin = users.c.user_id==orders.c.user_id, lazy = True),
))
- l = m.select(limit=2, offset=1)
+ sess= create_session()
+ q = sess.query(m)
+ l = q.select(limit=2, offset=1)
self.assert_result(l, User, *user_all_result[1:3])
# use a union all to get a lot of rows to join against
u2 = users.alias('u2')
s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
print [key for key in s.c.keys()]
- l = m.select(s.c.u2_user_id==User.c.user_id, distinct=True)
+ l = q.select(s.c.u2_user_id==User.c.user_id, distinct=True)
self.assert_result(l, User, *user_all_result)
- objectstore.clear()
+ sess.clear()
m = mapper(Item, orderitems, is_primary=True, properties = dict(
keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True),
))
- l = m.select((Item.c.item_name=='item 2') | (Item.c.item_name=='item 5') | (Item.c.item_name=='item 3'), order_by=[Item.c.item_id], limit=2)
+
+ l = sess.query(m).select((Item.c.item_name=='item 2') | (Item.c.item_name=='item 5') | (Item.c.item_name=='item 3'), order_by=[Item.c.item_id], limit=2)
self.assert_result(l, Item, *[item_keyword_result[1], item_keyword_result[2]])
def testonetoone(self):
m = mapper(User, users, properties = dict(
address = relation(mapper(Address, addresses), lazy = True, uselist = False)
))
- l = m.select(users.c.user_id == 7)
- self.echo(repr(l))
- self.echo(repr(l[0].address))
+ q = create_session().query(m)
+ l = q.select(users.c.user_id == 7)
+ self.assert_result(l, User, {'user_id':7, 'address' : (Address, {'address_id':1})})
def testbackwardsonetoone(self):
m = mapper(Address, addresses, properties = dict(
- user = relation(mapper(User, users, properties = {'id':users.c.user_id}), lazy = True)
+ user = relation(mapper(User, users), lazy = True)
))
- l = m.select(addresses.c.address_id == 1)
+ q = create_session().query(m)
+ l = q.select(addresses.c.address_id == 1)
self.echo(repr(l))
print repr(l[0].__dict__)
self.echo(repr(l[0].user))
@@ -572,10 +630,11 @@ class LazyTest(MapperSuperTest):
closedorders = alias(orders, 'closedorders')
m = mapper(User, users, properties = dict(
addresses = relation(mapper(Address, addresses), lazy = False),
- open_orders = relation(mapper(Order, openorders), primaryjoin = and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id), lazy = True),
- closed_orders = relation(mapper(Order, closedorders), primaryjoin = and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id), lazy = True)
+ open_orders = relation(mapper(Order, openorders, entity_name='open'), primaryjoin = and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id), lazy = True),
+ closed_orders = relation(mapper(Order, closedorders,entity_name='closed'), primaryjoin = and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id), lazy = True)
))
- l = m.select()
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, User,
{'user_id' : 7,
'addresses' : (Address, [{'address_id' : 1}]),
@@ -596,10 +655,11 @@ class LazyTest(MapperSuperTest):
def testmanytomany(self):
"""tests a many-to-many lazy load"""
- assign_mapper(Item, orderitems, properties = dict(
+ mapper(Item, orderitems, properties = dict(
keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True),
))
- l = Item.mapper.select()
+ q = create_session().query(Item)
+ l = q.select()
self.assert_result(l, Item,
{'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
{'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
@@ -607,7 +667,7 @@ class LazyTest(MapperSuperTest):
{'item_id' : 4, 'keywords' : (Keyword, [])},
{'item_id' : 5, 'keywords' : (Keyword, [])}
)
- l = Item.mapper.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, Item.c.item_id==itemkeywords.c.item_id))
+ l = q.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, Item.c.item_id==itemkeywords.c.item_id))
self.assert_result(l, Item,
{'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
{'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
@@ -616,13 +676,13 @@ class LazyTest(MapperSuperTest):
class EagerTest(MapperSuperTest):
def testbasic(self):
"""tests a basic one-to-many eager load"""
-
m = mapper(Address, addresses)
m = mapper(User, users, properties = dict(
addresses = relation(m, lazy = False),
))
- l = m.select()
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, User, *user_address_result)
def testorderby(self):
@@ -631,7 +691,8 @@ class EagerTest(MapperSuperTest):
m = mapper(User, users, properties = dict(
addresses = relation(m, lazy = False, order_by=addresses.c.email_address),
))
- l = m.select()
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, User,
{'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
{'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])},
@@ -644,7 +705,8 @@ class EagerTest(MapperSuperTest):
m = mapper(User, users, properties = dict(
addresses = relation(m, lazy = False, order_by=[desc(addresses.c.email_address)]),
))
- l = m.select()
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, User,
{'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
@@ -661,20 +723,24 @@ class EagerTest(MapperSuperTest):
addresses = relation(mapper(Address, addresses), lazy = False),
orders = relation(ordermapper, primaryjoin = users.c.user_id==orders.c.user_id, lazy = False),
))
- l = m.select(limit=2, offset=1)
+ sess = create_session()
+ q = sess.query(m)
+
+ l = q.select(limit=2, offset=1)
self.assert_result(l, User, *user_all_result[1:3])
# this is an involved 3x union of the users table to get a lot of rows.
# then see if the "distinct" works its way out. you actually get the same
# result with or without the distinct, just via less or more rows.
u2 = users.alias('u2')
s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
- l = m.select(s.c.u2_user_id==User.c.user_id, distinct=True)
+ l = q.select(s.c.u2_user_id==User.c.user_id, distinct=True)
self.assert_result(l, User, *user_all_result)
- objectstore.clear()
+ sess.clear()
m = mapper(Item, orderitems, is_primary=True, properties = dict(
keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = False, order_by=[keywords.c.keyword_id]),
))
- l = m.select((Item.c.item_name=='item 2') | (Item.c.item_name=='item 5') | (Item.c.item_name=='item 3'), order_by=[Item.c.item_id], limit=2)
+ q = sess.query(m)
+ l = q.select((Item.c.item_name=='item 2') | (Item.c.item_name=='item 5') | (Item.c.item_name=='item 3'), order_by=[Item.c.item_id], limit=2)
self.assert_result(l, Item, *[item_keyword_result[1], item_keyword_result[2]])
@@ -683,7 +749,8 @@ class EagerTest(MapperSuperTest):
m = mapper(User, users, properties = dict(
address = relation(mapper(Address, addresses), lazy = False, uselist = False)
))
- l = m.select(users.c.user_id == 7)
+ q = create_session().query(m)
+ l = q.select(users.c.user_id == 7)
self.assert_result(l, User,
{'user_id' : 7, 'address' : (Address, {'address_id' : 1, 'email_address': 'jack@bean.com'})},
)
@@ -693,7 +760,8 @@ class EagerTest(MapperSuperTest):
user = relation(mapper(User, users), lazy = False)
))
self.echo(repr(m.props['user'].uselist))
- l = m.select(addresses.c.address_id == 1)
+ q = create_session().query(m)
+ l = q.select(addresses.c.address_id == 1)
self.assert_result(l, Address,
{'address_id' : 1, 'email_address' : 'jack@bean.com',
'user' : (User, {'user_id' : 7, 'user_name' : 'jack'})
@@ -707,7 +775,8 @@ class EagerTest(MapperSuperTest):
m = mapper(User, users, properties = dict(
addresses = relation(mapper(Address, addresses), primaryjoin = users.c.user_id==addresses.c.user_id, lazy = False)
))
- l = m.select(and_(addresses.c.email_address == 'ed@lala.com', addresses.c.user_id==users.c.user_id))
+ q = create_session().query(m)
+ l = q.select(and_(addresses.c.email_address == 'ed@lala.com', addresses.c.user_id==users.c.user_id))
self.assert_result(l, User,
{'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2, 'email_address':'ed@wood.com'}, {'address_id':3, 'email_address':'ed@bettyboop.com'}, {'address_id':4, 'email_address':'ed@lala.com'}])},
)
@@ -715,6 +784,7 @@ class EagerTest(MapperSuperTest):
def testcompile(self):
"""tests deferred operation of a pre-compiled mapper statement"""
+ session = create_session()
m = mapper(User, users, properties = dict(
addresses = relation(mapper(Address, addresses), lazy = False)
))
@@ -722,7 +792,7 @@ class EagerTest(MapperSuperTest):
c = s.compile()
self.echo("\n" + str(c) + repr(c.get_params()))
- l = m.instances(s.execute(emailad = 'jack@bean.com'))
+ l = m.instances(s.execute(emailad = 'jack@bean.com'), session)
self.echo(repr(l))
def testmulti(self):
@@ -731,7 +801,8 @@ class EagerTest(MapperSuperTest):
addresses = relation(mapper(Address, addresses), primaryjoin = users.c.user_id==addresses.c.user_id, lazy = False),
orders = relation(mapper(Order, orders), lazy = False),
))
- l = m.select()
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, User,
{'user_id' : 7,
'addresses' : (Address, [{'address_id' : 1}]),
@@ -754,10 +825,11 @@ class EagerTest(MapperSuperTest):
ordermapper = mapper(Order, orders)
m = mapper(User, users, properties = dict(
addresses = relation(mapper(Address, addresses), lazy = False),
- open_orders = relation(mapper(Order, openorders), primaryjoin = and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id), lazy = False),
- closed_orders = relation(mapper(Order, closedorders), primaryjoin = and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id), lazy = False)
+ open_orders = relation(mapper(Order, openorders, non_primary=True), primaryjoin = and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id), lazy = False),
+ closed_orders = relation(mapper(Order, closedorders, non_primary=True), primaryjoin = and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id), lazy = False)
))
- l = m.select()
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, User,
{'user_id' : 7,
'addresses' : (Address, [{'address_id' : 1}]),
@@ -787,7 +859,8 @@ class EagerTest(MapperSuperTest):
addresses = relation(mapper(Address, addresses), lazy = False),
orders = relation(ordermapper, primaryjoin = users.c.user_id==orders.c.user_id, lazy = False),
))
- l = m.select()
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, User, *user_all_result)
def testmanytomany(self):
@@ -796,16 +869,37 @@ class EagerTest(MapperSuperTest):
m = mapper(Item, items, properties = dict(
keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy=False, order_by=[keywords.c.keyword_id]),
))
- l = m.select()
+ q = create_session().query(m)
+ l = q.select()
self.assert_result(l, Item, *item_keyword_result)
-# l = m.select()
- l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id))
+ l = q.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id))
self.assert_result(l, Item,
{'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
{'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
)
+ def testmanytomanyoptions(self):
+ items = orderitems
+
+ m = mapper(Item, items, properties = dict(
+ keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy=True, order_by=[keywords.c.keyword_id]),
+ ))
+ m2 = m.options(eagerload('keywords'))
+ q = create_session().query(m2)
+ def go():
+ l = q.select()
+ self.assert_result(l, Item, *item_keyword_result)
+ self.assert_sql_count(db, go, 1)
+
+ def go():
+ l = q.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id))
+ self.assert_result(l, Item,
+ {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
+ {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
+ )
+ self.assert_sql_count(db, go, 1)
+
def testoneandmany(self):
"""tests eager load for a parent object with a child object that
contains a many-to-many relationship to a third object."""
@@ -819,7 +913,8 @@ class EagerTest(MapperSuperTest):
m = mapper(Order, orders, properties = dict(
items = relation(m, lazy = False)
))
- l = m.select("orders.order_id in (1,2,3)")
+ q = create_session().query(m)
+ l = q.select("orders.order_id in (1,2,3)")
self.assert_result(l, Order,
{'order_id' : 1, 'items': (Item, [])},
{'order_id' : 2, 'items': (Item, [
diff --git a/test/masscreate.py b/test/masscreate.py
index 885c1f653..5f99bb6c1 100644
--- a/test/masscreate.py
+++ b/test/masscreate.py
@@ -16,7 +16,7 @@ attr_manager = AttributeManager()
if manage_attributes:
attr_manager.register_attribute(User, 'id', uselist=False)
attr_manager.register_attribute(User, 'name', uselist=False)
- attr_manager.register_attribute(User, 'addresses', uselist=True)
+ attr_manager.register_attribute(User, 'addresses', uselist=True, trackparent=True)
attr_manager.register_attribute(Address, 'email', uselist=False)
now = time.time()
@@ -35,7 +35,7 @@ for i in range(0,130):
a.email = 'foo@bar.com'
u.addresses.append(a)
# gc.collect()
- print len(managed_attributes)
+# print len(managed_attributes)
# managed_attributes.clear()
total = time.time() - now
print "Total time", total
diff --git a/test/massload.py b/test/massload.py
index ab47415e8..d36746968 100644
--- a/test/massload.py
+++ b/test/massload.py
@@ -35,7 +35,7 @@ class LoadTest(AssertMixin):
clear_mappers()
for x in range(1,NUM/500+1):
l = []
- for y in range(x*500-499, x*500 + 1):
+ for y in range(x*500-500, x*500):
l.append({'item_id':y, 'value':'this is item #%d' % y})
items.insert().execute(*l)
@@ -47,7 +47,7 @@ class LoadTest(AssertMixin):
for x in range (1,NUM/100):
# this is not needed with cpython which clears non-circular refs immediately
#gc.collect()
- l = m.select(items.c.item_id.between(x*100 - 99, x*100 ))
+ l = m.select(items.c.item_id.between(x*100 - 100, x*100 - 1))
assert len(l) == 100
print "loaded ", len(l), " items "
# modifying each object will insure that the objects get placed in the "dirty" list
@@ -56,10 +56,10 @@ class LoadTest(AssertMixin):
a.value = 'changed...'
assert len(objectstore.get_session().dirty) == len(l)
assert len(objectstore.get_session().identity_map) == len(l)
- #assert len(attributes.managed_attributes) == len(l)
+ assert len(attributes.managed_attributes) == len(l)
print len(objectstore.get_session().dirty)
print len(objectstore.get_session().identity_map)
- objectstore.expunge(*l)
+ #objectstore.expunge(*l)
if __name__ == "__main__":
testbase.main()
diff --git a/test/objectstore.py b/test/objectstore.py
index 9ec6ca7e2..404f9ff94 100644
--- a/test/objectstore.py
+++ b/test/objectstore.py
@@ -3,12 +3,28 @@ import unittest, sys, os
from sqlalchemy import *
import StringIO
import testbase
-
+from sqlalchemy.orm.mapper import global_extensions
+from sqlalchemy.ext.sessioncontext import SessionContext
+import sqlalchemy.ext.assignmapper as assignmapper
from tables import *
import tables
-class HistoryTest(AssertMixin):
+class SessionTest(AssertMixin):
def setUpAll(self):
+ global ctx, assign_mapper
+ ctx = SessionContext(create_session)
+ def assign_mapper(*args, **kwargs):
+ return assignmapper.assign_mapper(ctx, *args, **kwargs)
+ global_extensions.append(ctx.mapper_extension)
+ def tearDownAll(self):
+ global_extensions.remove(ctx.mapper_extension)
+ def tearDown(self):
+ ctx.current.clear()
+ clear_mappers()
+
+class HistoryTest(SessionTest):
+ def setUpAll(self):
+ SessionTest.setUpAll(self)
db.echo = False
users.create()
addresses.create()
@@ -18,9 +34,7 @@ class HistoryTest(AssertMixin):
addresses.drop()
users.drop()
db.echo = testbase.echo
- def setUp(self):
- objectstore.clear()
- clear_mappers()
+ SessionTest.tearDownAll(self)
def testattr(self):
"""tests the rolling back of scalar and list attributes. this kind of thing
@@ -43,7 +57,7 @@ class HistoryTest(AssertMixin):
self.assert_result([u], data[0], *data[1:])
self.echo(repr(u.addresses))
- objectstore.uow().rollback_object(u)
+ ctx.current.uow.rollback_object(u)
data = [User,
{'user_name' : None,
'addresses' : (Address, [])
@@ -52,6 +66,7 @@ class HistoryTest(AssertMixin):
self.assert_result([u], data[0], *data[1:])
def testbackref(self):
+ s = create_session()
class User(object):pass
class Address(object):pass
am = mapper(Address, addresses)
@@ -59,177 +74,82 @@ class HistoryTest(AssertMixin):
addresses = relation(am, backref='user', lazy=False))
)
- u = User()
- a = Address()
+ u = User(_sa_session=s)
+ a = Address(_sa_session=s)
a.user = u
#print repr(a.__class__._attribute_manager.get_history(a, 'user').added_items())
#print repr(u.addresses.added_items())
self.assert_(u.addresses == [a])
- objectstore.commit()
+ s.flush()
- objectstore.clear()
- u = m.select()[0]
+ s.clear()
+ u = s.query(m).select()[0]
print u.addresses[0].user
-class SessionTest(AssertMixin):
- def setUpAll(self):
- db.echo = False
- users.create()
- db.echo = testbase.echo
- def tearDownAll(self):
- db.echo = False
- users.drop()
- db.echo = testbase.echo
- def setUp(self):
- objectstore.get_session().clear()
- clear_mappers()
- tables.user_data()
- #db.echo = "debug"
- def tearDown(self):
- tables.delete_user_data()
-
- def test_nested_begin_commit(self):
- """tests that nesting objectstore transactions with multiple commits
- affects only the outermost transaction"""
- class User(object):pass
- m = mapper(User, users)
- def name_of(id):
- return users.select(users.c.user_id == id).execute().fetchone().user_name
- name1 = "Oliver Twist"
- name2 = 'Mr. Bumble'
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- s = objectstore.get_session()
- trans = s.begin()
- trans2 = s.begin()
- m.get(7).user_name = name1
- trans3 = s.begin()
- m.get(8).user_name = name2
- trans3.commit()
- s.commit() # should do nothing
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- trans2.commit()
- s.commit() # should do nothing
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- trans.commit()
- self.assert_(name_of(7) == name1, msg="user_name should be %s" % name1)
- self.assert_(name_of(8) == name2, msg="user_name should be %s" % name2)
-
- def test_nested_rollback(self):
- """tests that nesting objectstore transactions with a rollback inside
- affects only the outermost transaction"""
- class User(object):pass
- m = mapper(User, users)
- def name_of(id):
- return users.select(users.c.user_id == id).execute().fetchone().user_name
- name1 = "Oliver Twist"
- name2 = 'Mr. Bumble'
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- s = objectstore.get_session()
- trans = s.begin()
- trans2 = s.begin()
- m.get(7).user_name = name1
- trans3 = s.begin()
- m.get(8).user_name = name2
- trans3.rollback()
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- trans2.commit()
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
- trans.commit()
- self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
- self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
-
- @testbase.unsupported('sqlite')
- def test_true_nested(self):
- """tests creating a new Session inside a database transaction, in
- conjunction with an engine-level nested transaction, which uses
- a second connection in order to achieve a nested transaction that commits, inside
- of another engine session that rolls back."""
-# testbase.db.echo='debug'
- class User(object):
- pass
- testbase.db.begin()
- try:
- m = mapper(User, users)
- name1 = "Oliver Twist"
- name2 = 'Mr. Bumble'
- m.get(7).user_name = name1
- s = objectstore.Session(nest_on=testbase.db)
- m.using(s).get(8).user_name = name2
- s.commit()
- objectstore.commit()
- testbase.db.rollback()
- except:
- testbase.db.rollback()
- raise
- objectstore.clear()
- self.assert_(m.get(8).user_name == name2)
- self.assert_(m.get(7).user_name != name1)
-class VersioningTest(AssertMixin):
+class VersioningTest(SessionTest):
def setUpAll(self):
- objectstore.clear()
+ SessionTest.setUpAll(self)
+ ctx.current.clear()
global version_table
version_table = Table('version_test', db,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, Sequence('version_test_seq'), primary_key=True ),
Column('version_id', Integer, nullable=False),
Column('value', String(40), nullable=False)
).create()
def tearDownAll(self):
version_table.drop()
+ SessionTest.tearDownAll(self)
def tearDown(self):
version_table.delete().execute()
- objectstore.clear()
- clear_mappers()
+ SessionTest.tearDown(self)
@testbase.unsupported('mysql')
def testbasic(self):
+ s = create_session()
class Foo(object):pass
assign_mapper(Foo, version_table, version_id_col=version_table.c.version_id)
- f1 =Foo(value='f1')
- f2 = Foo(value='f2')
- objectstore.commit()
+ f1 =Foo(value='f1', _sa_session=s)
+ f2 = Foo(value='f2', _sa_session=s)
+ s.flush()
f1.value='f1rev2'
- objectstore.commit()
- s = objectstore.Session()
- f1_s = Foo.mapper.using(s).get(f1.id)
+ s.flush()
+ s2 = create_session()
+ f1_s = Foo.mapper.using(s2).get(f1.id)
f1_s.value='f1rev3'
- s.commit()
+ s2.flush()
f1.value='f1rev3mine'
success = False
try:
# a concurrent session has modified this, should throw
# an exception
- objectstore.commit()
- except SQLAlchemyError:
+ s.flush()
+ except exceptions.SQLAlchemyError, e:
+ #print e
success = True
assert success
- objectstore.clear()
- f1 = Foo.mapper.get(f1.id)
- f2 = Foo.mapper.get(f2.id)
+ s.clear()
+ f1 = s.query(Foo).get(f1.id)
+ f2 = s.query(Foo).get(f2.id)
f1_s.value='f1rev4'
- s.commit()
+ s2.flush()
- objectstore.delete(f1, f2)
+ s.delete(f1, f2)
success = False
try:
- objectstore.commit()
- except SQLAlchemyError:
+ s.flush()
+ except exceptions.SQLAlchemyError, e:
+ #print e
success = True
assert success
-class UnicodeTest(AssertMixin):
+class UnicodeTest(SessionTest):
def setUpAll(self):
- objectstore.clear()
+ SessionTest.setUpAll(self)
global uni_table
uni_table = Table('uni_test', db,
Column('id', Integer, primary_key=True),
@@ -237,22 +157,25 @@ class UnicodeTest(AssertMixin):
def tearDownAll(self):
uni_table.drop()
- uni_table.deregister()
+ SessionTest.tearDownAll(self)
def testbasic(self):
class Test(object):
- pass
- assign_mapper(Test, uni_table)
+ def __init__(self, id, txt):
+ self.id = id
+ self.txt = txt
+ mapper(Test, uni_table)
txt = u"\u0160\u0110\u0106\u010c\u017d"
t1 = Test(id=1, txt = txt)
self.assert_(t1.txt == txt)
- objectstore.commit()
+ ctx.current.flush()
self.assert_(t1.txt == txt)
-class PKTest(AssertMixin):
+class PKTest(SessionTest):
def setUpAll(self):
+ SessionTest.setUpAll(self)
db.echo = False
global table
global table2
@@ -286,9 +209,7 @@ class PKTest(AssertMixin):
table2.drop()
table3.drop()
db.echo = testbase.echo
- def setUp(self):
- objectstore.clear()
- clear_mappers()
+ SessionTest.tearDownAll(self)
@testbase.unsupported('sqlite')
def testprimarykey(self):
@@ -299,9 +220,9 @@ class PKTest(AssertMixin):
e.name = 'entry1'
e.value = 'this is entry 1'
e.multi_rev = 2
- objectstore.commit()
- objectstore.clear()
- e2 = Entry.mapper.get(e.multi_id, 2)
+ ctx.current.flush()
+ ctx.current.clear()
+ e2 = Entry.mapper.get((e.multi_id, 2))
self.assert_(e is not e2 and e._instance_key == e2._instance_key)
def testmanualpk(self):
class Entry(object):
@@ -311,7 +232,7 @@ class PKTest(AssertMixin):
e.pk_col_1 = 'pk1'
e.pk_col_2 = 'pk1_related'
e.data = 'im the data'
- objectstore.commit()
+ ctx.current.flush()
def testkeypks(self):
import datetime
class Entity(object):
@@ -322,11 +243,12 @@ class PKTest(AssertMixin):
e.secondary = 'pk2'
e.assigned = datetime.date.today()
e.data = 'some more data'
- objectstore.commit()
+ ctx.current.flush()
-class PrivateAttrTest(AssertMixin):
+class PrivateAttrTest(SessionTest):
"""tests various things to do with private=True mappers"""
def setUpAll(self):
+ SessionTest.setUpAll(self)
global a_table, b_table
a_table = Table('a',testbase.db,
Column('a_id', Integer, Sequence('next_a_id'), primary_key=True),
@@ -340,9 +262,7 @@ class PrivateAttrTest(AssertMixin):
def tearDownAll(self):
b_table.drop()
a_table.drop()
- def setUp(self):
- objectstore.clear()
- clear_mappers()
+ SessionTest.tearDownAll(self)
def testsinglecommit(self):
"""tests that a commit of a single object deletes private relationships"""
@@ -350,8 +270,7 @@ class PrivateAttrTest(AssertMixin):
class B(object):pass
assign_mapper(B,b_table)
- assign_mapper(A,a_table,properties= {'bs' : relation
- (B.mapper,private=True)})
+ assign_mapper(A,a_table,properties= {'bs' : relation(B.mapper,private=True)})
# create some objects
a = A(data='a1')
@@ -366,10 +285,10 @@ class PrivateAttrTest(AssertMixin):
a.bs.append(b2)
# inserts both A and Bs
- objectstore.commit(a)
+ ctx.current.flush([a])
- objectstore.delete(a)
- objectstore.commit(a)
+ ctx.current.delete(a)
+ ctx.current.flush([a])
assert b_table.count().scalar() == 0
@@ -386,23 +305,24 @@ class PrivateAttrTest(AssertMixin):
a2 = A(data='testa2')
b = B(data='testb')
b.a = a1
- objectstore.commit()
- objectstore.clear()
- sess = objectstore.get_session()
+ ctx.current.flush()
+ ctx.current.clear()
+ sess = ctx.current
a1 = A.mapper.get(a1.a_id)
a2 = A.mapper.get(a2.a_id)
assert a1.bs[0].a is a1
b = a1.bs[0]
b.a = a2
assert b not in sess.deleted
- objectstore.commit()
+ ctx.current.flush()
assert b in sess.identity_map.values()
-class DefaultTest(AssertMixin):
+class DefaultTest(SessionTest):
"""tests that when saving objects whose table contains DefaultGenerators, either python-side, preexec or database-side,
the newly saved instances receive all the default values either through a post-fetch or getting the pre-exec'ed
defaults back from the engine."""
def setUpAll(self):
+ SessionTest.setUpAll(self)
#db.echo = 'debug'
use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite')
@@ -423,6 +343,7 @@ class DefaultTest(AssertMixin):
self.table.create()
def tearDownAll(self):
self.table.drop()
+ SessionTest.tearDownAll(self)
def setUp(self):
self.table = Table('default_test', db)
def testinsert(self):
@@ -433,7 +354,7 @@ class DefaultTest(AssertMixin):
h3 = Hoho(hoho=self.althohoval, counter=12)
h4 = Hoho()
h5 = Hoho(foober='im the new foober')
- objectstore.commit()
+ ctx.current.flush()
self.assert_(h1.hoho==self.althohoval)
self.assert_(h3.hoho==self.althohoval)
self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval)
@@ -441,7 +362,7 @@ class DefaultTest(AssertMixin):
self.assert_(h1.counter == h4.counter==h5.counter==7)
self.assert_(h2.foober == h3.foober == h4.foober == 'im foober')
self.assert_(h5.foober=='im the new foober')
- objectstore.clear()
+ ctx.current.clear()
l = Hoho.mapper.select()
(h1, h2, h3, h4, h5) = l
self.assert_(h1.hoho==self.althohoval)
@@ -457,7 +378,7 @@ class DefaultTest(AssertMixin):
class Hoho(object):pass
assign_mapper(Hoho, self.table)
h1 = Hoho(hoho="15", counter="15")
- objectstore.commit()
+ ctx.current.flush()
self.assert_(h1.hoho=="15")
self.assert_(h1.counter=="15")
self.assert_(h1.foober=="im foober")
@@ -466,30 +387,27 @@ class DefaultTest(AssertMixin):
class Hoho(object):pass
assign_mapper(Hoho, self.table)
h1 = Hoho()
- objectstore.commit()
+ ctx.current.flush()
self.assert_(h1.foober == 'im foober')
h1.counter = 19
- objectstore.commit()
+ ctx.current.flush()
self.assert_(h1.foober == 'im the update')
-class SaveTest(AssertMixin):
+class SaveTest(SessionTest):
def setUpAll(self):
+ SessionTest.setUpAll(self)
db.echo = False
tables.create()
db.echo = testbase.echo
def tearDownAll(self):
db.echo = False
- db.commit()
tables.drop()
db.echo = testbase.echo
+ SessionTest.tearDownAll(self)
def setUp(self):
db.echo = False
- # remove all history/identity maps etc.
- objectstore.clear()
- # remove all mapperes
- clear_mappers()
keywords.insert().execute(
dict(name='blue'),
dict(name='red'),
@@ -499,32 +417,29 @@ class SaveTest(AssertMixin):
dict(name='round'),
dict(name='square')
)
- db.commit()
db.echo = testbase.echo
def tearDown(self):
db.echo = False
- db.commit()
tables.delete()
db.echo = testbase.echo
- self.assert_(len(objectstore.uow().new) == 0)
- self.assert_(len(objectstore.uow().dirty) == 0)
- self.assert_(len(objectstore.uow().modified_lists) == 0)
-
+ #self.assert_(len(ctx.current.new) == 0)
+ #self.assert_(len(ctx.current.dirty) == 0)
+ SessionTest.tearDown(self)
+
def testbasic(self):
# save two users
u = User()
u.user_name = 'savetester'
-
m = mapper(User, users)
u2 = User()
u2.user_name = 'savetester2'
- objectstore.uow().register_new(u)
+ ctx.current.save(u)
- objectstore.uow().commit(u)
- objectstore.uow().commit()
+ ctx.current.flush([u])
+ ctx.current.flush()
# assert the first one retreives the same from the identity map
nu = m.get(u.user_id)
@@ -532,18 +447,20 @@ class SaveTest(AssertMixin):
self.assert_(u is nu)
# clear out the identity map, so next get forces a SELECT
- objectstore.clear()
+ ctx.current.clear()
# check it again, identity should be different but ids the same
nu = m.get(u.user_id)
self.assert_(u is not nu and u.user_id == nu.user_id and nu.user_name == 'savetester')
# change first users name and save
+ ctx.current.update(u)
u.user_name = 'modifiedname'
- objectstore.uow().commit()
+ assert u in ctx.current.dirty
+ ctx.current.flush()
# select both
- #objectstore.clear()
+ #ctx.current.clear()
userlist = m.select(users.c.user_id.in_(u.user_id, u2.user_id), order_by=[users.c.user_name])
print repr(u.user_id), repr(userlist[0].user_id), repr(userlist[0].user_name)
self.assert_(u.user_id == userlist[0].user_id and userlist[0].user_name == 'modifiedname')
@@ -561,12 +478,12 @@ class SaveTest(AssertMixin):
u.addresses.append(Address())
u.addresses.append(Address())
u.addresses.append(Address())
- objectstore.commit()
- objectstore.clear()
+ ctx.current.flush()
+ ctx.current.clear()
ulist = m1.select()
u1 = ulist[0]
u1.user_name = 'newname'
- objectstore.commit()
+ ctx.current.flush()
self.assert_(len(u1.addresses) == 4)
def testinherits(self):
@@ -583,8 +500,8 @@ class SaveTest(AssertMixin):
)
au = AddressUser()
- objectstore.commit()
- objectstore.clear()
+ ctx.current.flush()
+ ctx.current.clear()
l = AddressUser.mapper.selectone()
self.assert_(l.user_id == au.user_id and l.address_id == au.address_id)
@@ -592,9 +509,9 @@ class SaveTest(AssertMixin):
"""tests a save of an object where each instance spans two tables. also tests
redefinition of the keynames for the column properties."""
usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
- print usersaddresses._get_col_by_original(users.c.user_id)
+ print usersaddresses.corresponding_column(users.c.user_id)
print repr(usersaddresses._orig_cols)
- m = mapper(User, usersaddresses, primarytable = users,
+ m = mapper(User, usersaddresses,
properties = dict(
email = addresses.c.email_address,
foo_id = [users.c.user_id, addresses.c.user_id],
@@ -605,7 +522,7 @@ class SaveTest(AssertMixin):
u.user_name = 'multitester'
u.email = 'multi@test.org'
- objectstore.uow().commit()
+ ctx.current.flush()
usertable = users.select(users.c.user_id.in_(u.foo_id)).execute().fetchall()
self.assertEqual(usertable[0].values(), [u.foo_id, 'multitester'])
@@ -614,7 +531,7 @@ class SaveTest(AssertMixin):
u.email = 'lala@hey.com'
u.user_name = 'imnew'
- objectstore.uow().commit()
+ ctx.current.flush()
usertable = users.select(users.c.user_id.in_(u.foo_id)).execute().fetchall()
self.assertEqual(usertable[0].values(), [u.foo_id, 'imnew'])
@@ -632,12 +549,34 @@ class SaveTest(AssertMixin):
u.user_name = 'one2onetester'
u.address = Address()
u.address.email_address = 'myonlyaddress@foo.com'
- objectstore.uow().commit()
+ ctx.current.flush()
u.user_name = 'imnew'
- objectstore.uow().commit()
+ ctx.current.flush()
u.address.email_address = 'imnew@foo.com'
- objectstore.uow().commit()
+ ctx.current.flush()
+ def testchildmove(self):
+ """tests moving a child from one parent to the other, then deleting the first parent, properly
+ updates the child with the new parent. this tests the 'trackparent' option in the attributes module."""
+ m = mapper(User, users, properties = dict(
+ addresses = relation(mapper(Address, addresses), lazy = True, private = False)
+ ))
+ u1 = User()
+ u1.user_name = 'user1'
+ u2 = User()
+ u2.user_name = 'user2'
+ a = Address()
+ a.email_address = 'address1'
+ u1.addresses.append(a)
+ ctx.current.flush()
+ del u1.addresses[0]
+ u2.addresses.append(a)
+ ctx.current.delete(u1)
+ ctx.current.flush()
+ ctx.current.clear()
+ u2 = m.get(u2.user_id)
+ assert len(u2.addresses) == 1
+
def testdelete(self):
m = mapper(User, users, properties = dict(
address = relation(mapper(Address, addresses), lazy = True, uselist = False, private = False)
@@ -647,89 +586,11 @@ class SaveTest(AssertMixin):
u.user_name = 'one2onetester'
u.address = a
u.address.email_address = 'myonlyaddress@foo.com'
- objectstore.uow().commit()
+ ctx.current.flush()
self.echo("\n\n\n")
- objectstore.uow().register_deleted(u)
- objectstore.uow().commit()
- self.assert_(a.address_id is not None and a.user_id is None and not objectstore.uow().identity_map.has_key(u._instance_key) and objectstore.uow().identity_map.has_key(a._instance_key))
-
- def testcascadingdelete(self):
- m = mapper(User, users, properties = dict(
- address = relation(mapper(Address, addresses), lazy = False, uselist = False, private = True),
- orders = relation(
- mapper(Order, orders, properties = dict (
- items = relation(mapper(Item, orderitems), lazy = False, uselist =True, private = True)
- )),
- lazy = True, uselist = True, private = True)
- ))
-
- data = [User,
- {'user_name' : 'ed',
- 'address' : (Address, {'email_address' : 'foo@bar.com'}),
- 'orders' : (Order, [
- {'description' : 'eds 1st order', 'items' : (Item, [{'item_name' : 'eds o1 item'}, {'item_name' : 'eds other o1 item'}])},
- {'description' : 'eds 2nd order', 'items' : (Item, [{'item_name' : 'eds o2 item'}, {'item_name' : 'eds other o2 item'}])}
- ])
- },
- {'user_name' : 'jack',
- 'address' : (Address, {'email_address' : 'jack@jack.com'}),
- 'orders' : (Order, [
- {'description' : 'jacks 1st order', 'items' : (Item, [{'item_name' : 'im a lumberjack'}, {'item_name' : 'and im ok'}])}
- ])
- },
- {'user_name' : 'foo',
- 'address' : (Address, {'email_address': 'hi@lala.com'}),
- 'orders' : (Order, [
- {'description' : 'foo order', 'items' : (Item, [])},
- {'description' : 'foo order 2', 'items' : (Item, [{'item_name' : 'hi'}])},
- {'description' : 'foo order three', 'items' : (Item, [{'item_name' : 'there'}])}
- ])
- }
- ]
-
- for elem in data[1:]:
- u = User()
- u.user_name = elem['user_name']
- u.address = Address()
- u.address.email_address = elem['address'][1]['email_address']
- u.orders = []
- for order in elem['orders'][1]:
- o = Order()
- o.isopen = None
- o.description = order['description']
- u.orders.append(o)
- o.items = []
- for item in order['items'][1]:
- i = Item()
- i.item_name = item['item_name']
- o.items.append(i)
-
- objectstore.uow().commit()
- objectstore.clear()
-
- l = m.select()
- for u in l:
- self.echo( repr(u.orders))
- self.assert_result(l, data[0], *data[1:])
-
- self.echo("\n\n\n")
- objectstore.uow().register_deleted(l[0])
- objectstore.uow().register_deleted(l[2])
- objectstore.commit()
- return
- res = self.capture_exec(db, lambda: objectstore.uow().commit())
- state = None
-
- for line in res.split('\n'):
- if line == "DELETE FROM items WHERE items.item_id = :item_id":
- self.assert_(state is None or state == 'addresses')
- elif line == "DELETE FROM orders WHERE orders.order_id = :order_id":
- state = 'orders'
- elif line == "DELETE FROM email_addresses WHERE email_addresses.address_id = :address_id":
- if state is None:
- state = 'addresses'
- elif line == "DELETE FROM users WHERE users.user_id = :user_id":
- self.assert_(state is not None)
+ ctx.current.delete(u)
+ ctx.current.flush()
+ self.assert_(a.address_id is not None and a.user_id is None and not ctx.current.identity_map.has_key(u._instance_key) and ctx.current.identity_map.has_key(a._instance_key))
def testbackwardsonetoone(self):
# test 'backwards'
@@ -755,37 +616,37 @@ class SaveTest(AssertMixin):
a.user.user_name = elem['user_name']
objects.append(a)
- objectstore.uow().commit()
+ ctx.current.flush()
objects[2].email_address = 'imnew@foo.bar'
objects[3].user = User()
objects[3].user.user_name = 'imnewlyadded'
- self.assert_sql(db, lambda: objectstore.uow().commit(), [
+ self.assert_sql(db, lambda: ctx.current.flush(), [
(
"INSERT INTO users (user_name) VALUES (:user_name)",
{'user_name': 'imnewlyadded'}
),
{
"UPDATE email_addresses SET email_address=:email_address WHERE email_addresses.address_id = :email_addresses_address_id":
- lambda: [{'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id}]
+ lambda ctx: {'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id}
,
"UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id":
- lambda: [{'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id}]
+ lambda ctx: {'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id}
},
],
with_sequences=[
(
"INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)",
- lambda:{'user_name': 'imnewlyadded', 'user_id':db.last_inserted_ids()[0]}
+ lambda ctx:{'user_name': 'imnewlyadded', 'user_id':ctx.last_inserted_ids()[0]}
),
{
"UPDATE email_addresses SET email_address=:email_address WHERE email_addresses.address_id = :email_addresses_address_id":
- lambda: [{'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id}]
+ lambda ctx: {'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id}
,
"UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id":
- lambda: [{'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id}]
+ lambda ctx: {'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id}
},
])
@@ -810,7 +671,7 @@ class SaveTest(AssertMixin):
u.addresses.append(a2)
self.echo( repr(u.addresses))
self.echo( repr(u.addresses.added_items()))
- objectstore.uow().commit()
+ ctx.current.flush()
usertable = users.select(users.c.user_id.in_(u.user_id)).execute().fetchall()
self.assertEqual(usertable[0].values(), [u.user_id, 'one2manytester'])
@@ -823,7 +684,7 @@ class SaveTest(AssertMixin):
a2.email_address = 'somethingnew@foo.com'
- objectstore.uow().commit()
+ ctx.current.flush()
addresstable = addresses.select(addresses.c.address_id == addressid).execute().fetchall()
@@ -837,7 +698,6 @@ class SaveTest(AssertMixin):
dict(user_id = 8, user_name = 'ed'),
dict(user_id = 9, user_name = 'fred')
)
- db.commit()
# mapper with just users table
assign_mapper(User, users)
@@ -853,7 +713,7 @@ class SaveTest(AssertMixin):
u[0].addresses[0].email_address='hi'
# insure that upon commit, the new mapper with the address relation is used
- self.assert_sql(db, lambda: objectstore.commit(),
+ self.assert_sql(db, lambda: ctx.current.flush(),
[
(
"INSERT INTO email_addresses (user_id, email_address) VALUES (:user_id, :email_address)",
@@ -863,7 +723,7 @@ class SaveTest(AssertMixin):
with_sequences=[
(
"INSERT INTO email_addresses (address_id, user_id, email_address) VALUES (:address_id, :user_id, :email_address)",
- lambda:{'email_address': 'hi', 'user_id': 7, 'address_id':db.last_inserted_ids()[0]}
+ lambda ctx:{'email_address': 'hi', 'user_id': 7, 'address_id':ctx.last_inserted_ids()[0]}
),
]
)
@@ -890,7 +750,7 @@ class SaveTest(AssertMixin):
a3 = Address()
a3.email_address = 'emailaddress3'
- objectstore.commit()
+ ctx.current.flush()
self.echo("\n\n\n")
# modify user2 directly, append an address to user1.
@@ -899,18 +759,18 @@ class SaveTest(AssertMixin):
u2.user_name = 'user2modified'
u1.addresses.append(a3)
del u1.addresses[0]
- self.assert_sql(db, lambda: objectstore.commit(),
+ self.assert_sql(db, lambda: ctx.current.flush(),
[
(
"UPDATE users SET user_name=:user_name WHERE users.user_id = :users_user_id",
- [{'users_user_id': u2.user_id, 'user_name': 'user2modified'}]
+ {'users_user_id': u2.user_id, 'user_name': 'user2modified'}
),
(
"UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id",
- [{'user_id': u1.user_id, 'email_addresses_address_id': a3.address_id}]
+ {'user_id': u1.user_id, 'email_addresses_address_id': a3.address_id}
),
("UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id",
- [{'user_id': None, 'email_addresses_address_id': a1.address_id}]
+ {'user_id': None, 'email_addresses_address_id': a1.address_id}
)
])
@@ -924,12 +784,12 @@ class SaveTest(AssertMixin):
u1.user_name='user1'
a1.user = u1
- objectstore.commit()
+ ctx.current.flush()
self.echo("\n\n\n")
- objectstore.delete(u1)
+ ctx.current.delete(u1)
a1.user = None
- objectstore.commit()
+ ctx.current.flush()
def _testalias(self):
"""tests that an alias of a table can be used in a mapper.
@@ -971,12 +831,13 @@ class SaveTest(AssertMixin):
def testmanytomany(self):
items = orderitems
+ keywordmapper = mapper(Keyword, keywords)
+
items.select().execute()
m = mapper(Item, items, properties = dict(
- keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = False),
+ keywords = relation(keywordmapper, itemkeywords, lazy = False),
))
- keywordmapper = mapper(Keyword, keywords)
data = [Item,
{'item_name': 'mm_item1', 'keywords' : (Keyword,[{'name': 'big'},{'name': 'green'}, {'name': 'purple'},{'name': 'round'}])},
@@ -1007,7 +868,7 @@ class SaveTest(AssertMixin):
k.name = kname
item.keywords.append(k)
- objectstore.uow().commit()
+ ctx.current.flush()
l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name, keywords.c.name])
self.assert_result(l, *data)
@@ -1016,48 +877,48 @@ class SaveTest(AssertMixin):
k = Keyword()
k.name = 'yellow'
objects[5].keywords.append(k)
- self.assert_sql(db, lambda:objectstore.commit(), [
+ self.assert_sql(db, lambda:ctx.current.flush(), [
{
"UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id":
- [{'item_name': 'item4updated', 'items_item_id': objects[4].item_id}]
+ {'item_name': 'item4updated', 'items_item_id': objects[4].item_id}
,
"INSERT INTO keywords (name) VALUES (:name)":
{'name': 'yellow'}
},
("INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)",
- lambda: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}]
+ lambda ctx: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}]
)
],
with_sequences = [
{
"UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id":
- [{'item_name': 'item4updated', 'items_item_id': objects[4].item_id}]
+ {'item_name': 'item4updated', 'items_item_id': objects[4].item_id}
,
"INSERT INTO keywords (keyword_id, name) VALUES (:keyword_id, :name)":
- lambda: {'name': 'yellow', 'keyword_id':db.last_inserted_ids()[0]}
+ lambda ctx: {'name': 'yellow', 'keyword_id':ctx.last_inserted_ids()[0]}
},
("INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)",
- lambda: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}]
+ lambda ctx: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}]
)
]
)
objects[2].keywords.append(k)
dkid = objects[5].keywords[1].keyword_id
del objects[5].keywords[1]
- self.assert_sql(db, lambda:objectstore.commit(), [
+ self.assert_sql(db, lambda:ctx.current.flush(), [
(
"DELETE FROM itemkeywords WHERE itemkeywords.item_id = :item_id AND itemkeywords.keyword_id = :keyword_id",
[{'item_id': objects[5].item_id, 'keyword_id': dkid}]
),
(
"INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)",
- lambda: [{'item_id': objects[2].item_id, 'keyword_id': k.keyword_id}]
+ lambda ctx: [{'item_id': objects[2].item_id, 'keyword_id': k.keyword_id}]
)
])
- objectstore.delete(objects[3])
- objectstore.commit()
+ ctx.current.delete(objects[3])
+ ctx.current.flush()
def testassociation(self):
class IKAssociation(object):
@@ -1072,7 +933,7 @@ class SaveTest(AssertMixin):
# the reorganization of mapper construction affected this, but was fixed again
m = mapper(Item, items, properties = dict(
keywords = relation(mapper(IKAssociation, itemkeywords, properties = dict(
- keyword = relation(mapper(Keyword, keywords), lazy = False, uselist = False)
+ keyword = relation(mapper(Keyword, keywords, non_primary=True), lazy = False, uselist = False)
), primary_key = [itemkeywords.c.item_id, itemkeywords.c.keyword_id]),
lazy = False)
))
@@ -1117,8 +978,8 @@ class SaveTest(AssertMixin):
ik.keyword = k
item.keywords.append(ik)
- objectstore.uow().commit()
- objectstore.clear()
+ ctx.current.flush()
+ ctx.current.clear()
l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name, keywords.c.name])
self.assert_result(l, *data)
@@ -1136,7 +997,7 @@ class SaveTest(AssertMixin):
a = Address()
a.email_address = 'testaddress'
a.user = u
- objectstore.commit()
+ ctx.current.flush()
print repr(u.addresses)
x = False
try:
@@ -1148,8 +1009,8 @@ class SaveTest(AssertMixin):
if x:
self.assert_(False, "User addresses element should be scalar based")
- objectstore.delete(u)
- objectstore.commit()
+ ctx.current.delete(u)
+ ctx.current.flush()
def testdoublerelation(self):
m2 = mapper(Address, addresses)
@@ -1169,13 +1030,13 @@ class SaveTest(AssertMixin):
u.boston_addresses.append(a)
u.newyork_addresses.append(b)
- objectstore.commit()
+ ctx.current.flush()
-class SaveTest2(AssertMixin):
+class SaveTest2(SessionTest):
def setUp(self):
db.echo = False
- objectstore.clear()
+ ctx.current.clear()
clear_mappers()
self.users = Table('users', db,
Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
@@ -1201,6 +1062,7 @@ class SaveTest2(AssertMixin):
self.addresses.drop()
self.users.drop()
db.echo = testbase.echo
+ SessionTest.tearDown(self)
def testbackwardsnonmatch(self):
m = mapper(Address, self.addresses, properties = dict(
@@ -1217,7 +1079,7 @@ class SaveTest2(AssertMixin):
a.user = User()
a.user.user_name = elem['user_name']
objects.append(a)
- self.assert_sql(db, lambda: objectstore.commit(), [
+ self.assert_sql(db, lambda: ctx.current.flush(), [
(
"INSERT INTO users (user_name) VALUES (:user_name)",
{'user_name': 'thesub'}
@@ -1239,19 +1101,19 @@ class SaveTest2(AssertMixin):
with_sequences = [
(
"INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)",
- lambda: {'user_name': 'thesub', 'user_id':db.last_inserted_ids()[0]}
+ lambda ctx: {'user_name': 'thesub', 'user_id':ctx.last_inserted_ids()[0]}
),
(
"INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)",
- lambda: {'user_name': 'assdkfj', 'user_id':db.last_inserted_ids()[0]}
+ lambda ctx: {'user_name': 'assdkfj', 'user_id':ctx.last_inserted_ids()[0]}
),
(
"INSERT INTO email_addresses (address_id, rel_user_id, email_address) VALUES (:address_id, :rel_user_id, :email_address)",
- lambda:{'rel_user_id': 1, 'email_address': 'bar@foo.com', 'address_id':db.last_inserted_ids()[0]}
+ lambda ctx:{'rel_user_id': 1, 'email_address': 'bar@foo.com', 'address_id':ctx.last_inserted_ids()[0]}
),
(
"INSERT INTO email_addresses (address_id, rel_user_id, email_address) VALUES (:address_id, :rel_user_id, :email_address)",
- lambda:{'rel_user_id': 2, 'email_address': 'thesdf@asdf.com', 'address_id':db.last_inserted_ids()[0]}
+ lambda ctx:{'rel_user_id': 2, 'email_address': 'thesdf@asdf.com', 'address_id':ctx.last_inserted_ids()[0]}
)
]
)
diff --git a/test/onetoone.py b/test/onetoone.py
index 9ff330c92..5dc5b1204 100644
--- a/test/onetoone.py
+++ b/test/onetoone.py
@@ -1,5 +1,6 @@
from sqlalchemy import *
import testbase
+from sqlalchemy.ext.sessioncontext import SessionContext
class Jack(object):
def __repr__(self):
@@ -23,8 +24,10 @@ class Port(object):
class O2OTest(testbase.AssertMixin):
def setUpAll(self):
- global jack, port
- jack = Table('jack', testbase.db,
+ global jack, port, metadata, ctx
+ metadata = BoundMetaData(testbase.db)
+ ctx = SessionContext(create_session)
+ jack = Table('jack', metadata,
Column('id', Integer, primary_key=True),
#Column('room_id', Integer, ForeignKey("room.id")),
Column('number', String(50)),
@@ -33,54 +36,54 @@ class O2OTest(testbase.AssertMixin):
)
- port = Table('port', testbase.db,
+ port = Table('port', metadata,
Column('id', Integer, primary_key=True),
#Column('device_id', Integer, ForeignKey("device.id")),
Column('name', String(30)),
Column('description', String(100)),
Column('jack_id', Integer, ForeignKey("jack.id")),
)
- jack.create()
- port.create()
+ metadata.create_all()
def setUp(self):
- objectstore.clear()
+ pass
def tearDown(self):
clear_mappers()
def tearDownAll(self):
- port.drop()
- jack.drop()
+ metadata.drop_all()
def test1(self):
- assign_mapper(Port, port)
- assign_mapper(Jack, jack, order_by=[jack.c.number],properties = {
- 'port': relation(Port.mapper, backref='jack', uselist=False, lazy=True),
- })
+ mapper(Port, port, extension=ctx.mapper_extension)
+ mapper(Jack, jack, order_by=[jack.c.number],properties = {
+ 'port': relation(Port, backref='jack', uselist=False, lazy=True),
+ }, extension=ctx.mapper_extension)
j=Jack(number='101')
p=Port(name='fa0/1')
j.port=p
- objectstore.commit()
+ ctx.current.flush()
jid = j.id
pid = p.id
- j=Jack.get(jid)
- p=Port.get(pid)
+ j=ctx.current.query(Jack).get(jid)
+ p=ctx.current.query(Port).get(pid)
print p.jack
- print j.port
+ assert p.jack is not None
+ assert p.jack is j
+ assert j.port is not None
p.jack=None
assert j.port is None #works
- objectstore.clear()
+ ctx.current.clear()
- j=Jack.get(jid)
- p=Port.get(pid)
+ j=ctx.current.query(Jack).get(jid)
+ p=ctx.current.query(Port).get(pid)
j.port=None
self.assert_(p.jack is None)
- objectstore.commit()
+ ctx.current.flush()
- j.delete()
- objectstore.commit()
+ ctx.current.delete(j)
+ ctx.current.flush()
if __name__ == "__main__":
testbase.main()
diff --git a/test/parseconnect.py b/test/parseconnect.py
new file mode 100644
index 000000000..e1f50e8c9
--- /dev/null
+++ b/test/parseconnect.py
@@ -0,0 +1,29 @@
+from testbase import PersistTest
+import sqlalchemy.engine.url as url
+import unittest
+
+class ParseConnectTest(PersistTest):
+ def testrfc1738(self):
+ for text in (
+ 'dbtype://username:password@hostspec:110//usr/db_file.db',
+ 'dbtype://username:password@hostspec/database',
+ 'dbtype://username:password@hostspec',
+ 'dbtype://username:password@/database',
+ 'dbtype://username@hostspec',
+ 'dbtype://username:password@127.0.0.1:1521',
+ 'dbtype://hostspec/database',
+ 'dbtype://hostspec',
+ 'dbtype:///database',
+ 'dbtype:///:memory:',
+ 'dbtype:///foo/bar/im/a/file',
+ 'dbtype:///E:/work/src/LEM/db/hello.db',
+ 'dbtype://'
+ ):
+ u = url.make_url(text)
+ print u, text
+ assert str(u) == text
+
+
+if __name__ == "__main__":
+ unittest.main()
+ \ No newline at end of file
diff --git a/test/polymorph.py b/test/polymorph.py
new file mode 100644
index 000000000..ec7e95a0a
--- /dev/null
+++ b/test/polymorph.py
@@ -0,0 +1,169 @@
+import testbase
+from sqlalchemy import *
+import sets
+
+# test classes
+class Person(object):
+ def __init__(self, **kwargs):
+ for key, value in kwargs.iteritems():
+ setattr(self, key, value)
+ def get_name(self):
+ try:
+ return getattr(self, 'person_name')
+ except AttributeError:
+ return getattr(self, 'name')
+ def __repr__(self):
+ return "Ordinary person %s" % self.get_name()
+class Engineer(Person):
+ def __repr__(self):
+ return "Engineer %s, status %s, engineer_name %s, primary_language %s" % (self.get_name(), self.status, self.engineer_name, self.primary_language)
+class Manager(Person):
+ def __repr__(self):
+ return "Manager %s, status %s, manager_name %s" % (self.get_name(), self.status, self.manager_name)
+class Company(object):
+ def __init__(self, **kwargs):
+ for key, value in kwargs.iteritems():
+ setattr(self, key, value)
+ def __repr__(self):
+ return "Company %s" % self.name
+
+class MultipleTableTest(testbase.PersistTest):
+ def setUpAll(self, use_person_column=False):
+ global companies, people, engineers, managers, metadata
+ metadata = BoundMetaData(testbase.db)
+
+ # a table to store companies
+ companies = Table('companies', metadata,
+ Column('company_id', Integer, primary_key=True),
+ Column('name', String(50)))
+
+ # we will define an inheritance relationship between the table "people" and "engineers",
+ # and a second inheritance relationship between the table "people" and "managers"
+ people = Table('people', metadata,
+ Column('person_id', Integer, primary_key=True),
+ Column('company_id', Integer, ForeignKey('companies.company_id')),
+ Column('name', String(50)),
+ Column('type', String(30)))
+
+ engineers = Table('engineers', metadata,
+ Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+ Column('status', String(30)),
+ Column('engineer_name', String(50)),
+ Column('primary_language', String(50)),
+ )
+
+ managers = Table('managers', metadata,
+ Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+ Column('status', String(30)),
+ Column('manager_name', String(50))
+ )
+
+ metadata.create_all()
+
+ def tearDownAll(self):
+ metadata.drop_all()
+
+ def tearDown(self):
+ clear_mappers()
+ for t in metadata.table_iterator(reverse=True):
+ t.delete().execute()
+
+ def test_f_f_f(self):
+ self.do_test(False, False, False)
+ def test_f_f_t(self):
+ self.do_test(False, False, True)
+ def test_f_t_f(self):
+ self.do_test(False, True, False)
+ def test_f_t_t(self):
+ self.do_test(False, True, True)
+ def test_t_f_f(self):
+ self.do_test(True, False, False)
+ def test_t_f_t(self):
+ self.do_test(True, False, True)
+ def test_t_t_f(self):
+ self.do_test(True, True, False)
+ def test_t_t_t(self):
+ self.do_test(True, True, True)
+
+
+ def do_test(self, include_base=False, lazy_relation=True, redefine_colprop=False):
+ """tests the polymorph.py example, with several options:
+
+ include_base - whether or not to include the base 'person' type in the union.
+ lazy_relation - whether or not the Company relation to People is lazy or eager.
+ redefine_colprop - if we redefine the 'name' column to be 'people_name' on the base Person class
+ """
+ # create a union that represents both types of joins.
+ if include_base:
+ person_join = polymorphic_union(
+ {
+ 'engineer':people.join(engineers),
+ 'manager':people.join(managers),
+ 'person':people.select(people.c.type=='person'),
+ }, None, 'pjoin')
+ else:
+ person_join = polymorphic_union(
+ {
+ 'engineer':people.join(engineers),
+ 'manager':people.join(managers),
+ }, None, 'pjoin')
+
+ if redefine_colprop:
+ person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
+ else:
+ person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person')
+
+ mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
+ mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
+
+ mapper(Company, companies, properties={
+ 'employees': relation(Person, lazy=lazy_relation, private=True, backref='company')
+ })
+
+ if redefine_colprop:
+ person_attribute_name = 'person_name'
+ else:
+ person_attribute_name = 'name'
+
+ session = create_session()
+ c = Company(name='company1')
+ c.employees.append(Manager(status='AAB', manager_name='manager1', **{person_attribute_name:'pointy haired boss'}))
+ c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', **{person_attribute_name:'dilbert'}))
+ if include_base:
+ c.employees.append(Person(status='HHH', **{person_attribute_name:'joesmith'}))
+ c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'}))
+ c.employees.append(Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'}))
+ session.save(c)
+ print session.new
+ session.flush()
+ session.clear()
+ id = c.company_id
+ c = session.query(Company).get(id)
+ for e in c.employees:
+ print e, e._instance_key, e.company
+ if include_base:
+ assert sets.Set([e.get_name() for e in c.employees]) == sets.Set(['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith'])
+ else:
+ assert sets.Set([e.get_name() for e in c.employees]) == sets.Set(['pointy haired boss', 'dilbert', 'wally', 'jsmith'])
+ print "\n"
+
+
+ dilbert = session.query(Person).selectfirst(person_join.c.name=='dilbert')
+ dilbert2 = session.query(Engineer).selectfirst(people.c.name=='dilbert')
+ assert dilbert is dilbert2
+
+ dilbert.engineer_name = 'hes dibert!'
+
+ session.flush()
+ session.clear()
+
+ c = session.query(Company).get(id)
+ for e in c.employees:
+ print e, e._instance_key
+
+ session.delete(c)
+ session.flush()
+
+if __name__ == "__main__":
+ testbase.main()
+
diff --git a/test/pool.py b/test/pool.py
index 2737a33b1..d8c984aa8 100644
--- a/test/pool.py
+++ b/test/pool.py
@@ -1,5 +1,5 @@
from testbase import PersistTest
-import unittest, sys, os
+import unittest, sys, os, time
from pysqlite2 import dbapi2 as sqlite
import sqlalchemy.pool as pool
@@ -40,7 +40,14 @@ class PoolTest(PersistTest):
self.assert_(connection.cursor() is not None)
self.assert_(connection is not connection2)
- def testqueuepool(self):
+ def testqueuepool_del(self):
+ self._do_testqueuepool(useclose=False)
+
+ def testqueuepool_close(self):
+ self._do_testqueuepool(useclose=True)
+
+ def _do_testqueuepool(self, useclose=False):
+
p = pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = False, echo = False)
def status(pool):
@@ -60,30 +67,73 @@ class PoolTest(PersistTest):
self.assert_(status(p) == (3,0,2,5))
c6 = p.connect()
self.assert_(status(p) == (3,0,3,6))
- c4 = c3 = c2 = None
+ if useclose:
+ c4.close()
+ c3.close()
+ c2.close()
+ else:
+ c4 = c3 = c2 = None
self.assert_(status(p) == (3,3,3,3))
- c1 = c5 = c6 = None
+ if useclose:
+ c1.close()
+ c5.close()
+ c6.close()
+ else:
+ c1 = c5 = c6 = None
self.assert_(status(p) == (3,3,0,0))
c1 = p.connect()
c2 = p.connect()
self.assert_(status(p) == (3, 1, 0, 2))
- c2 = None
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
self.assert_(status(p) == (3, 2, 0, 1))
- def testthreadlocal(self):
+ def test_timeout(self):
+ p = pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = 0, use_threadlocal = False, echo = False, timeout=2)
+ c1 = p.get()
+ c2 = p.get()
+ c3 = p.get()
+ now = time.time()
+ c4 = p.get()
+ assert int(time.time() - now) == 2
+
+ def testthreadlocal_del(self):
+ self._do_testthreadlocal(useclose=False)
+
+ def testthreadlocal_close(self):
+ self._do_testthreadlocal(useclose=True)
+
+ def _do_testthreadlocal(self, useclose=False):
for p in (
pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True, echo = False),
pool.SingletonThreadPool(creator = lambda: sqlite.connect('foo.db'), use_threadlocal = True)
- ):
+ ):
c1 = p.connect()
c2 = p.connect()
self.assert_(c1 is c2)
c3 = p.unique_connection()
self.assert_(c3 is not c1)
- c2 = None
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
c2 = p.connect()
self.assert_(c1 is c2)
self.assert_(c3 is not c1)
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
+
+ if useclose:
+ c1 = p.connect()
+ c2 = p.connect()
+ c3 = p.connect()
+ c3.close()
+ c2.close()
+ self.assert_(c1.connection is not None)
def tearDown(self):
pool.clear_managers()
diff --git a/test/proxy_engine.py b/test/proxy_engine.py
index 2a2cebc5b..df0c64398 100644
--- a/test/proxy_engine.py
+++ b/test/proxy_engine.py
@@ -1,9 +1,10 @@
+import os
+
from sqlalchemy import *
from sqlalchemy.ext.proxy import ProxyEngine
from testbase import PersistTest
import testbase
-import os
#
# Define an engine, table and mapper at the module level, to show that the
@@ -11,20 +12,31 @@ import os
#
-module_engine = ProxyEngine(echo=testbase.echo)
-users = Table('users', module_engine,
- Column('user_id', Integer, primary_key=True),
- Column('user_name', String(16)),
- Column('password', String(20))
- )
+class ProxyTestBase(PersistTest):
+ def setUpAll(self):
+
+ global users, User, module_engine, module_metadata
+
+ module_engine = ProxyEngine(echo=testbase.echo)
+ module_metadata = MetaData()
-class User(object):
- pass
+ users = Table('users', module_metadata,
+ Column('user_id', Integer, primary_key=True),
+ Column('user_name', String(16)),
+ Column('password', String(20))
+ )
+ class User(object):
+ pass
-class ConstructTest(PersistTest):
+ User.mapper = mapper(User, users)
+ def tearDownAll(self):
+ clear_mappers()
+
+class ConstructTest(ProxyTestBase):
"""tests that we can build SQL constructs without engine-specific parameters, particulary
oid_column, being needed, as the proxy engine is usually not connected yet."""
+
def test_join(self):
engine = ProxyEngine()
t = Table('table1', engine,
@@ -33,47 +45,46 @@ class ConstructTest(PersistTest):
Column('col2', Integer, ForeignKey('table1.col1')))
j = join(t, t2)
-class ProxyEngineTest1(PersistTest):
- def setUp(self):
- clear_mappers()
- objectstore.clear()
-
+class ProxyEngineTest1(ProxyTestBase):
+
def test_engine_connect(self):
# connect to a real engine
module_engine.connect(testbase.db_uri)
- users.create()
- assign_mapper(User, users)
+ module_metadata.create_all(module_engine)
+
+ session = create_session(bind_to=module_engine)
try:
- trans = objectstore.begin()
user = User()
user.user_name='fred'
user.password='*'
- trans.commit()
+
+ session.save(user)
+ session.flush()
+
+ query = session.query(User)
# select
- sqluser = User.select_by(user_name='fred')[0]
+ sqluser = query.select_by(user_name='fred')[0]
assert sqluser.user_name == 'fred'
# modify
sqluser.user_name = 'fred jones'
- # commit - saves everything that changed
- objectstore.commit()
+ # flush - saves everything that changed
+ session.flush()
- allusers = [ user.user_name for user in User.select() ]
- assert allusers == [ 'fred jones' ]
+ allusers = [ user.user_name for user in query.select() ]
+ assert allusers == ['fred jones']
+
finally:
- users.drop()
+ module_metadata.drop_all(module_engine)
+
+
+class ThreadProxyTest(ProxyTestBase):
-class ThreadProxyTest(PersistTest):
- def setUp(self):
- assign_mapper(User, users)
- def tearDown(self):
- clear_mappers()
def tearDownAll(self):
- pass
os.remove('threadtesta.db')
os.remove('threadtestb.db')
@@ -92,23 +103,26 @@ class ThreadProxyTest(PersistTest):
try:
module_engine.connect(db_uri)
- users.create()
+ module_metadata.create_all(module_engine)
try:
- trans = objectstore.begin()
+ session = create_session(bind_to=module_engine)
+
+ query = session.query(User)
- all = User.select()[:]
+ all = list(query.select())
assert all == []
u = User()
u.user_name = uname
u.password = 'whatever'
- trans.commit()
- names = [ us.user_name for us in User.select() ]
- assert names == [ uname ]
+ session.save(u)
+ session.flush()
+
+ names = [u.user_name for u in query.select()]
+ assert names == [uname]
finally:
- users.drop()
- module_engine.dispose()
+ module_metadata.drop_all(module_engine)
except Exception, e:
import traceback
traceback.print_exc()
@@ -119,8 +133,8 @@ class ThreadProxyTest(PersistTest):
# NOTE: I'm not sure how to give the test runner the option to
# override these uris, or how to safely clear them after test runs
- a = Thread(target=run('sqlite://filename=threadtesta.db', 'jim', qa))
- b = Thread(target=run('sqlite://filename=threadtestb.db', 'joe', qb))
+ a = Thread(target=run('sqlite:///threadtesta.db', 'jim', qa))
+ b = Thread(target=run('sqlite:///threadtestb.db', 'joe', qb))
a.start()
b.start()
@@ -134,11 +148,8 @@ class ThreadProxyTest(PersistTest):
if res != False:
raise res
-class ProxyEngineTest2(PersistTest):
- def setUp(self):
- clear_mappers()
- objectstore.clear()
+class ProxyEngineTest2(ProxyTestBase):
def test_table_singleton_a(self):
"""set up for table singleton check
@@ -153,8 +164,9 @@ class ProxyEngineTest2(PersistTest):
Column('cat_name', String))
engine.connect(testbase.db_uri)
- cats.create()
- cats.drop()
+
+ cats.create(engine)
+ cats.drop(engine)
ProxyEngineTest2.cats_table_a = cats
assert isinstance(cats, Table)
@@ -179,141 +191,8 @@ class ProxyEngineTest2(PersistTest):
# this will fail because the old reference's local storage will
# not have the default attributes
engine.connect(testbase.db_uri)
- cats.create()
- cats.drop()
-
- def test_type_engine_caching(self):
- from sqlalchemy.engine import SQLEngine
- import sqlalchemy.types as sqltypes
-
- class EngineA(SQLEngine):
- def __init__(self):
- pass
-
- def hash_key(self):
- return 'a'
-
- def type_descriptor(self, typeobj):
- if isinstance(typeobj, types.Integer):
- return TypeEngineX2()
- else:
- return TypeEngineSTR()
-
- class EngineB(SQLEngine):
- def __init__(self):
- pass
-
- def hash_key(self):
- return 'b'
-
- def type_descriptor(self, typeobj):
- return TypeEngineMonkey()
-
- class TypeEngineX2(sqltypes.TypeEngine):
- def convert_bind_param(self, value, engine):
- return value * 2
-
- class TypeEngineSTR(sqltypes.TypeEngine):
- def convert_bind_param(self, value, engine):
- return repr(str(value))
-
- class TypeEngineMonkey(sqltypes.TypeEngine):
- def convert_bind_param(self, value, engine):
- return 'monkey'
-
- engine = ProxyEngine()
- engine.storage.engine = EngineA()
-
- a = sqltypes.Integer().engine_impl(engine)
- assert a.convert_bind_param(12, engine) == 24
- assert a.convert_bind_param([1,2,3], engine) == [1, 2, 3, 1, 2, 3]
-
- a2 = sqltypes.String().engine_impl(engine)
- assert a2.convert_bind_param(12, engine) == "'12'"
- assert a2.convert_bind_param([1,2,3], engine) == "'[1, 2, 3]'"
-
- engine.storage.engine = EngineB()
- b = sqltypes.Integer().engine_impl(engine)
- assert b.convert_bind_param(12, engine) == 'monkey'
- assert b.convert_bind_param([1,2,3], engine) == 'monkey'
-
-
- def test_type_engine_autoincrement(self):
- engine = ProxyEngine()
- dogs = Table('dogs', engine,
- Column('dog_id', Integer, primary_key=True),
- Column('breed', String),
- Column('name', String))
-
- class Dog(object):
- pass
-
- assign_mapper(Dog, dogs)
-
- engine.connect(testbase.db_uri)
- dogs.create()
- try:
- spot = Dog()
- spot.breed = 'beagle'
- spot.name = 'Spot'
+ cats.create(engine)
+ cats.drop(engine)
- rover = Dog()
- rover.breed = 'spaniel'
- rover.name = 'Rover'
-
- objectstore.commit()
-
- assert spot.dog_id > 0, "Spot did not get an id"
- assert rover.dog_id != spot.dog_id
- finally:
- dogs.drop()
-
- def test_type_proxy_schema_gen(self):
- from sqlalchemy.databases.postgres import PGSchemaGenerator
-
- engine = ProxyEngine()
- lizards = Table('lizards', engine,
- Column('id', Integer, primary_key=True),
- Column('name', String))
-
- # this doesn't really CONNECT to pg, just establishes pg as the
- # actual engine so that we can determine that it gets the right
- # answer
- engine.connect('postgres://database=test&port=5432&host=127.0.0.1&user=scott&password=tiger')
-
- sg = PGSchemaGenerator(engine)
- id_spec = sg.get_column_specification(lizards.c.id)
- assert id_spec == 'id SERIAL NOT NULL PRIMARY KEY'
-
-
if __name__ == "__main__":
testbase.main()
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/test/query.py b/test/query.py
index 946180618..0c5164133 100644
--- a/test/query.py
+++ b/test/query.py
@@ -6,7 +6,6 @@ import sqlalchemy.databases.sqlite as sqllite
import tables
db = testbase.db
-#db.echo='debug'
from sqlalchemy import *
from sqlalchemy.engine import ResultProxy, RowProxy
@@ -103,8 +102,6 @@ class QueryTest(PersistTest):
def testdelete(self):
- c = db.connection()
-
self.users.insert().execute(user_id = 7, user_name = 'jack')
self.users.insert().execute(user_id = 8, user_name = 'fred')
print repr(self.users.select().execute().fetchall())
@@ -113,14 +110,6 @@ class QueryTest(PersistTest):
print repr(self.users.select().execute().fetchall())
- def testtransaction(self):
- def dostuff():
- self.users.insert().execute(user_id = 7, user_name = 'john')
- self.users.insert().execute(user_id = 8, user_name = 'jack')
-
- db.transaction(dostuff)
- print repr(self.users.select().execute().fetchall())
-
def testselectlimit(self):
self.users.insert().execute(user_id=1, user_name='john')
self.users.insert().execute(user_id=2, user_name='jack')
@@ -158,10 +147,13 @@ class QueryTest(PersistTest):
self.users.insert().execute(user_id=1, user_name='foo')
r = self.users.select().execute().fetchone()
self.assertEqual(len(r), 2)
+ r.close()
r = db.execute('select user_name, user_id from query_users', {}).fetchone()
self.assertEqual(len(r), 2)
+ r.close()
r = db.execute('select user_name from query_users', {}).fetchone()
self.assertEqual(len(r), 1)
+ r.close()
def test_column_order_with_simple_query(self):
# should return values in column definition order
@@ -180,11 +172,9 @@ class QueryTest(PersistTest):
self.assertEqual(r[1], 1)
self.assertEqual(r.keys(), ['user_name', 'user_id'])
self.assertEqual(r.values(), ['foo', 1])
-
+
+ @testbase.unsupported('oracle')
def test_column_accessor_shadow(self):
- if db.engine.__module__.endswith('oracle'):
- return
-
shadowed = Table('test_shadowed', db,
Column('shadow_id', INT, primary_key = True),
Column('shadow_name', VARCHAR(20)),
@@ -209,6 +199,7 @@ class QueryTest(PersistTest):
self.fail('Should not allow access to private attributes')
except AttributeError:
pass # expected
+ r.close()
finally:
shadowed.drop()
diff --git a/test/reflection.py b/test/reflection.py
index 718957add..20a5fd90c 100644
--- a/test/reflection.py
+++ b/test/reflection.py
@@ -1,8 +1,6 @@
import sqlalchemy.ansisql as ansisql
import sqlalchemy.databases.postgres as postgres
-import sqlalchemy.databases.oracle as oracle
-import sqlalchemy.databases.sqlite as sqllite
from sqlalchemy import *
@@ -14,7 +12,7 @@ class ReflectionTest(PersistTest):
def testbasic(self):
# really trip it up with a circular reference
- use_function_defaults = testbase.db.engine.__module__.endswith('postgres') or testbase.db.engine.__module__.endswith('oracle')
+ use_function_defaults = testbase.db.engine.name == 'postgres' or testbase.db.engine.name == 'oracle'
use_string_defaults = use_function_defaults or testbase.db.engine.__module__.endswith('sqlite')
@@ -123,9 +121,10 @@ class ReflectionTest(PersistTest):
table.drop()
def testtoengine(self):
- db = ansisql.engine()
+ meta = MetaData('md1')
+ meta2 = MetaData('md2')
- table = Table('mytable', db,
+ table = Table('mytable', meta,
Column('myid', Integer, key = 'id'),
Column('name', String, key = 'name', nullable=False),
Column('description', String, key = 'description'),
@@ -133,14 +132,14 @@ class ReflectionTest(PersistTest):
print repr(table)
- pgdb = postgres.engine({})
+ table2 = table.tometadata(meta2)
- pgtable = table.toengine(pgdb)
+ print repr(table2)
- print repr(pgtable)
- assert pgtable.c.id.nullable
- assert not pgtable.c.name.nullable
- assert pgtable.c.description.nullable
+ assert table is not table2
+ assert table2.c.id.nullable
+ assert not table2.c.name.nullable
+ assert table2.c.description.nullable
def testoverride(self):
table = Table(
@@ -165,6 +164,58 @@ class ReflectionTest(PersistTest):
self.assert_(isinstance(table.c.col4.type, String))
finally:
table.drop()
+
+class CreateDropTest(PersistTest):
+ def setUpAll(self):
+ global metadata
+ metadata = MetaData()
+ users = Table('users', metadata,
+ Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
+ Column('user_name', String(40)),
+ )
+
+ addresses = Table('email_addresses', metadata,
+ Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
+ Column('user_id', Integer, ForeignKey(users.c.user_id)),
+ Column('email_address', String(40)),
+
+ )
+
+ orders = Table('orders', metadata,
+ Column('order_id', Integer, Sequence('order_id_seq', optional=True), primary_key = True),
+ Column('user_id', Integer, ForeignKey(users.c.user_id)),
+ Column('description', String(50)),
+ Column('isopen', Integer),
+
+ )
+
+ orderitems = Table('items', metadata,
+ Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
+ Column('order_id', INT, ForeignKey("orders")),
+ Column('item_name', VARCHAR(50)),
+
+ )
+
+ def test_sorter( self ):
+ tables = metadata._sort_tables(metadata.tables.values())
+ table_names = [t.name for t in tables]
+ self.assert_( table_names == ['users', 'orders', 'items', 'email_addresses'] or table_names == ['users', 'email_addresses', 'orders', 'items'])
+
+
+ def test_createdrop(self):
+ metadata.create_all(engine=testbase.db)
+ self.assertEqual( testbase.db.has_table('items'), True )
+ self.assertEqual( testbase.db.has_table('email_addresses'), True )
+ metadata.create_all(engine=testbase.db)
+ self.assertEqual( testbase.db.has_table('items'), True )
+
+ metadata.drop_all(engine=testbase.db)
+ self.assertEqual( testbase.db.has_table('items'), False )
+ self.assertEqual( testbase.db.has_table('email_addresses'), False )
+ metadata.drop_all(engine=testbase.db)
+ self.assertEqual( testbase.db.has_table('items'), False )
+
+
if __name__ == "__main__":
testbase.main()
diff --git a/test/relationships.py b/test/relationships.py
index 84a45c2dd..d9a9d6e50 100644
--- a/test/relationships.py
+++ b/test/relationships.py
@@ -22,33 +22,34 @@ class RelationTest(testbase.PersistTest):
global tbl_b
global tbl_c
global tbl_d
- tbl_a = Table("tbl_a", db,
+ metadata = MetaData()
+ tbl_a = Table("tbl_a", metadata,
Column("id", Integer, primary_key=True),
Column("name", String),
)
- tbl_b = Table("tbl_b", db,
+ tbl_b = Table("tbl_b", metadata,
Column("id", Integer, primary_key=True),
Column("name", String),
)
- tbl_c = Table("tbl_c", db,
+ tbl_c = Table("tbl_c", metadata,
Column("id", Integer, primary_key=True),
Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False),
Column("name", String),
)
- tbl_d = Table("tbl_d", db,
+ tbl_d = Table("tbl_d", metadata,
Column("id", Integer, primary_key=True),
Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False),
Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")),
Column("name", String),
)
def setUp(self):
- tbl_a.create()
- tbl_b.create()
- tbl_c.create()
- tbl_d.create()
-
- objectstore.clear()
- clear_mappers()
+ global session
+ session = create_session(bind_to=testbase.db)
+ conn = session.connect()
+ conn.create(tbl_a)
+ conn.create(tbl_b)
+ conn.create(tbl_c)
+ conn.create(tbl_d)
class A(object):
pass
@@ -75,30 +76,31 @@ class RelationTest(testbase.PersistTest):
b = B(); b.name = "b1"
c = C(); c.name = "c1"; c.a_row = a
# we must have more than one d row or it won't fail
- d = D(); d.name = "d1"; d.b_row = b; d.c_row = c
- d = D(); d.name = "d2"; d.b_row = b; d.c_row = c
- d = D(); d.name = "d3"; d.b_row = b; d.c_row = c
-
+ d1 = D(); d1.name = "d1"; d1.b_row = b; d1.c_row = c
+ d2 = D(); d2.name = "d2"; d2.b_row = b; d2.c_row = c
+ d3 = D(); d3.name = "d3"; d3.b_row = b; d3.c_row = c
+ session.save_or_update(a)
+ session.save_or_update(b)
+
def tearDown(self):
- tbl_d.drop()
- tbl_c.drop()
- tbl_b.drop()
- tbl_a.drop()
+ conn = session.connect()
+ conn.drop(tbl_d)
+ conn.drop(tbl_c)
+ conn.drop(tbl_b)
+ conn.drop(tbl_a)
def tearDownAll(self):
- testbase.db.tables.clear()
+ testbase.metadata.tables.clear()
def testDeleteRootTable(self):
- session = objectstore.get_session()
- session.commit()
+ session.flush()
session.delete(a) # works as expected
- session.commit()
-
+ session.flush()
+
def testDeleteMiddleTable(self):
- session = objectstore.get_session()
- session.commit()
+ session.flush()
session.delete(c) # fails
- session.commit()
+ session.flush()
if __name__ == "__main__":
diff --git a/test/select.py b/test/select.py
index fb136cfec..0fc3ca60f 100644
--- a/test/select.py
+++ b/test/select.py
@@ -1,14 +1,6 @@
from sqlalchemy import *
-import sqlalchemy.ansisql as ansisql
-import sqlalchemy.databases.postgres as postgres
-import sqlalchemy.databases.oracle as oracle
-import sqlalchemy.databases.sqlite as sqlite
-import sqlalchemy.databases.mysql as mysql
-
-db = ansisql.engine()
-#db = create_engine('mssql')
-
+from sqlalchemy.databases import sqlite, postgres, mysql, oracle
from testbase import PersistTest
import unittest, re
@@ -34,8 +26,9 @@ table3 = table(
column('otherstuff'),
)
+metadata = MetaData()
table4 = Table(
- 'remotetable', db,
+ 'remotetable', metadata,
Column('rem_id', Integer, primary_key=True),
Column('datatype_id', Integer),
Column('value', String(20)),
@@ -58,8 +51,8 @@ addresses = table('addresses',
)
class SQLTest(PersistTest):
- def runtest(self, clause, result, engine = None, params = None, checkparams = None):
- c = clause.compile(parameters=params, engine=engine)
+ def runtest(self, clause, result, dialect = None, params = None, checkparams = None):
+ c = clause.compile(parameters=params, dialect=dialect)
self.echo("\nSQL String:\n" + str(c) + repr(c.get_params()))
cc = re.sub(r'\n', '', str(c))
self.assert_(cc == result, str(c) + "\n does not match \n" + result)
@@ -80,7 +73,6 @@ myothertable.othername FROM mytable, myothertable")
"""tests placing select statements in the column clause of another select, for the
purposes of selecting from the exported columns of that select."""
s = select([table1], table1.c.name == 'jack')
- #print [key for key in s.c.keys()]
self.runtest(
select(
[s],
@@ -151,7 +143,6 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
self.runtest(
select([users, s.c.street], from_obj=[s]),
"""SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")
-
def testcolumnsubquery(self):
s = select([table1.c.myid], scalar=True, correlate=False)
@@ -213,7 +204,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
)
self.runtest(
- literal("a") + literal("b") * literal("c"), ":literal + (:liter_1 * :liter_2)", db
+ literal("a") + literal("b") * literal("c"), ":literal + (:liter_1 * :liter_2)"
)
def testmultiparam(self):
@@ -234,9 +225,9 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
)
def testoraclelimit(self):
- e = create_engine('oracle')
- users = Table('users', e, Column('name', String(10), key='username'))
- self.runtest(select([users.c.username], limit=5), "SELECT name FROM (SELECT users.name AS name, ROW_NUMBER() OVER (ORDER BY users.rowid ASC) AS ora_rn FROM users) WHERE ora_rn<=5", engine=e)
+ metadata = MetaData()
+ users = Table('users', metadata, Column('name', String(10), key='username'))
+ self.runtest(select([users.c.username], limit=5), "SELECT name FROM (SELECT users.name AS name, ROW_NUMBER() OVER (ORDER BY users.rowid) AS ora_rn FROM users) WHERE ora_rn<=5", dialect=oracle.dialect())
def testgroupby_and_orderby(self):
self.runtest(
@@ -276,15 +267,13 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
def testtext(self):
self.runtest(
text("select * from foo where lala = bar") ,
- "select * from foo where lala = bar",
- engine = db
+ "select * from foo where lala = bar"
)
self.runtest(select(
["foobar(a)", "pk_foo_bar(syslaal)"],
"a = 12",
- from_obj = ["foobar left outer join lala on foobar.foo = lala.foo"],
- engine = db
+ from_obj = ["foobar left outer join lala on foobar.foo = lala.foo"]
),
"SELECT foobar(a), pk_foo_bar(syslaal) FROM foobar left outer join lala on foobar.foo = lala.foo WHERE a = 12")
@@ -296,33 +285,32 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
s.append_whereclause("column2=19")
s.order_by("column1")
s.append_from("table1")
- self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1", db)
+ self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1")
def testtextbinds(self):
self.runtest(
- db.text("select * from foo where lala=:bar and hoho=:whee"),
+ text("select * from foo where lala=:bar and hoho=:whee"),
"select * from foo where lala=:bar and hoho=:whee",
checkparams={'bar':4, 'whee': 7},
params={'bar':4, 'whee': 7, 'hoho':10},
- engine=db
)
- engine = postgres.engine({})
+ dialect = postgres.dialect()
self.runtest(
- engine.text("select * from foo where lala=:bar and hoho=:whee"),
+ text("select * from foo where lala=:bar and hoho=:whee"),
"select * from foo where lala=%(bar)s and hoho=%(whee)s",
checkparams={'bar':4, 'whee': 7},
params={'bar':4, 'whee': 7, 'hoho':10},
- engine=engine
+ dialect=dialect
)
- engine = sqlite.engine({})
+ dialect = sqlite.dialect()
self.runtest(
- engine.text("select * from foo where lala=:bar and hoho=:whee"),
+ text("select * from foo where lala=:bar and hoho=:whee"),
"select * from foo where lala=? and hoho=?",
checkparams=[4, 7],
params={'bar':4, 'whee': 7, 'hoho':10},
- engine=engine
+ dialect=dialect
)
def testtextmix(self):
@@ -393,7 +381,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
"SELECT foo.bar.lala(:lala)")
# test a dotted func off the engine itself
- self.runtest(db.func.lala.hoho(7), "lala.hoho(:hoho)")
+ self.runtest(func.lala.hoho(7), "lala.hoho(:hoho)")
def testjoin(self):
self.runtest(
@@ -461,6 +449,13 @@ FROM mytable WHERE mytable.myid = :mytable_my_1 ORDER BY mytable.myid")
FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \
FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thirdtable")
+ u = union(
+ select([table1]),
+ select([table2]),
+ select([table3])
+ )
+ assert u.corresponding_column(table2.c.otherid) is u.c.otherid
+
def testouterjoin(self):
# test an outer join. the oracle module should take the ON clause of the join and
@@ -485,19 +480,19 @@ FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid
WHERE mytable.name = %(mytable_name)s AND mytable.myid = %(mytable_myid)s AND \
myothertable.othername != %(myothertable_othername)s AND \
EXISTS (select yay from foo where boo = lar)",
- engine = postgres.engine({}))
-
+ dialect=postgres.dialect()
+ )
self.runtest(query,
"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \
FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid(+) AND \
mytable.name = :mytable_name AND mytable.myid = :mytable_myid AND \
myothertable.othername != :myothertable_othername AND EXISTS (select yay from foo where boo = lar)",
- engine = oracle.engine({}, use_ansi = False))
+ dialect=oracle.OracleDialect(use_ansi = False))
query = table1.outerjoin(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid)
self.runtest(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON thirdtable.userid = myothertable.otherid")
- self.runtest(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", engine=oracle.engine({}, use_ansi=False))
+ self.runtest(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
def testbindparam(self):
self.runtest(select(
@@ -513,7 +508,7 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
# check that the bind params sent along with a compile() call
# get preserved when the params are retreived later
s = select([table1], table1.c.myid == bindparam('test'))
- c = s.compile(parameters = {'test' : 7}, engine=db)
+ c = s.compile(parameters = {'test' : 7})
self.assert_(c.get_params() == {'test' : 7})
@@ -542,28 +537,26 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
Column('ts', TIMESTAMP),
)
- def check_results(engine, expected_results, literal):
+ def check_results(dialect, expected_results, literal):
self.assertEqual(len(expected_results), 5, 'Incorrect number of expected results')
- self.assertEqual(str(cast(tbl.c.v1, Numeric, engine=engine)), 'CAST(casttest.v1 AS %s)' %expected_results[0])
- self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9), engine=engine)), 'CAST(casttest.v1 AS %s)' %expected_results[1])
- self.assertEqual(str(cast(tbl.c.ts, Date, engine=engine)), 'CAST(casttest.ts AS %s)' %expected_results[2])
- self.assertEqual(str(cast(1234, TEXT, engine=engine)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
- self.assertEqual(str(cast('test', String(20), engine=engine)), 'CAST(%s AS %s)' %(literal, expected_results[4]))
-
- sel = select([tbl, cast(tbl.c.v1, Numeric)], engine=engine)
- self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) \nFROM casttest")
-
+ self.assertEqual(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0])
+ self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1])
+ self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2])
+ self.assertEqual(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
+ self.assertEqual(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4]))
+ sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect)
+ self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) \nFROM casttest")
# first test with Postgres engine
- check_results(postgres.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s')
+ check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s')
# then the Oracle engine
- check_results(oracle.engine({}, use_ansi = False), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal')
+# check_results(oracle.OracleDialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':literal')
# then the sqlite engine
- check_results(sqlite.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?')
+ check_results(sqlite.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?')
# and the MySQL engine
- check_results(mysql.engine({}), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%s')
+ check_results(mysql.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%s')
class CRUDTest(SQLTest):
def testinsert(self):
@@ -601,8 +594,7 @@ class CRUDTest(SQLTest):
self.runtest(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'})
self.runtest(update(table1, table1.c.myid == 12, values = {table1.c.myid : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'})
s = table1.update(table1.c.myid == 12, values = {table1.c.name : 'lala'})
- c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'}, engine=db)
- print str(c)
+ c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'})
self.assert_(str(s) == str(c))
def testupdateexpression(self):
@@ -623,7 +615,7 @@ class CRUDTest(SQLTest):
s = select([table2], table2.c.otherid == table1.c.myid)
u = update(table1, table1.c.name == 'jack', values = {table1.c.name : s})
self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name")
-
+
# test a correlated WHERE clause
s = select([table2.c.othername], table2.c.otherid == 7)
u = update(table1, table1.c.name==s)
diff --git a/test/selectable.py b/test/selectable.py
index c1a3f28e5..59f57331b 100755
--- a/test/selectable.py
+++ b/test/selectable.py
@@ -15,6 +15,7 @@ table = Table('table1', db,
Column('col1', Integer, primary_key=True),
Column('col2', String(20)),
Column('col3', Integer),
+ Column('colx', Integer),
redefine=True
)
@@ -22,16 +23,10 @@ table2 = Table('table2', db,
Column('col1', Integer, primary_key=True),
Column('col2', Integer, ForeignKey('table1.col1')),
Column('col3', String(20)),
+ Column('coly', Integer),
redefine=True
)
-table3 = Table('table3', db,
- Column('col1', Integer, ForeignKey('table1.col1'), primary_key=True),
- Column('col2', Integer),
- Column('col3', String(20)),
- redefine=True
- )
-
class SelectableTest(testbase.AssertMixin):
def testtablealias(self):
a = table.alias('a')
@@ -43,16 +38,51 @@ class SelectableTest(testbase.AssertMixin):
print str(j)
self.assert_(criterion.compare(j.onclause))
- def testjoinpks(self):
- a = join(table, table3)
- b = join(table, table3, table.c.col1==table3.c.col2)
- c = join(table, table3, table.c.col2==table3.c.col2)
- d = join(table, table3, table.c.col2==table3.c.col1)
-
- self.assert_(a.primary_key==[table.c.col1])
- self.assert_(b.primary_key==[table.c.col1, table3.c.col1])
- self.assert_(c.primary_key==[table.c.col1, table3.c.col1])
- self.assert_(d.primary_key==[table.c.col1, table3.c.col1])
+ def testunion(self):
+ # tests that we can correspond a column in a Select statement with a certain Table, against
+ # a column in a Union where one of its underlying Selects matches to that same Table
+ u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
+ select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
+ )
+ s1 = table.select(use_labels=True)
+ s2 = table2.select(use_labels=True)
+ print ["%d %s" % (id(c),c.key) for c in u.c]
+ c = u.corresponding_column(s1.c.table1_col2)
+ print "%d %s" % (id(c), c.key)
+ assert u.corresponding_column(s1.c.table1_col2) is u.c.col2
+ assert u.corresponding_column(s2.c.table2_col2) is u.c.col2
+
+ def testaliasunion(self):
+ # same as testunion, except its an alias of the union
+ u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
+ select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
+ ).alias('analias')
+ s1 = table.select(use_labels=True)
+ s2 = table2.select(use_labels=True)
+ assert u.corresponding_column(s1.c.table1_col2) is u.c.col2
+ assert u.corresponding_column(s2.c.table2_col2) is u.c.col2
+ assert u.corresponding_column(s2.c.table2_coly) is u.c.coly
+ assert s2.corresponding_column(u.c.coly) is s2.c.table2_coly
+
+ def testselectunion(self):
+ # like testaliasunion, but off a Select off the union.
+ u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
+ select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
+ ).alias('analias')
+ s = select([u])
+ s1 = table.select(use_labels=True)
+ s2 = table2.select(use_labels=True)
+ assert s.corresponding_column(s1.c.table1_col2) is s.c.col2
+ assert s.corresponding_column(s2.c.table2_col2) is s.c.col2
+
+ def testunionagainstjoin(self):
+ # same as testunion, except its an alias of the union
+ u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
+ select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
+ ).alias('analias')
+ j1 = table.join(table2)
+ assert u.corresponding_column(j1.c.table1_colx) is u.c.colx
+ assert j1.corresponding_column(u.c.colx) is j1.c.table1_colx
def testjoin(self):
a = join(table, table2)
diff --git a/test/selectresults.py b/test/selectresults.py
index d04918683..6bd619f3a 100644
--- a/test/selectresults.py
+++ b/test/selectresults.py
@@ -3,31 +3,40 @@ import testbase
from sqlalchemy import *
-from sqlalchemy.mods.selectresults import SelectResultsExt
+from sqlalchemy.ext.selectresults import SelectResultsExt
class Foo(object):
pass
class SelectResultsTest(PersistTest):
def setUpAll(self):
+ self.install_threadlocal()
global foo
foo = Table('foo', testbase.db,
Column('id', Integer, Sequence('foo_id_seq'), primary_key=True),
- Column('bar', Integer))
+ Column('bar', Integer),
+ Column('range', Integer))
assign_mapper(Foo, foo, extension=SelectResultsExt())
foo.create()
for i in range(100):
- Foo(bar=i)
- objectstore.commit()
+ Foo(bar=i, range=i%10)
+ objectstore.flush()
def setUp(self):
- self.orig = Foo.mapper.select_whereclause()
- self.res = Foo.select()
+ self.query = Foo.mapper.query()
+ self.orig = self.query.select_whereclause()
+ self.res = self.query.select()
def tearDownAll(self):
global foo
foo.drop()
+ self.uninstall_threadlocal()
+
+ def test_selectby(self):
+ res = self.query.select_by(range=5)
+ assert res.order_by([Foo.c.bar])[0].bar == 5
+ assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
def test_slice(self):
assert self.res[1] == self.orig[1]
diff --git a/test/session.py b/test/session.py
new file mode 100644
index 000000000..9ed7f0f7a
--- /dev/null
+++ b/test/session.py
@@ -0,0 +1,7 @@
+
+# test merging a composed object.
+
+# test that when cascading an operation, like "merge", lazy-loaded scalar and list attributes that werent already loaded on the given object remain not loaded.
+
+# test putting an object in session A, "moving" it to session B, insure its in B and not in A
+
diff --git a/test/sessioncontext.py b/test/sessioncontext.py
new file mode 100644
index 000000000..83bc2f2bf
--- /dev/null
+++ b/test/sessioncontext.py
@@ -0,0 +1,47 @@
+from testbase import PersistTest, AssertMixin
+import unittest, sys, os
+from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.session import object_session, Session
+from sqlalchemy import *
+import testbase
+
+metadata = MetaData()
+users = Table('users', metadata,
+ Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
+ Column('user_name', String(40)),
+ mysql_engine='innodb'
+)
+
+class SessionContextTest(AssertMixin):
+ def setUp(self):
+ clear_mappers()
+
+ def do_test(self, class_, context):
+ """test session assignment on object creation"""
+ obj = class_()
+ assert context.current == object_session(obj)
+
+ # keep a reference so the old session doesn't get gc'd
+ old_session = context.current
+
+ context.current = Session()
+ assert context.current != object_session(obj)
+ assert old_session == object_session(obj)
+
+ new_session = context.current
+ del context.current
+ assert context.current != new_session
+ assert old_session == object_session(obj)
+
+ obj2 = class_()
+ assert context.current == object_session(obj2)
+
+ def test_mapper_extension(self):
+ context = SessionContext(Session)
+ class User(object): pass
+ User.mapper = mapper(User, users, extension=context.mapper_extension)
+ self.do_test(User, context)
+
+
+if __name__ == "__main__":
+ testbase.main()
diff --git a/test/tables.py b/test/tables.py
index f1e1a845b..2bfc75868 100644
--- a/test/tables.py
+++ b/test/tables.py
@@ -9,22 +9,22 @@ __all__ = ['db', 'users', 'addresses', 'orders', 'orderitems', 'keywords', 'item
ECHO = testbase.echo
db = testbase.db
+metadata = BoundMetaData(db)
-
-users = Table('users', db,
+users = Table('users', metadata,
Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
Column('user_name', String(40)),
mysql_engine='innodb'
)
-addresses = Table('email_addresses', db,
+addresses = Table('email_addresses', metadata,
Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
Column('user_id', Integer, ForeignKey(users.c.user_id)),
Column('email_address', String(40)),
)
-orders = Table('orders', db,
+orders = Table('orders', metadata,
Column('order_id', Integer, Sequence('order_id_seq', optional=True), primary_key = True),
Column('user_id', Integer, ForeignKey(users.c.user_id)),
Column('description', String(50)),
@@ -32,51 +32,32 @@ orders = Table('orders', db,
)
-orderitems = Table('items', db,
+orderitems = Table('items', metadata,
Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
Column('order_id', INT, ForeignKey("orders")),
Column('item_name', VARCHAR(50)),
)
-keywords = Table('keywords', db,
+keywords = Table('keywords', metadata,
Column('keyword_id', Integer, Sequence('keyword_id_seq', optional=True), primary_key = True),
Column('name', VARCHAR(50)),
)
-itemkeywords = Table('itemkeywords', db,
+itemkeywords = Table('itemkeywords', metadata,
Column('item_id', INT, ForeignKey("items")),
Column('keyword_id', INT, ForeignKey("keywords")),
)
def create():
- users.create()
- addresses.create()
- orders.create()
- orderitems.create()
- keywords.create()
- itemkeywords.create()
-
+ metadata.create_all()
def drop():
- itemkeywords.drop()
- keywords.drop()
- orderitems.drop()
- orders.drop()
- addresses.drop()
- users.drop()
- db.commit()
-
+ metadata.drop_all()
def delete():
- itemkeywords.delete().execute()
- keywords.delete().execute()
- orderitems.delete().execute()
- orders.delete().execute()
- addresses.delete().execute()
- users.delete().execute()
- db.commit()
-
+ for t in metadata.table_iterator(reverse=True):
+ t.delete().execute()
def user_data():
users.insert().execute(
dict(user_id = 7, user_name = 'jack'),
@@ -85,7 +66,6 @@ def user_data():
)
def delete_user_data():
users.delete().execute()
- db.commit()
def data():
delete()
@@ -144,8 +124,6 @@ def data():
dict(keyword_id=7, item_id=2),
dict(keyword_id=6, item_id=3)
)
-
- db.commit()
class User(object):
def __init__(self):
diff --git a/test/testbase.py b/test/testbase.py
index 8ef63ef27..04972779d 100644
--- a/test/testbase.py
+++ b/test/testbase.py
@@ -2,63 +2,96 @@ import unittest
import StringIO
import sqlalchemy.engine as engine
import sqlalchemy.ext.proxy as proxy
-import sqlalchemy.schema as schema
+import sqlalchemy.pool as pool
+#import sqlalchemy.schema as schema
import re, sys
+import sqlalchemy
+import optparse
+
-echo = True
-#echo = False
-#echo = 'debug'
db = None
+metadata = None
db_uri = None
+echo = True
+
+# redefine sys.stdout so all those print statements go to the echo func
+local_stdout = sys.stdout
+class Logger(object):
+ def write(self, msg):
+ if echo:
+ local_stdout.write(msg)
+sys.stdout = Logger()
+
+def echo_text(text):
+ print text
def parse_argv():
# we are using the unittest main runner, so we are just popping out the
# arguments we need instead of using our own getopt type of thing
- global db, db_uri
+ global db, db_uri, metadata
DBTYPE = 'sqlite'
PROXY = False
+
+
+ parser = optparse.OptionParser(usage = "usage: %prog [options] files...")
+ parser.add_option("--dburi", action="store", dest="dburi", help="database uri (overrides --db)")
+ parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (sqlite, sqlite_file, postgres, mysql, oracle, oracle8, mssql)")
+ parser.add_option("--mockpool", action="store_true", dest="mockpool", help="use mock pool")
+ parser.add_option("--verbose", action="store_true", dest="verbose", help="full debug echoing")
+ parser.add_option("--quiet", action="store_true", dest="quiet", help="be totally quiet")
+ parser.add_option("--nothreadlocal", action="store_true", dest="nothreadlocal", help="dont use thread-local mod")
+ parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to SA default)")
+
+ (options, args) = parser.parse_args()
+ sys.argv[1:] = args
- if len(sys.argv) >= 3:
- if sys.argv[1] == '--dburi':
- (param, db_uri) = (sys.argv.pop(1), sys.argv.pop(1))
- elif sys.argv[1] == '--db':
- (param, DBTYPE) = (sys.argv.pop(1), sys.argv.pop(1))
+ if options.dburi:
+ db_uri = param = options.dburi
+ elif options.db:
+ DBTYPE = param = options.db
+
opts = {}
if (None == db_uri):
- p = DBTYPE.split('.')
- if len(p) > 1:
- arg = p[0]
- DBTYPE = p[1]
- if arg == 'proxy':
- PROXY = True
if DBTYPE == 'sqlite':
- db_uri = 'sqlite://filename=:memory:'
+ db_uri = 'sqlite:///:memory:'
elif DBTYPE == 'sqlite_file':
- db_uri = 'sqlite://filename=querytest.db'
+ db_uri = 'sqlite:///querytest.db'
elif DBTYPE == 'postgres':
- db_uri = 'postgres://database=test&port=5432&host=127.0.0.1&user=scott&password=tiger'
+ db_uri = 'postgres://scott:tiger@127.0.0.1:5432/test'
elif DBTYPE == 'mysql':
- db_uri = 'mysql://database=test&host=127.0.0.1&user=scott&password=tiger'
+ db_uri = 'mysql://scott:tiger@127.0.0.1/test'
elif DBTYPE == 'oracle':
- db_uri = 'oracle://user=scott&password=tiger'
+ db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
elif DBTYPE == 'oracle8':
- db_uri = 'oracle://user=scott&password=tiger'
+ db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
opts = {'use_ansi':False}
elif DBTYPE == 'mssql':
- db_uri = 'mssql://database=test&user=scott&password=tiger'
+ db_uri = 'mssql://scott:tiger@/test'
if not db_uri:
raise "Could not create engine. specify --db <sqlite|sqlite_file|postgres|mysql|oracle|oracle8|mssql> to test runner."
- if PROXY:
- db = proxy.ProxyEngine(echo=echo, default_ordering=True, **opts)
- db.connect(db_uri)
+ if not options.nothreadlocal:
+ __import__('sqlalchemy.mods.threadlocal')
+ sqlalchemy.mods.threadlocal.uninstall_plugin()
+
+ global echo
+ echo = options.verbose and not options.quiet
+
+ global quiet
+ quiet = options.quiet
+
+ if options.enginestrategy is not None:
+ opts['strategy'] = options.enginestrategy
+ if options.mockpool:
+ db = engine.create_engine(db_uri, echo=True, default_ordering=True, poolclass=MockPool, **opts)
else:
- db = engine.create_engine(db_uri, echo=echo, default_ordering=True, **opts)
+ db = engine.create_engine(db_uri, echo=True, default_ordering=True, **opts)
db = EngineAssert(db)
-
+ metadata = sqlalchemy.BoundMetaData(db)
+
def unsupported(*dbs):
"""a decorator that marks a test as unsupported by one or more database implementations"""
def decorate(func):
@@ -87,21 +120,49 @@ def supported(*dbs):
return lala
return decorate
-def echo_text(text):
- print text
class PersistTest(unittest.TestCase):
"""persist base class, provides default setUpAll, tearDownAll and echo functionality"""
def __init__(self, *args, **params):
unittest.TestCase.__init__(self, *args, **params)
def echo(self, text):
- if echo:
- echo_text(text)
+ echo_text(text)
+ def install_threadlocal(self):
+ sqlalchemy.mods.threadlocal.install_plugin()
+ def uninstall_threadlocal(self):
+ sqlalchemy.mods.threadlocal.uninstall_plugin()
def setUpAll(self):
pass
def tearDownAll(self):
pass
+ def shortDescription(self):
+ """overridden to not return docstrings"""
+ return None
+
+class MockPool(pool.Pool):
+ """this pool is hardcore about only one connection being used at a time."""
+ def __init__(self, creator, **params):
+ pool.Pool.__init__(self, **params)
+ self.connection = creator()
+ self._conn = self.connection
+
+ def status(self):
+ return "MockPool"
+
+ def do_return_conn(self, conn):
+ assert conn is self._conn and self.connection is None
+ self.connection = conn
+
+ def do_return_invalid(self):
+ raise "Invalid"
+ def do_get(self):
+ if getattr(self, 'breakpoint', False):
+ raise "breakpoint"
+ assert self.connection is not None
+ c = self.connection
+ self.connection = None
+ return c
class AssertMixin(PersistTest):
"""given a list-based structure of keys/properties which represent information within an object structure, and
@@ -145,8 +206,10 @@ class EngineAssert(proxy.BaseProxyEngine):
"""decorates a SQLEngine object to match the incoming queries against a set of assertions."""
def __init__(self, engine):
self._engine = engine
- self.realexec = engine.post_exec
- self.realexec.im_self.post_exec = self.post_exec
+
+ self.real_execution_context = engine.dialect.create_execution_context
+ engine.dialect.create_execution_context = self.execution_context
+
self.logger = engine.logger
self.set_assert_list(None, None)
self.sql_count = 0
@@ -154,8 +217,6 @@ class EngineAssert(proxy.BaseProxyEngine):
return self._engine
def set_engine(self, e):
self._engine = e
-# def __getattr__(self, key):
- # return getattr(self.engine, key)
def set_assert_list(self, unittest, list):
self.unittest = unittest
self.assert_list = list
@@ -164,46 +225,55 @@ class EngineAssert(proxy.BaseProxyEngine):
def _set_echo(self, echo):
self.engine.echo = echo
echo = property(lambda s: s.engine.echo, _set_echo)
- def post_exec(self, proxy, compiled, parameters, **kwargs):
- self.engine.logger = self.logger
- statement = str(compiled)
- statement = re.sub(r'\n', '', statement)
-
- if self.assert_list is not None:
- item = self.assert_list[-1]
- if not isinstance(item, dict):
- item = self.assert_list.pop()
- else:
- # asserting a dictionary of statements->parameters
- # this is to specify query assertions where the queries can be in
- # multiple orderings
- if not item.has_key('_converted'):
- for key in item.keys():
- ckey = self.convert_statement(key)
- item[ckey] = item[key]
- if ckey != key:
- del item[key]
- item['_converted'] = True
- try:
- entry = item.pop(statement)
- if len(item) == 1:
- self.assert_list.pop()
- item = (statement, entry)
- except KeyError:
- self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))
-
- (query, params) = item
- if callable(params):
- params = params()
-
- query = self.convert_statement(query)
-
- self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
- self.sql_count += 1
- return self.realexec(proxy, compiled, parameters, **kwargs)
+
+ def execution_context(self):
+ def post_exec(engine, proxy, compiled, parameters, **kwargs):
+ ctx = e
+ self.engine.logger = self.logger
+ statement = str(compiled)
+ statement = re.sub(r'\n', '', statement)
+ if self.assert_list is not None:
+ item = self.assert_list[-1]
+ if not isinstance(item, dict):
+ item = self.assert_list.pop()
+ else:
+ # asserting a dictionary of statements->parameters
+ # this is to specify query assertions where the queries can be in
+ # multiple orderings
+ if not item.has_key('_converted'):
+ for key in item.keys():
+ ckey = self.convert_statement(key)
+ item[ckey] = item[key]
+ if ckey != key:
+ del item[key]
+ item['_converted'] = True
+ try:
+ entry = item.pop(statement)
+ if len(item) == 1:
+ self.assert_list.pop()
+ item = (statement, entry)
+ except KeyError:
+ self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))
+
+ (query, params) = item
+ if callable(params):
+ params = params(ctx)
+ if params is not None and isinstance(params, list) and len(params) == 1:
+ params = params[0]
+
+ query = self.convert_statement(query)
+ self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
+ self.sql_count += 1
+ return realexec(ctx, proxy, compiled, parameters, **kwargs)
+
+ e = self.real_execution_context()
+ realexec = e.post_exec
+ realexec.im_self.post_exec = post_exec
+ return e
+
def convert_statement(self, query):
- paramstyle = self.engine.paramstyle
+ paramstyle = self.engine.dialect.paramstyle
if paramstyle == 'named':
pass
elif paramstyle =='pyformat':
@@ -275,10 +345,11 @@ parse_argv()
def runTests(suite):
- runner = unittest.TextTestRunner(verbosity = 2, descriptions =1)
+ runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
runner.run(suite)
def main():
- unittest.main()
+ suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__'))
+ runTests(suite)
diff --git a/test/testtypes.py b/test/testtypes.py
index ae58f0f54..db1fbca20 100644
--- a/test/testtypes.py
+++ b/test/testtypes.py
@@ -2,7 +2,8 @@ from sqlalchemy import *
import string,datetime, re, sys
from testbase import PersistTest, AssertMixin
import testbase
-
+import sqlalchemy.engine.url as url
+
db = testbase.db
class MyType(types.TypeEngine):
@@ -34,15 +35,15 @@ class MyUnicodeType(types.Unicode):
class AdaptTest(PersistTest):
def testadapt(self):
- e1 = create_engine('postgres://')
- e2 = create_engine('sqlite://')
- e3 = create_engine('mysql://')
+ e1 = url.URL('postgres').get_module().dialect()
+ e2 = url.URL('mysql').get_module().dialect()
+ e3 = url.URL('sqlite').get_module().dialect()
type = String(40)
- t1 = type.engine_impl(e1)
- t2 = type.engine_impl(e2)
- t3 = type.engine_impl(e3)
+ t1 = type.dialect_impl(e1)
+ t2 = type.dialect_impl(e2)
+ t3 = type.dialect_impl(e3)
assert t1 != t2
assert t2 != t3
assert t3 != t1
@@ -116,7 +117,7 @@ class ColumnsTest(AssertMixin):
)
for aCol in testTable.c:
- self.assertEquals(expectedResults[aCol.name], db.schemagenerator().get_column_specification(aCol))
+ self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None).get_column_specification(aCol))
class UnicodeTest(AssertMixin):
"""tests the Unicode type. also tests the TypeDecorator with instances in the types package."""
@@ -130,13 +131,6 @@ class UnicodeTest(AssertMixin):
unicode_table.create()
def tearDownAll(self):
unicode_table.drop()
- def testwhereclause(self):
- l = unicode_table.select(unicode_table.c.unicode_data==u'this is also unicode').execute()
- def testmapperwhere(self):
- class Foo(object):pass
- m = mapper(Foo, unicode_table)
- l = m.get_by(unicode_data=unicode('this is also unicode'))
- l = m.get_by(plain_data=unicode('this is also unicode'))
def testbasic(self):
rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
unicodedata = rawdata.decode('utf-8')
@@ -147,15 +141,15 @@ class UnicodeTest(AssertMixin):
self.assert_(isinstance(x['unicode_data'], unicode) and x['unicode_data'] == unicodedata)
if isinstance(x['plain_data'], unicode):
# SQLLite returns even non-unicode data as unicode
- self.assert_(sys.modules[db.engine.__module__].descriptor()['name'] == 'sqlite')
+ self.assert_(db.name == 'sqlite')
self.echo("its sqlite !")
else:
self.assert_(not isinstance(x['plain_data'], unicode) and x['plain_data'] == rawdata)
def testengineparam(self):
"""tests engine-wide unicode conversion"""
- prev_unicode = db.engine.convert_unicode
+ prev_unicode = db.engine.dialect.convert_unicode
try:
- db.engine.convert_unicode = True
+ db.engine.dialect.convert_unicode = True
rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
unicodedata = rawdata.decode('utf-8')
unicode_table.insert().execute(unicode_data=unicodedata, plain_data=rawdata)
@@ -165,8 +159,7 @@ class UnicodeTest(AssertMixin):
self.assert_(isinstance(x['unicode_data'], unicode) and x['unicode_data'] == unicodedata)
self.assert_(isinstance(x['plain_data'], unicode) and x['plain_data'] == unicodedata)
finally:
- db.engine.convert_unicode = prev_unicode
-
+ db.engine.dialect.convert_unicode = prev_unicode
class Foo(object):
def __init__(self, moredata):
@@ -175,7 +168,7 @@ class Foo(object):
self.moredata = moredata
def __eq__(self, other):
return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata
-
+
class BinaryTest(AssertMixin):
def setUpAll(self):
global binary_table
@@ -184,21 +177,20 @@ class BinaryTest(AssertMixin):
Column('data', Binary),
Column('data_slice', Binary(100)),
Column('misc', String(30)),
- Column('pickled', PickleType))
+ Column('pickled', PickleType)
+ )
binary_table.create()
def tearDownAll(self):
binary_table.drop()
def testbinary(self):
testobj1 = Foo('im foo 1')
testobj2 = Foo('im foo 2')
-
+
stream1 =self.get_module_stream('sqlalchemy.sql')
- stream2 =self.get_module_stream('sqlalchemy.engine')
+ stream2 =self.get_module_stream('sqlalchemy.schema')
binary_table.insert().execute(primary_id=1, misc='sql.pyc', data=stream1, data_slice=stream1[0:100], pickled=testobj1)
- binary_table.insert().execute(primary_id=2, misc='engine.pyc', data=stream2, data_slice=stream2[0:99], pickled=testobj2)
+ binary_table.insert().execute(primary_id=2, misc='schema.pyc', data=stream2, data_slice=stream2[0:99], pickled=testobj2)
l = binary_table.select().execute().fetchall()
- print type(l[0]['data'])
- return
print len(stream1), len(l[0]['data']), len(l[0]['data_slice'])
self.assert_(list(stream1) == list(l[0]['data']))
self.assert_(list(stream1[0:100]) == list(l[0]['data_slice']))
@@ -231,7 +223,7 @@ class DateTest(AssertMixin):
collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime),
Column('user_date', Date), Column('user_time', Time)]
- if db.engine.__module__.endswith('mysql') or db.engine.__module__.endswith('mssql'):
+ if db.engine.name == 'mysql' or db.engine.name == 'mssql':
# strip microseconds -- not supported by this engine (should be an easier way to detect this)
for d in insert_data:
if d[2] is not None:
diff --git a/test/transaction.py b/test/transaction.py
new file mode 100644
index 000000000..d76d7e0f4
--- /dev/null
+++ b/test/transaction.py
@@ -0,0 +1,63 @@
+
+import testbase
+import unittest, sys, datetime
+import tables
+db = testbase.db
+from sqlalchemy import *
+
+class TransactionTest(testbase.PersistTest):
+ def setUpAll(self):
+ global users, metadata
+ metadata = MetaData()
+ users = Table('query_users', metadata,
+ Column('user_id', INT, primary_key = True),
+ Column('user_name', VARCHAR(20)),
+ )
+ users.create(testbase.db)
+
+ def tearDown(self):
+ testbase.db.connect().execute(users.delete())
+ def tearDownAll(self):
+ users.drop(testbase.db)
+
+ @testbase.unsupported('mysql')
+ def testrollback(self):
+ """test a basic rollback"""
+ connection = testbase.db.connect()
+ transaction = connection.begin()
+ connection.execute(users.insert(), user_id=1, user_name='user1')
+ connection.execute(users.insert(), user_id=2, user_name='user2')
+ connection.execute(users.insert(), user_id=3, user_name='user3')
+ transaction.rollback()
+
+ result = connection.execute("select * from query_users")
+ assert len(result.fetchall()) == 0
+ connection.close()
+
+class AutoRollbackTest(testbase.PersistTest):
+ def setUpAll(self):
+ global metadata
+ metadata = MetaData()
+
+ def tearDownAll(self):
+ metadata.drop_all(testbase.db)
+
+ def testrollback_deadlock(self):
+ """test that returning connections to the pool clears any object locks."""
+ conn1 = testbase.db.connect()
+ conn2 = testbase.db.connect()
+ users = Table('deadlock_users', metadata,
+ Column('user_id', INT, primary_key = True),
+ Column('user_name', VARCHAR(20)),
+ )
+ users.create(conn1)
+ conn1.execute("select * from deadlock_users")
+ conn1.close()
+ # without auto-rollback in the connection pool's return() logic, this deadlocks in Postgres,
+ # because conn1 is returned to the pool but still has a lock on "deadlock_users"
+ # comment out the rollback in pool/ConnectionFairy._close() to see !
+ users.drop(conn2)
+ conn2.close()
+
+if __name__ == "__main__":
+ testbase.main()