diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-06-10 21:18:24 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-06-10 21:18:24 +0000 |
| commit | 45cec095b4904ba71425d2fe18c143982dd08f43 (patch) | |
| tree | af5e540fdcbf1cb2a3337157d69d4b40be010fa8 /test/orm | |
| parent | 698a3c1ac665e7cd2ef8d5ad3ebf51b7fe6661f4 (diff) | |
| download | sqlalchemy-45cec095b4904ba71425d2fe18c143982dd08f43.tar.gz | |
- unit tests have been migrated from unittest to nose.
See README.unittests for information on how to run
the tests. [ticket:970]
Diffstat (limited to 'test/orm')
| -rw-r--r-- | test/orm/_base.py | 108 | ||||
| -rw-r--r-- | test/orm/_fixtures.py | 43 | ||||
| -rw-r--r-- | test/orm/alltests.py | 60 | ||||
| -rw-r--r-- | test/orm/inheritance/alltests.py | 31 | ||||
| -rw-r--r-- | test/orm/inheritance/test_abc_inheritance.py (renamed from test/orm/inheritance/abc_inheritance.py) | 21 | ||||
| -rw-r--r-- | test/orm/inheritance/test_abc_polymorphic.py (renamed from test/orm/inheritance/abc_polymorphic.py) | 12 | ||||
| -rw-r--r-- | test/orm/inheritance/test_basic.py (renamed from test/orm/inheritance/basic.py) | 80 | ||||
| -rw-r--r-- | test/orm/inheritance/test_concrete.py (renamed from test/orm/inheritance/concrete.py) | 46 | ||||
| -rw-r--r-- | test/orm/inheritance/test_magazine.py (renamed from test/orm/inheritance/magazine.py) | 13 | ||||
| -rw-r--r-- | test/orm/inheritance/test_manytomany.py (renamed from test/orm/inheritance/manytomany.py) | 19 | ||||
| -rw-r--r-- | test/orm/inheritance/test_poly_linked_list.py (renamed from test/orm/inheritance/poly_linked_list.py) | 20 | ||||
| -rw-r--r-- | test/orm/inheritance/test_polymorph.py (renamed from test/orm/inheritance/polymorph.py) | 40 | ||||
| -rw-r--r-- | test/orm/inheritance/test_polymorph2.py (renamed from test/orm/inheritance/polymorph2.py) | 80 | ||||
| -rw-r--r-- | test/orm/inheritance/test_productspec.py (renamed from test/orm/inheritance/productspec.py) | 10 | ||||
| -rw-r--r-- | test/orm/inheritance/test_query.py (renamed from test/orm/inheritance/query.py) | 282 | ||||
| -rw-r--r-- | test/orm/inheritance/test_selects.py (renamed from test/orm/inheritance/selects.py) | 12 | ||||
| -rw-r--r-- | test/orm/inheritance/test_single.py (renamed from test/orm/inheritance/single.py) | 82 | ||||
| -rw-r--r-- | test/orm/sharding/alltests.py | 18 | ||||
| -rw-r--r-- | test/orm/sharding/test_shard.py (renamed from test/orm/sharding/shard.py) | 23 | ||||
| -rw-r--r-- | test/orm/test_association.py (renamed from test/orm/association.py) | 30 | ||||
| -rw-r--r-- | test/orm/test_assorted_eager.py (renamed from test/orm/assorted_eager.py) | 110 | ||||
| -rw-r--r-- | test/orm/test_attributes.py (renamed from test/orm/attributes.py) | 15 | ||||
| -rw-r--r-- | test/orm/test_bind.py (renamed from test/orm/bind.py) | 26 | ||||
| -rw-r--r-- | test/orm/test_cascade.py (renamed from test/orm/cascade.py) | 97 | ||||
| -rw-r--r-- | test/orm/test_collection.py (renamed from test/orm/collection.py) | 49 | ||||
| -rw-r--r-- | test/orm/test_compile.py (renamed from test/orm/compile.py) | 9 | ||||
| -rw-r--r-- | test/orm/test_cycles.py (renamed from test/orm/cycles.py) | 89 | ||||
| -rw-r--r-- | test/orm/test_defaults.py (renamed from test/orm/defaults.py) | 28 | ||||
| -rw-r--r-- | test/orm/test_deprecations.py (renamed from test/orm/deprecations.py) | 25 | ||||
| -rw-r--r-- | test/orm/test_dynamic.py (renamed from test/orm/dynamic.py) | 32 | ||||
| -rw-r--r-- | test/orm/test_eager_relations.py (renamed from test/orm/eager_relations.py) | 65 | ||||
| -rw-r--r-- | test/orm/test_evaluator.py (renamed from test/orm/evaluator.py) | 25 | ||||
| -rw-r--r-- | test/orm/test_expire.py (renamed from test/orm/expire.py) | 52 | ||||
| -rw-r--r-- | test/orm/test_extendedattr.py (renamed from test/orm/extendedattr.py) | 51 | ||||
| -rw-r--r-- | test/orm/test_generative.py (renamed from test/orm/generative.py) | 47 | ||||
| -rw-r--r-- | test/orm/test_instrumentation.py (renamed from test/orm/instrumentation.py) | 45 | ||||
| -rw-r--r-- | test/orm/test_lazy_relations.py (renamed from test/orm/lazy_relations.py) | 23 | ||||
| -rw-r--r-- | test/orm/test_lazytest1.py (renamed from test/orm/lazytest1.py) | 20 | ||||
| -rw-r--r-- | test/orm/test_manytomany.py (renamed from test/orm/manytomany.py) | 34 | ||||
| -rw-r--r-- | test/orm/test_mapper.py (renamed from test/orm/mapper.py) | 94 | ||||
| -rw-r--r-- | test/orm/test_merge.py (renamed from test/orm/merge.py) | 17 | ||||
| -rw-r--r-- | test/orm/test_naturalpks.py (renamed from test/orm/naturalpks.py) | 85 | ||||
| -rw-r--r-- | test/orm/test_onetoone.py (renamed from test/orm/onetoone.py) | 20 | ||||
| -rw-r--r-- | test/orm/test_pickled.py (renamed from test/orm/pickled.py) | 54 | ||||
| -rw-r--r-- | test/orm/test_query.py (renamed from test/orm/query.py) | 426 | ||||
| -rw-r--r-- | test/orm/test_relationships.py (renamed from test/orm/relationships.py) | 181 | ||||
| -rw-r--r-- | test/orm/test_scoping.py (renamed from test/orm/scoping.py) | 65 | ||||
| -rw-r--r-- | test/orm/test_selectable.py (renamed from test/orm/selectable.py) | 25 | ||||
| -rw-r--r-- | test/orm/test_session.py (renamed from test/orm/session.py) | 94 | ||||
| -rw-r--r-- | test/orm/test_transaction.py (renamed from test/orm/transaction.py) | 61 | ||||
| -rw-r--r-- | test/orm/test_unitofwork.py (renamed from test/orm/unitofwork.py) | 178 | ||||
| -rw-r--r-- | test/orm/test_utils.py (renamed from test/orm/utils.py) | 14 |
52 files changed, 1672 insertions, 1514 deletions
diff --git a/test/orm/_base.py b/test/orm/_base.py index 9e599a6f1..8d695e912 100644 --- a/test/orm/_base.py +++ b/test/orm/_base.py @@ -2,9 +2,10 @@ import gc import inspect import sys import types -from testlib import config, sa, testing -from testlib.testing import resolve_artifact_names, adict -from testlib.compat import _function_named +import sqlalchemy as sa +from sqlalchemy.test import config, testing +from sqlalchemy.test.testing import resolve_artifact_names, adict +from sqlalchemy.util import function_named _repr_stack = set() @@ -95,7 +96,8 @@ class ComparableEntity(BasicEntity): class ORMTest(testing.TestBase, testing.AssertsExecutionResults): __requires__ = ('subqueries',) - def tearDownAll(self): + @classmethod + def teardown_class(cls): sa.orm.session.Session.close_all() sa.orm.clear_mappers() # TODO: ensure mapper registry is empty @@ -124,18 +126,18 @@ class MappedTest(ORMTest): classes = None other_artifacts = None - def setUpAll(self): - if self.run_setup_classes == 'each': - assert self.run_setup_mappers != 'once' + @classmethod + def setup_class(cls): + if cls.run_setup_classes == 'each': + assert cls.run_setup_mappers != 'once' - assert self.run_deletes in (None, 'each') - if self.run_inserts == 'once': - assert self.run_deletes is None + assert cls.run_deletes in (None, 'each') + if cls.run_inserts == 'once': + assert cls.run_deletes is None - assert not hasattr(self, 'keep_mappers') - assert not hasattr(self, 'keep_data') + assert not hasattr(cls, 'keep_mappers') + assert not hasattr(cls, 'keep_data') - cls = self.__class__ if cls.tables is None: cls.tables = adict() if cls.classes is None: @@ -143,35 +145,32 @@ class MappedTest(ORMTest): if cls.other_artifacts is None: cls.other_artifacts = adict() - if self.metadata is None: - setattr(type(self), 'metadata', sa.MetaData()) + if cls.metadata is None: + setattr(cls, 'metadata', sa.MetaData()) - if self.metadata.bind is None: - self.metadata.bind = getattr(self, 'engine', config.db) + if cls.metadata.bind is None: + cls.metadata.bind = getattr(cls, 'engine', config.db) - if self.run_define_tables: - self.define_tables(self.metadata) - self.metadata.create_all() - self.tables.update(self.metadata.tables) + if cls.run_define_tables == 'once': + cls.define_tables(cls.metadata) + cls.metadata.create_all() + cls.tables.update(cls.metadata.tables) - if self.run_setup_classes: + if cls.run_setup_classes == 'once': baseline = subclasses(BasicEntity) - self.setup_classes() - self._register_new_class_artifacts(baseline) + cls.setup_classes() + cls._register_new_class_artifacts(baseline) - if self.run_setup_mappers: + if cls.run_setup_mappers == 'once': baseline = subclasses(BasicEntity) - self.setup_mappers() - self._register_new_class_artifacts(baseline) + cls.setup_mappers() + cls._register_new_class_artifacts(baseline) - if self.run_inserts: - self._load_fixtures() - self.insert_data() - - def setUp(self): - if self._sa_first_test: - return + if cls.run_inserts == 'once': + cls._load_fixtures() + cls.insert_data() + def setup(self): if self.run_define_tables == 'each': self.tables.clear() self.metadata.drop_all() @@ -195,7 +194,7 @@ class MappedTest(ORMTest): self._load_fixtures() self.insert_data() - def tearDown(self): + def teardown(self): sa.orm.session.Session.close_all() # some tests create mappers in the test bodies @@ -213,26 +212,32 @@ class MappedTest(ORMTest): print >> sys.stderr, "Error emptying table %s: %r" % ( table, ex) - def tearDownAll(self): - for cls in self.classes.values(): - self.unregister_class(cls) - ORMTest.tearDownAll(self) - self.metadata.drop_all() - self.metadata.bind = None + @classmethod + def teardown_class(cls): + for cl in cls.classes.values(): + cls.unregister_class(cl) + ORMTest.teardown_class() + cls.metadata.drop_all() + cls.metadata.bind = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): raise NotImplementedError() - def setup_classes(self): + @classmethod + def setup_classes(cls): pass - def setup_mappers(self): + @classmethod + def setup_mappers(cls): pass - def fixtures(self): + @classmethod + def fixtures(cls): return {} - def insert_data(self): + @classmethod + def insert_data(cls): pass def sql_count_(self, count, fn): @@ -260,15 +265,16 @@ class MappedTest(ORMTest): if name[0].isupper: delattr(cls, name) del cls.classes[name] - - def _load_fixtures(self): + + @classmethod + def _load_fixtures(cls): headers, rows = {}, {} - for table, data in self.fixtures().iteritems(): + for table, data in cls.fixtures().iteritems(): if isinstance(table, basestring): - table = self.tables[table] + table = cls.tables[table] headers[table] = data[0] rows[table] = data[1:] - for table in self.metadata.sorted_tables: + for table in cls.metadata.sorted_tables: if table not in headers: continue table.bind.execute( diff --git a/test/orm/_fixtures.py b/test/orm/_fixtures.py index f036b92b2..14709ec43 100644 --- a/test/orm/_fixtures.py +++ b/test/orm/_fixtures.py @@ -1,7 +1,9 @@ -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import attributes -from testlib.testing import fixture -from orm import _base +from sqlalchemy import MetaData, Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import attributes +from sqlalchemy.test.testing import fixture +from test.orm import _base __all__ = () @@ -227,34 +229,21 @@ class FixtureTest(_base.MappedTest): Address=Address, Dingaling=Dingaling) - def setUpAll(self): - assert not hasattr(self, 'refresh_data') - assert not hasattr(self, 'only_tables') - #refresh_data = False - #only_tables = False - - #if type(self) is not FixtureTest: - # setattr(type(self), 'classes', _base.adict(self.classes)) - - #if self.run_setup_classes: - # for cls in self.classes.values(): - # self.register_class(cls) - super(FixtureTest, self).setUpAll() - - #if not self.only_tables and self.keep_data: - # _registry.load() - - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): pass - def setup_classes(self): - for cls in self.fixture_classes.values(): - self.register_class(cls) + @classmethod + def setup_classes(cls): + for cl in cls.fixture_classes.values(): + cls.register_class(cl) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): pass - def insert_data(self): + @classmethod + def insert_data(cls): _load_fixtures() diff --git a/test/orm/alltests.py b/test/orm/alltests.py deleted file mode 100644 index 9458ca523..000000000 --- a/test/orm/alltests.py +++ /dev/null @@ -1,60 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -import inheritance.alltests as inheritance -import sharding.alltests as sharding - -def suite(): - modules_to_test = ( - 'orm.attributes', - 'orm.bind', - 'orm.extendedattr', - 'orm.instrumentation', - 'orm.query', - 'orm.lazy_relations', - 'orm.eager_relations', - 'orm.mapper', - 'orm.expire', - 'orm.selectable', - 'orm.collection', - 'orm.generative', - 'orm.lazytest1', - 'orm.assorted_eager', - - 'orm.naturalpks', - 'orm.defaults', - 'orm.unitofwork', - 'orm.session', - 'orm.transaction', - 'orm.scoping', - 'orm.cascade', - 'orm.relationships', - 'orm.association', - 'orm.merge', - 'orm.pickled', - 'orm.utils', - - 'orm.cycles', - - 'orm.compile', - 'orm.manytomany', - 'orm.onetoone', - 'orm.dynamic', - - 'orm.evaluator', - - 'orm.deprecations', - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - alltests.addTest(inheritance.suite()) - alltests.addTest(sharding.suite()) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/orm/inheritance/alltests.py b/test/orm/inheritance/alltests.py deleted file mode 100644 index 41f0521dd..000000000 --- a/test/orm/inheritance/alltests.py +++ /dev/null @@ -1,31 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -def suite(): - modules_to_test = ( - 'orm.inheritance.basic', - 'orm.inheritance.query', - 'orm.inheritance.manytomany', - 'orm.inheritance.single', - 'orm.inheritance.concrete', - 'orm.inheritance.polymorph', - 'orm.inheritance.polymorph2', - 'orm.inheritance.poly_linked_list', - 'orm.inheritance.abc_polymorphic', - 'orm.inheritance.abc_inheritance', - 'orm.inheritance.productspec', - 'orm.inheritance.magazine', - 'orm.inheritance.selects', - - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/orm/inheritance/abc_inheritance.py b/test/orm/inheritance/test_abc_inheritance.py index ee324e381..4e55cf70e 100644 --- a/test/orm/inheritance/abc_inheritance.py +++ b/test/orm/inheritance/test_abc_inheritance.py @@ -1,10 +1,9 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE -from testlib import testing -from orm import _base +from sqlalchemy.test import testing +from test.orm import _base def produce_test(parent, child, direction): @@ -12,9 +11,10 @@ def produce_test(parent, child, direction): relationship between two of the classes, using either one-to-many or many-to-one.""" class ABCTest(_base.MappedTest): - def define_tables(self, meta): + @classmethod + def define_tables(cls, metadata): global ta, tb, tc - ta = ["a", meta] + ta = ["a", metadata] ta.append(Column('id', Integer, primary_key=True)), ta.append(Column('a_data', String(30))) if "a"== parent and direction == MANYTOONE: @@ -23,7 +23,7 @@ def produce_test(parent, child, direction): ta.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) ta = Table(*ta) - tb = ["b", meta] + tb = ["b", metadata] tb.append(Column('id', Integer, ForeignKey("a.id"), primary_key=True, )) tb.append(Column('b_data', String(30))) @@ -34,7 +34,7 @@ def produce_test(parent, child, direction): tb.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) tb = Table(*tb) - tc = ["c", meta] + tc = ["c", metadata] tc.append(Column('id', Integer, ForeignKey("b.id"), primary_key=True, )) tc.append(Column('c_data', String(30))) @@ -45,14 +45,14 @@ def produce_test(parent, child, direction): tc.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) tc = Table(*tc) - def tearDown(self): + def teardown(self): if direction == MANYTOONE: parent_table = {"a":ta, "b":tb, "c": tc}[parent] parent_table.update(values={parent_table.c.child_id:None}).execute() elif direction == ONETOMANY: child_table = {"a":ta, "b":tb, "c": tc}[child] child_table.update(values={child_table.c.parent_id:None}).execute() - super(ABCTest, self).tearDown() + super(ABCTest, self).teardown() def test_roundtrip(self): parent_table = {"a":ta, "b":tb, "c": tc}[parent] @@ -167,5 +167,4 @@ for parent in ["a", "b", "c"]: exec("%s = testclass" % testclass.__name__) del testclass -if __name__ == "__main__": - testenv.main() +del produce_test
\ No newline at end of file diff --git a/test/orm/inheritance/abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py index 6fabbb24c..8cad8ed78 100644 --- a/test/orm/inheritance/abc_polymorphic.py +++ b/test/orm/inheritance/test_abc_polymorphic.py @@ -1,13 +1,13 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy import util from sqlalchemy.orm import * -from testlib import _function_named -from orm import _base, _fixtures +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures class ABCTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global a, b, c a = Table('a', metadata, Column('id', Integer, primary_key=True), @@ -78,7 +78,7 @@ class ABCTest(_base.MappedTest): C(cdata='c2', bdata='c2', adata='c2'), ] == sess.query(C).all() - test_roundtrip = _function_named( + test_roundtrip = function_named( test_roundtrip, 'test_%s' % fetchtype) return test_roundtrip @@ -86,5 +86,3 @@ class ABCTest(_base.MappedTest): test_none = make_test('none') -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/test_basic.py index 150874477..fc4aae17d 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/test_basic.py @@ -1,17 +1,17 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy import exc as sa_exc, util from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc -#from testlib import * -#from testlib import fixtures -from testlib import _function_named, testing, engines -from orm import _base, _fixtures +from sqlalchemy.test import testing, engines +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures class O2MTest(_base.MappedTest): """deals with inheritance and one-to-many relationships""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, blub foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq', optional=True), @@ -69,7 +69,8 @@ class O2MTest(_base.MappedTest): self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') class FalseDiscriminatorTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('type', Integer, nullable=False)) @@ -87,7 +88,8 @@ class FalseDiscriminatorTest(_base.MappedTest): assert isinstance(sess.query(Foo).one(), Bar) class PolymorphicSynonymTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1, t2 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), @@ -118,8 +120,8 @@ class PolymorphicSynonymTest(_base.MappedTest): sess.add(at2) sess.flush() sess.expunge_all() - self.assertEquals(sess.query(T2).filter(T2.info=='at2').one(), at2) - self.assertEquals(at2.info, "THE INFO IS:at2") + eq_(sess.query(T2).filter(T2.info=='at2').one(), at2) + eq_(at2.info, "THE INFO IS:at2") class CascadeTest(_base.MappedTest): @@ -127,7 +129,8 @@ class CascadeTest(_base.MappedTest): cascading along the path of the instance's mapper, not the base mapper.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1, t2, t3, t4 t1= Table('t1', metadata, Column('id', Integer, primary_key=True), @@ -191,7 +194,8 @@ class CascadeTest(_base.MappedTest): sess.flush() class GetTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, blub foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq', optional=True), @@ -209,7 +213,7 @@ class GetTest(_base.MappedTest): Column('bar_id', Integer, ForeignKey('bar.id')), Column('data', String(20))) - def create_test(polymorphic, name): + def _create_test(polymorphic, name): def test_get(self): class Foo(object): pass @@ -271,16 +275,17 @@ class GetTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 3) - test_get = _function_named(test_get, name) + test_get = function_named(test_get, name) return test_get - test_get_polymorphic = create_test(True, 'test_get_polymorphic') - test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic') + test_get_polymorphic = _create_test(True, 'test_get_polymorphic') + test_get_nonpolymorphic = _create_test(False, 'test_get_nonpolymorphic') class EagerLazyTest(_base.MappedTest): """tests eager load/lazy load of child items off inheritance mappers, tests that LazyLoader constructs the right query condition.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, bar_foo foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq', optional=True), @@ -325,7 +330,8 @@ class EagerLazyTest(_base.MappedTest): class FlushTest(_base.MappedTest): """test dependency sorting among inheriting mappers""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global users, roles, user_roles, admins users = Table('users', metadata, Column('id', Integer, primary_key=True), @@ -413,7 +419,8 @@ class FlushTest(_base.MappedTest): assert user_roles.count().scalar() == 1 class VersioningTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global base, subtable, stuff base = Table('base', metadata, Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ), @@ -522,7 +529,8 @@ class DistinctPKTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global person_table, employee_table, Person, Employee person_table = Table("persons", metadata, @@ -542,7 +550,8 @@ class DistinctPKTest(_base.MappedTest): class Employee(Person): pass - def insert_data(self): + @classmethod + def insert_data(cls): person_insert = person_table.insert() person_insert.execute(id=1, name='alice') person_insert.execute(id=2, name='bob') @@ -593,7 +602,8 @@ class DistinctPKTest(_base.MappedTest): class SyncCompileTest(_base.MappedTest): """test that syncrules compile properly on custom inherit conds""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global _a_table, _b_table, _c_table _a_table = Table('a', metadata, @@ -660,7 +670,8 @@ class SyncCompileTest(_base.MappedTest): class OverrideColKeyTest(_base.MappedTest): """test overriding of column attributes.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global base, subtable base = Table('base', metadata, @@ -686,7 +697,7 @@ class OverrideColKeyTest(_base.MappedTest): # Sub gets a "base_id" property using the "base_id" # column of both tables. - self.assertEquals( + eq_( class_mapper(Sub).get_property('base_id').columns, [base.c.base_id, subtable.c.base_id] ) @@ -710,7 +721,7 @@ class OverrideColKeyTest(_base.MappedTest): 'id':[base.c.base_id, subtable.c.base_id] }) - self.assertEquals( + eq_( class_mapper(Sub).get_property('id').columns, [base.c.base_id, subtable.c.base_id] ) @@ -733,12 +744,12 @@ class OverrideColKeyTest(_base.MappedTest): }) mapper(Sub, subtable, inherits=Base) - self.assertEquals( + eq_( class_mapper(Sub).get_property('id').columns, [base.c.base_id] ) - self.assertEquals( + eq_( class_mapper(Sub).get_property('base_id').columns, [subtable.c.base_id] ) @@ -782,7 +793,7 @@ class OverrideColKeyTest(_base.MappedTest): # it has its own "id" property. Sub's "id" property # gets joined normally with the extra column. - self.assertEquals( + eq_( class_mapper(Sub).get_property('id').columns, [base.c.base_id, subtable.c.base_id] ) @@ -892,7 +903,8 @@ class OptimizedLoadTest(_base.MappedTest): a column in the join condition is not available. """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global base, sub base = Table('base', metadata, Column('id', Integer, primary_key=True), @@ -931,7 +943,8 @@ class OptimizedLoadTest(_base.MappedTest): assert s1.sub == 's1sub' class PKDiscriminatorTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): parents = Table('parents', metadata, Column('id', Integer, primary_key=True), Column('name', String(60))) @@ -974,7 +987,8 @@ class PKDiscriminatorTest(_base.MappedTest): class DeleteOrphanTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global single, parent single = Table('single', metadata, Column('id', Integer, primary_key=True), @@ -1007,9 +1021,7 @@ class DeleteOrphanTest(_base.MappedTest): sess = create_session() s1 = SubClass(data='s1') sess.add(s1) - self.assertRaisesMessage(orm_exc.FlushError, + assert_raises_message(orm_exc.FlushError, "is not attached to any parent 'Parent' instance via that classes' 'related' attribute", sess.flush) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/concrete.py b/test/orm/inheritance/test_concrete.py index 6cdaed7e6..4a884cb86 100644 --- a/test/orm/inheritance/concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -1,12 +1,13 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc -from testlib import * -from testlib import sa, testing -from orm import _base +from sqlalchemy.test import * +import sqlalchemy as sa +from sqlalchemy.test import testing +from test.orm import _base from sqlalchemy.orm import attributes -from testlib.testing import eq_ +from sqlalchemy.test.testing import eq_ class Employee(object): def __init__(self, name): @@ -42,7 +43,8 @@ class Company(object): class ConcreteTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global managers_table, engineers_table, hackers_table, companies, employees_table companies = Table('companies', metadata, @@ -103,7 +105,7 @@ class ConcreteTest(_base.MappedTest): manager = session.query(Manager).one() session.expire(manager, ['manager_data']) - self.assertEquals(manager.manager_data, "knows how to manage things") + eq_(manager.manager_data, "knows how to manage things") def test_multi_level_no_base(self): pjoin = polymorphic_union({ @@ -144,8 +146,8 @@ class ConcreteTest(_base.MappedTest): assert 'name' not in attributes.instance_state(hacker).expired_attributes assert 'nickname' not in attributes.instance_state(hacker).expired_attributes def go(): - self.assertEquals(jerry.name, "Jerry") - self.assertEquals(hacker.nickname, "Badass") + eq_(jerry.name, "Jerry") + eq_(hacker.nickname, "Badass") self.assert_sql_count(testing.db, go, 0) session.expunge_all() @@ -194,8 +196,8 @@ class ConcreteTest(_base.MappedTest): session.flush() def go(): - self.assertEquals(jerry.name, "Jerry") - self.assertEquals(hacker.nickname, "Badass") + eq_(jerry.name, "Jerry") + eq_(hacker.nickname, "Badass") self.assert_sql_count(testing.db, go, 0) session.expunge_all() @@ -315,7 +317,8 @@ class ConcreteTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 1) class PropertyInheritanceTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('a_table', metadata, Column('id', Integer, primary_key=True), Column('some_c_id', Integer, ForeignKey('c_table.id')), @@ -332,7 +335,8 @@ class PropertyInheritanceTest(_base.MappedTest): ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_base.ComparableEntity): pass @@ -352,7 +356,7 @@ class PropertyInheritanceTest(_base.MappedTest): b = B() c = C() - self.assertRaises(AttributeError, setattr, b, 'some_c', c) + assert_raises(AttributeError, setattr, b, 'some_c', c) clear_mappers() mapper(A, a_table, properties={ @@ -361,7 +365,7 @@ class PropertyInheritanceTest(_base.MappedTest): mapper(B, b_table,inherits=A, concrete=True) mapper(C, c_table) b = B() - self.assertRaises(AttributeError, setattr, b, 'a_id', 3) + assert_raises(AttributeError, setattr, b, 'a_id', 3) clear_mappers() mapper(A, a_table, properties={ @@ -392,8 +396,8 @@ class PropertyInheritanceTest(_base.MappedTest): b1 = B(some_c=c1, bname='b1') b2 = B(some_c=c1, bname='b2') - self.assertRaises(AttributeError, setattr, b1, 'aname', 'foo') - self.assertRaises(AttributeError, getattr, A, 'bname') + assert_raises(AttributeError, setattr, b1, 'aname', 'foo') + assert_raises(AttributeError, getattr, A, 'bname') assert c2.many_a == [a2] assert c1.many_a == [a1] @@ -463,7 +467,8 @@ class PropertyInheritanceTest(_base.MappedTest): class ColKeysTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global offices_table, refugees_table refugees_table = Table('refugee', metadata, Column('refugee_fid', Integer, primary_key=True), @@ -473,7 +478,8 @@ class ColKeysTest(_base.MappedTest): Column('office_fid', Integer, primary_key=True), Column('office_name', Unicode(30), key='name')) - def insert_data(self): + @classmethod + def insert_data(cls): refugees_table.insert().execute( dict(refugee_fid=1, name=u"refugee1"), dict(refugee_fid=2, name=u"refugee2") @@ -511,5 +517,3 @@ class ColKeysTest(_base.MappedTest): eq_(sess.query(Office).get(1).name, "office1") eq_(sess.query(Office).get(2).name, "office2") -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/magazine.py b/test/orm/inheritance/test_magazine.py index 34374c887..067301251 100644 --- a/test/orm/inheritance/magazine.py +++ b/test/orm/inheritance/test_magazine.py @@ -1,9 +1,9 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -from testlib import testing, _function_named -from orm import _base +from sqlalchemy.test import testing +from sqlalchemy.util import function_named +from test.orm import _base class BaseObject(object): def __init__(self, *args, **kwargs): @@ -70,7 +70,8 @@ class ClassifiedPage(MagazinePage): class MagazineTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global publication_table, issue_table, location_table, location_name_table, magazine_table, \ page_table, magazine_page_table, classified_page_table, page_size_table @@ -208,7 +209,7 @@ def generate_round_trip_test(use_unions=False, use_joins=False): print [page, page2, page3] assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]), repr(p.issues[0].locations[0].magazine.pages) - test_roundtrip = _function_named( + test_roundtrip = function_named( test_roundtrip, "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions")) setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip) @@ -216,5 +217,3 @@ for (use_union, use_join) in [(True, False), (False, True), (False, False)]: generate_round_trip_test(use_union, use_join) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/manytomany.py b/test/orm/inheritance/test_manytomany.py index 5dbf69ba5..f7e676bbb 100644 --- a/test/orm/inheritance/manytomany.py +++ b/test/orm/inheritance/test_manytomany.py @@ -1,14 +1,15 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ from sqlalchemy import * from sqlalchemy.orm import * -from testlib import testing -from orm import _base +from sqlalchemy.test import testing +from test.orm import _base class InheritTest(_base.MappedTest): """deals with inheritance and many-to-many relationships""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global principals global users global groups @@ -67,7 +68,8 @@ class InheritTest(_base.MappedTest): class InheritTest2(_base.MappedTest): """deals with inheritance and many-to-many relationships""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, foo_bar foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_id_seq', optional=True), @@ -140,7 +142,8 @@ class InheritTest2(_base.MappedTest): class InheritTest3(_base.MappedTest): """deals with inheritance and many-to-many relationships""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, blub, bar_foo, blub_bar, blub_foo # the 'data' columns are to appease SQLite which cant handle a blank INSERT @@ -196,7 +199,7 @@ class InheritTest3(_base.MappedTest): l = sess.query(Bar).all() print repr(l[0]) + repr(l[0].foos) found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos])) - self.assertEqual(found, compare) + eq_(found, compare) @testing.fails_on('maxdb', 'FIXME: unknown') def testadvanced(self): @@ -244,5 +247,3 @@ class InheritTest3(_base.MappedTest): self.assert_(repr(x) == compare) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/poly_linked_list.py b/test/orm/inheritance/test_poly_linked_list.py index 2cf051949..67b543f31 100644 --- a/test/orm/inheritance/poly_linked_list.py +++ b/test/orm/inheritance/test_poly_linked_list.py @@ -1,15 +1,15 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -from orm import _base -from testlib import testing +from test.orm import _base +from sqlalchemy.test import testing class PolymorphicCircularTest(_base.MappedTest): run_setup_mappers = 'once' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global Table1, Table1B, Table2, Table3, Data table1 = Table('table1', metadata, Column('id', Integer, primary_key=True), @@ -115,26 +115,26 @@ class PolymorphicCircularTest(_base.MappedTest): @testing.fails_on('maxdb', 'FIXME: unknown') def testone(self): - self.do_testlist([Table1, Table2, Table1, Table2]) + self._testlist([Table1, Table2, Table1, Table2]) @testing.fails_on('maxdb', 'FIXME: unknown') def testtwo(self): - self.do_testlist([Table3]) + self._testlist([Table3]) @testing.fails_on('maxdb', 'FIXME: unknown') def testthree(self): - self.do_testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1]) + self._testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1]) @testing.fails_on('maxdb', 'FIXME: unknown') def testfour(self): - self.do_testlist([ + self._testlist([ Table2('t2', [Data('data1'), Data('data2')]), Table1('t1', []), Table3('t3', [Data('data3')]), Table1B('t1b', [Data('data4'), Data('data5')]) ]) - def do_testlist(self, classes): + def _testlist(self, classes): sess = create_session( ) # create objects in a linked list @@ -195,5 +195,3 @@ class PolymorphicCircularTest(_base.MappedTest): # everything should match ! assert original == forwards == backwards -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/test_polymorph.py index 81f6c82a1..cd3b2d89e 100644 --- a/test/orm/inheritance/polymorph.py +++ b/test/orm/inheritance/test_polymorph.py @@ -1,11 +1,12 @@ """tests basic polymorphic mapper loading/saving, minimal relations""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc -from testlib import _function_named, Column, testing -from orm import _fixtures, _base +from sqlalchemy.test import Column, testing +from sqlalchemy.util import function_named +from test.orm import _fixtures, _base class Person(_fixtures.Base): pass @@ -19,7 +20,8 @@ class Company(_fixtures.Base): pass class PolymorphTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global companies, people, engineers, managers, boss companies = Table('companies', metadata, @@ -83,7 +85,7 @@ class InsertOrderTest(PolymorphTest): session.add(c) session.flush() session.expunge_all() - self.assertEquals(session.query(Company).get(c.company_id), c) + eq_(session.query(Company).get(c.company_id), c) class RelationToSubclassTest(PolymorphTest): def test_basic(self): @@ -115,13 +117,13 @@ class RelationToSubclassTest(PolymorphTest): sess.flush() sess.expunge_all() - self.assertEquals(sess.query(Company).filter_by(company_id=c.company_id).one(), c) + eq_(sess.query(Company).filter_by(company_id=c.company_id).one(), c) assert c.managers[0].company is c class RoundTripTest(PolymorphTest): pass -def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic): +def _generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic): """generates a round trip test. include_base - whether or not to include the base 'person' type in the union. @@ -205,15 +207,15 @@ def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with session.flush() session.expunge_all() - self.assertEquals(session.query(Person).get(dilbert.person_id), dilbert) + eq_(session.query(Person).get(dilbert.person_id), dilbert) session.expunge_all() - self.assertEquals(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert) + eq_(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert) session.expunge_all() def go(): cc = session.query(Company).get(c.company_id) - self.assertEquals(cc.employees, employees) + eq_(cc.employees, employees) if not lazy_relation: if with_polymorphic != 'none': @@ -229,14 +231,14 @@ def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with # test selecting from the query, using the base mapped table (people) as the selection criterion. # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" - self.assertEquals( + eq_( session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first(), dilbert ) assert session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first().person_id - self.assertEquals( + eq_( session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first(), dilbert ) @@ -268,22 +270,22 @@ def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with # test standalone orphans daboss = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'}) session.add(daboss) - self.assertRaises(orm_exc.FlushError, session.flush) + assert_raises(orm_exc.FlushError, session.flush) c = session.query(Company).first() daboss.company = c manager_list = [e for e in c.employees if isinstance(e, Manager)] session.flush() session.expunge_all() - self.assertEquals(session.query(Manager).order_by(Manager.person_id).all(), manager_list) + eq_(session.query(Manager).order_by(Manager.person_id).all(), manager_list) c = session.query(Company).first() session.delete(c) session.flush() - self.assertEquals(people.count().scalar(), 0) + eq_(people.count().scalar(), 0) - test_roundtrip = _function_named( + test_roundtrip = function_named( test_roundtrip, "test_%s%s%s_%s" % ( (lazy_relation and "lazy" or "eager"), (include_base and "_inclbase" or ""), @@ -296,9 +298,7 @@ for lazy_relation in [True, False]: for with_polymorphic in ['unions', 'joins', 'auto', 'none']: if with_polymorphic == 'unions': for include_base in [True, False]: - generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic) + _generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic) else: - generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic) + _generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/polymorph2.py b/test/orm/inheritance/test_polymorph2.py index aec162b75..51b6d4970 100644 --- a/test/orm/inheritance/polymorph2.py +++ b/test/orm/inheritance/test_polymorph2.py @@ -2,14 +2,15 @@ inheritance setups for which we maintain compatibility. """ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ from sqlalchemy import * from sqlalchemy import util from sqlalchemy.orm import * -from testlib import _function_named, TestBase, AssertsExecutionResults, testing -from orm import _base, _fixtures -from testlib.testing import eq_ +from sqlalchemy.test import TestBase, AssertsExecutionResults, testing +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures +from sqlalchemy.test.testing import eq_ class AttrSettable(object): def __init__(self, **kwargs): @@ -20,7 +21,8 @@ class AttrSettable(object): class RelationTest1(_base.MappedTest): """test self-referential relationships on polymorphic mappers""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, managers people = Table('people', metadata, @@ -35,11 +37,11 @@ class RelationTest1(_base.MappedTest): Column('manager_name', String(50)) ) - def tearDown(self): + def teardown(self): people.update(values={people.c.manager_id:None}).execute() - super(RelationTest1, self).tearDown() + super(RelationTest1, self).teardown() - def testparentrefsdescendant(self): + def test_parent_refs_descendant(self): class Person(AttrSettable): pass class Manager(Person): @@ -55,7 +57,7 @@ class RelationTest1(_base.MappedTest): mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id) - self.assertEquals(class_mapper(Person).get_property('manager').synchronize_pairs, [(managers.c.person_id,people.c.manager_id)]) + eq_(class_mapper(Person).get_property('manager').synchronize_pairs, [(managers.c.person_id,people.c.manager_id)]) session = create_session() p = Person(name='some person') @@ -70,7 +72,7 @@ class RelationTest1(_base.MappedTest): print p, m, p.manager assert p.manager is m - def testdescendantrefsparent(self): + def test_descendant_refs_parent(self): class Person(AttrSettable): pass class Manager(Person): @@ -99,7 +101,8 @@ class RelationTest1(_base.MappedTest): class RelationTest2(_base.MappedTest): """test self-referential relationships on polymorphic mappers""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), @@ -194,7 +197,8 @@ class RelationTest2(_base.MappedTest): class RelationTest3(_base.MappedTest): """test self-referential relationships on polymorphic mappers""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), @@ -212,7 +216,7 @@ class RelationTest3(_base.MappedTest): Column('data', String(30)) ) -def generate_test(jointype="join1", usedata=False): +def _generate_test(jointype="join1", usedata=False): def do_test(self): class Person(AttrSettable): pass @@ -287,19 +291,20 @@ def generate_test(jointype="join1", usedata=False): assert p.data.data == 'ps data' assert m.data.data == 'ms data' - do_test = _function_named( - do_test, 'test_relationonbaseclass_%s_%s' % ( + do_test = function_named( + do_test, 'test_relation_on_base_class_%s_%s' % ( jointype, data and "nodata" or "data")) return do_test for jointype in ["join1", "join2", "join3", "join4"]: for data in (True, False): - func = generate_test(jointype, data) + func = _generate_test(jointype, data) setattr(RelationTest3, func.__name__, func) - +del func class RelationTest4(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, managers, cars people = Table('people', metadata, Column('person_id', Integer, primary_key=True), @@ -411,7 +416,8 @@ class RelationTest4(_base.MappedTest): assert c.car_id==car1.car_id class RelationTest5(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, managers, cars people = Table('people', metadata, Column('person_id', Integer, primary_key=True), @@ -472,7 +478,8 @@ class RelationTest5(_base.MappedTest): class RelationTest6(_base.MappedTest): """test self-referential relationships on a single joined-table inheritance mapper""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), @@ -514,7 +521,8 @@ class RelationTest6(_base.MappedTest): assert m.colleague is m2 class RelationTest7(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, managers, cars, offroad_cars cars = Table('cars', metadata, Column('car_id', Integer, primary_key=True), @@ -613,7 +621,8 @@ class RelationTest7(_base.MappedTest): assert p.car_id == p.car.car_id class RelationTest8(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global taggable, users taggable = Table('taggable', metadata, Column('id', Integer, primary_key=True), @@ -658,7 +667,8 @@ class RelationTest8(_base.MappedTest): ) class GenerativeTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): # cars---owned by--- people (abstract) --- has a --- status # | ^ ^ | # | | | | @@ -693,9 +703,10 @@ class GenerativeTest(TestBase, AssertsExecutionResults): metadata.create_all() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() - def tearDown(self): + def teardown(self): clear_mappers() for t in reversed(metadata.sorted_tables): t.delete().execute() @@ -784,7 +795,8 @@ class GenerativeTest(TestBase, AssertsExecutionResults): assert str(list(r)) == "[Engineer E4, field X, status Status dead]" class MultiLevelTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global table_Employee, table_Engineer, table_Manager table_Employee = Table( 'Employee', metadata, Column( 'name', type_= String(100), ), @@ -861,7 +873,8 @@ class MultiLevelTest(_base.MappedTest): assert session.query( Manager).all() == [c] class ManyToManyPolyTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global base_item_table, item_table, base_item_collection_table, collection_table base_item_table = Table( 'base_item', metadata, @@ -911,7 +924,8 @@ class ManyToManyPolyTest(_base.MappedTest): class_mapper(BaseItem) class CustomPKTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1, t2 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), @@ -994,7 +1008,8 @@ class CustomPKTest(_base.MappedTest): sess.flush() class InheritingEagerTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, employees, tags, peopleTags people = Table('people', metadata, @@ -1055,7 +1070,8 @@ class InheritingEagerTest(_base.MappedTest): assert len(instance.tags) == 2 class MissingPolymorphicOnTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global tablea, tableb, tablec, tabled tablea = Table('tablea', metadata, Column('id', Integer, primary_key=True), @@ -1101,7 +1117,5 @@ class MissingPolymorphicOnTest(_base.MappedTest): sess.add(d) sess.flush() sess.expunge_all() - self.assertEquals(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')]) + eq_(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')]) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/productspec.py b/test/orm/inheritance/test_productspec.py index b6a8c5146..b2bcb85d5 100644 --- a/test/orm/inheritance/productspec.py +++ b/test/orm/inheritance/test_productspec.py @@ -1,16 +1,16 @@ -import testenv; testenv.configure_for_tests() from datetime import datetime from sqlalchemy import * from sqlalchemy.orm import * -from testlib import testing -from orm import _base +from sqlalchemy.test import testing +from test.orm import _base class InheritTest(_base.MappedTest): """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global products_table, specification_table, documents_table global Product, Detail, Assembly, SpecLine, Document, RasterDocument @@ -316,5 +316,3 @@ class InheritTest(_base.MappedTest): print new assert orig == new == '<Assembly a1> specification=[<SpecLine 1.0 <Detail d1>>] documents=[<Document doc1>, <RasterDocument doc2>]' -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/test_query.py index 58d205455..5b57e8f45 100644 --- a/test/orm/inheritance/query.py +++ b/test/orm/inheritance/test_query.py @@ -1,13 +1,13 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy import exc as sa_exc from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.engine import default -from testlib import AssertsCompiledSQL, testing -from orm import _base, _fixtures -from testlib.testing import eq_ +from sqlalchemy.test import AssertsCompiledSQL, testing +from test.orm import _base, _fixtures +from sqlalchemy.test.testing import eq_ class Company(_fixtures.Base): pass @@ -27,13 +27,14 @@ class Machine(_fixtures.Base): class Paperwork(_fixtures.Base): pass -def make_test(select_type): +def _produce_test(select_type): class PolymorphicQueryTest(_base.MappedTest, AssertsCompiledSQL): run_inserts = 'once' run_setup_mappers = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global companies, people, engineers, managers, boss, paperwork, machines companies = Table('companies', metadata, @@ -128,7 +129,8 @@ def make_test(select_type): mapper(Paperwork, paperwork) - def insert_data(self): + @classmethod + def insert_data(cls): global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3, c1, c2 c1 = Company(name="MegaCorp, Inc.") @@ -178,14 +180,14 @@ def make_test(select_type): sess = create_session() def go(): - self.assertEquals(sess.query(Person).all(), all_employees) + eq_(sess.query(Person).all(), all_employees) self.assert_sql_count(testing.db, go, {'':14, 'Polymorphic':9}.get(select_type, 10)) def test_primary_eager_aliasing(self): sess = create_session() def go(): - self.assertEquals(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) + eq_(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4)) sess = create_session() @@ -194,7 +196,7 @@ def make_test(select_type): assert sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().subquery().count().scalar() == 2 def go(): - self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) + eq_(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) self.assert_sql_count(testing.db, go, 3) @@ -203,9 +205,9 @@ def make_test(select_type): # for all mappers, ensure the primary key has been calculated as just the "person_id" # column - self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) - self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) - self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore")) + eq_(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) + eq_(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) + eq_(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore")) def test_multi_join(self): sess = create_session() @@ -216,8 +218,8 @@ def make_test(select_type): q = sess.query(Company, Person, c, e).join((Person, Company.employees)).join((e, c.employees)).\ filter(Person.name=='dilbert').filter(e.name=='wally') - self.assertEquals(q.count(), 1) - self.assertEquals(q.all(), [ + eq_(q.count(), 1) + eq_(q.all(), [ ( Company(company_id=1,name=u'MegaCorp, Inc.'), Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), @@ -228,99 +230,99 @@ def make_test(select_type): def test_filter_on_subclass(self): sess = create_session() - self.assertEquals(sess.query(Engineer).all()[0], Engineer(name="dilbert")) + eq_(sess.query(Engineer).all()[0], Engineer(name="dilbert")) - self.assertEquals(sess.query(Engineer).first(), Engineer(name="dilbert")) + eq_(sess.query(Engineer).first(), Engineer(name="dilbert")) - self.assertEquals(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert")) + eq_(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert")) - self.assertEquals(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert")) + eq_(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert")) - self.assertEquals(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) + eq_(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) - self.assertEquals(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) + eq_(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) def test_join_from_polymorphic(self): sess = create_session() for aliased in (True, False): - self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) - self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) - self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1]) + eq_(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1]) - self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) def test_join_from_with_polymorphic(self): sess = create_session() for aliased in (True, False): sess.expunge_all() - self.assertEquals(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + eq_(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) sess.expunge_all() - self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + eq_(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) sess.expunge_all() - self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + eq_(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) def test_join_to_polymorphic(self): sess = create_session() - self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2) + eq_(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2) - self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2) + eq_(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2) def test_polymorphic_any(self): sess = create_session() - self.assertEquals( + eq_( sess.query(Company).\ filter(Company.employees.any(Person.name=='vlad')).all(), [c2] ) # test that the aliasing on "Person" does not bleed into the # EXISTS clause generated by any() - self.assertEquals( + eq_( sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\ filter(Company.employees.any(Person.name=='wally')).all(), [c1] ) - self.assertEquals( + eq_( sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\ filter(Company.employees.any(Person.name=='vlad')).all(), [] ) - self.assertEquals( + eq_( sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(), c2 ) calias = aliased(Company) - self.assertEquals( + eq_( sess.query(calias).filter(calias.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(), c2 ) - self.assertEquals( + eq_( sess.query(Company).filter(Company.employees.of_type(Boss).any(Boss.golf_swing=='fore')).one(), c1 ) - self.assertEquals( + eq_( sess.query(Company).filter(Company.employees.of_type(Boss).any(Manager.manager_name=='pointy')).one(), c1 ) if select_type != '': - self.assertEquals( + eq_( sess.query(Person).filter(Engineer.machines.any(Machine.name=="Commodore 64")).all(), [e2, e3] ) - self.assertEquals( + eq_( sess.query(Person).filter(Person.paperwork.any(Paperwork.description=="review #2")).all(), [m1] ) - self.assertEquals( + eq_( sess.query(Company).filter(Company.employees.of_type(Engineer).any(and_(Engineer.primary_language=='cobol'))).one(), c2 ) @@ -328,48 +330,48 @@ def make_test(select_type): def test_join_from_columns_or_subclass(self): sess = create_session() - self.assertEquals( + eq_( sess.query(Manager.name).order_by(Manager.name).all(), [(u'dogbert',), (u'pointy haired boss',)] ) - self.assertEquals( + eq_( sess.query(Manager.name).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(), [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)] ) - self.assertEquals( + eq_( sess.query(Person.name).join((Paperwork, Person.paperwork)).order_by(Person.name).all(), [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)] ) - self.assertEquals( + eq_( sess.query(Person.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Person.name).all(), [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)] ) - self.assertEquals( + eq_( sess.query(Manager).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(), [m1, b1] ) - self.assertEquals( + eq_( sess.query(Manager.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(), [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)] ) - self.assertEquals( + eq_( sess.query(Manager.person_id).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(), [(4,), (4,), (3,)] ) - self.assertEquals( + eq_( sess.query(Manager.name, Paperwork.description).join((Paperwork, Manager.person_id==Paperwork.person_id)).all(), [(u'pointy haired boss', u'review #1'), (u'dogbert', u'review #2'), (u'dogbert', u'review #3')] ) malias = aliased(Manager) - self.assertEquals( + eq_( sess.query(malias.name).join((paperwork, malias.person_id==paperwork.c.person_id)).all(), [(u'pointy haired boss',), (u'dogbert',), (u'dogbert',)] ) @@ -391,9 +393,9 @@ def make_test(select_type): sess = create_session() - self.assertRaises(sa_exc.InvalidRequestError, sess.query(Person).with_polymorphic, Paperwork) - self.assertRaises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Boss) - self.assertRaises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Person) + assert_raises(sa_exc.InvalidRequestError, sess.query(Person).with_polymorphic, Paperwork) + assert_raises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Boss) + assert_raises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Person) # compare to entities without related collections to prevent additional lazy SQL from firing on # loaded entities @@ -404,32 +406,32 @@ def make_test(select_type): Manager(name="dogbert", manager_name="dogbert", status="regular manager"), Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") ] - self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) + eq_(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) def go(): - self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1]) + eq_(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1]) self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) + eq_(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - self.assertEquals(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations) + eq_(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations) self.assert_sql_count(testing.db, go, 3) sess.expunge_all() def go(): - self.assertEquals(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations) + eq_(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations) self.assert_sql_count(testing.db, go, 3) sess.expunge_all() def go(): # limit the polymorphic join down to just "Person", overriding select_table - self.assertEquals(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations) + eq_(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations) self.assert_sql_count(testing.db, go, 6) def test_relation_to_polymorphic(self): @@ -449,14 +451,14 @@ def make_test(select_type): def go(): # test load Companies with lazy load to 'employees' - self.assertEquals(sess.query(Company).all(), assert_result) + eq_(sess.query(Company).all(), assert_result) self.assert_sql_count(testing.db, go, {'':9, 'Polymorphic':4}.get(select_type, 5)) sess = create_session() def go(): # currently, it doesn't matter if we say Company.employees, or Company.employees.of_type(Engineer). eagerloader doesn't # pick up on the "of_type()" as of yet. - self.assertEquals(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result) + eq_(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result) # in the case of select_type='', the eagerload doesn't take in this case; # it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines" @@ -466,78 +468,78 @@ def make_test(select_type): sess = create_session() def go(): # test load People with eagerload to engineers + machines - self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(), + eq_(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(), [Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])] ) self.assert_sql_count(testing.db, go, 1) def test_join_to_subclass(self): sess = create_session() - self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) if select_type == '': - self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) ealias = aliased(Engineer) - self.assertEquals(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3]) - self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) - self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2]) - self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + eq_(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3]) + eq_(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + eq_(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2]) + eq_(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) else: - self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) - self.assertEquals(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) - self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2]) - self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + eq_(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) + eq_(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + eq_(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2]) + eq_(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) # non-polymorphic - self.assertEquals(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3]) - self.assertEquals(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + eq_(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3]) + eq_(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) # here's the new way - self.assertEquals(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + eq_(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) def test_join_through_polymorphic(self): sess = create_session() for aliased in (True, False): - self.assertEquals( + eq_( sess.query(Company).\ join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [c1] ) - self.assertEquals( + eq_( sess.query(Company).\ join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#%')).all(), [c1, c2] ) - self.assertEquals( + eq_( sess.query(Company).\ join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#2%')).all(), [c1] ) - self.assertEquals( + eq_( sess.query(Company).\ join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#%')).all(), [c1, c2] ) - self.assertEquals( + eq_( sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [c1] ) - self.assertEquals( + eq_( sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#%')).all(), [c1, c2] @@ -549,14 +551,14 @@ def make_test(select_type): # ORMJoin using regular table foreign key connections. Engineer # is expressed as "(select * people join engineers) as anon_1" # so the join is contained. - self.assertEquals( + eq_( sess.query(Company).join(Engineer).filter(Engineer.engineer_name=='vlad').one(), c2 ) # same, using explicit join condition. Query.join() must adapt the on clause # here to match the subquery wrapped around "people join engineers". - self.assertEquals( + eq_( sess.query(Company).join((Engineer, Company.company_id==Engineer.company_id)).filter(Engineer.engineer_name=='vlad').one(), c2 ) @@ -565,17 +567,17 @@ def make_test(select_type): def test_filter_on_baseclass(self): sess = create_session() - self.assertEquals(sess.query(Person).all(), all_employees) + eq_(sess.query(Person).all(), all_employees) - self.assertEquals(sess.query(Person).first(), all_employees[0]) + eq_(sess.query(Person).first(), all_employees[0]) - self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2) + eq_(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2) def test_from_alias(self): sess = create_session() palias = aliased(Person) - self.assertEquals( + eq_( sess.query(palias).filter(palias.name.in_(['dilbert', 'wally'])).all(), [e1, e2] ) @@ -586,7 +588,7 @@ def make_test(select_type): c1_employees = [e1, e2, b1, m1] palias = aliased(Person) - self.assertEquals( + eq_( sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), [ @@ -596,7 +598,7 @@ def make_test(select_type): ] ) - self.assertEquals( + eq_( sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ filter(Person.person_id>palias.person_id).from_self().order_by(Person.person_id, palias.person_id).all(), [ @@ -614,30 +616,30 @@ def make_test(select_type): # the subquery and usually results in recursion overflow errors within the adaption. subq = sess.query(engineers.c.person_id).filter(Engineer.primary_language=='java').statement.as_scalar() - self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1) + eq_(sess.query(Person).filter(Person.person_id==subq).one(), e1) def test_mixed_entities(self): sess = create_session() - self.assertEquals( + eq_( sess.query(Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), [(u'Elbonia, Inc.', Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'))] ) - self.assertEquals( + eq_( sess.query(Person, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'), u'Elbonia, Inc.')] ) - self.assertEquals( + eq_( sess.query(Manager.name).all(), [('pointy haired boss', ), ('dogbert',)] ) - self.assertEquals( + eq_( sess.query(Manager.name + " foo").all(), [('pointy haired boss foo', ), ('dogbert foo',)] ) @@ -647,12 +649,12 @@ def make_test(select_type): assert row.primary_language == 'java' - self.assertEquals( + eq_( sess.query(Engineer.name, Engineer.primary_language).all(), [(u'dilbert', u'java'), (u'wally', u'c++'), (u'vlad', u'cobol')] ) - self.assertEquals( + eq_( sess.query(Boss.name, Boss.golf_swing).all(), [(u'pointy haired boss', u'fore')] ) @@ -670,18 +672,18 @@ def make_test(select_type): # [] # ) - self.assertEquals( + eq_( sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), [(u'vlad',u'Elbonia, Inc.')] ) - self.assertEquals( + eq_( sess.query(Engineer.primary_language).filter(Person.type=='engineer').all(), [(u'java',), (u'c++',), (u'cobol',)] ) if select_type != '': - self.assertEquals( + eq_( sess.query(Engineer, Company.name).join(Company.employees).filter(Person.type=='engineer').all(), [ (Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'MegaCorp, Inc.'), @@ -690,20 +692,20 @@ def make_test(select_type): ] ) - self.assertEquals( + eq_( sess.query(Engineer.primary_language, Company.name).join(Company.employees).filter(Person.type=='engineer').order_by(desc(Engineer.primary_language)).all(), [(u'java', u'MegaCorp, Inc.'), (u'cobol', u'Elbonia, Inc.'), (u'c++', u'MegaCorp, Inc.')] ) palias = aliased(Person) - self.assertEquals( + eq_( sess.query(Person, Company.name, palias).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'), u'Elbonia, Inc.', Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'))] ) - self.assertEquals( + eq_( sess.query(palias, Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), [(Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'Elbonia, Inc.', @@ -711,13 +713,13 @@ def make_test(select_type): ] ) - self.assertEquals( + eq_( sess.query(Person.name, Company.name, palias.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), [(u'vlad', u'Elbonia, Inc.', u'dilbert')] ) palias = aliased(Person) - self.assertEquals( + eq_( sess.query(Person.type, Person.name, palias.type, palias.name).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), [(u'manager', u'dogbert', u'engineer', u'dilbert'), @@ -725,7 +727,7 @@ def make_test(select_type): (u'manager', u'dogbert', u'boss', u'pointy haired boss')] ) - self.assertEquals( + eq_( sess.query(Person.name, Paperwork.description).filter(Person.person_id==Paperwork.person_id).order_by(Person.name, Paperwork.description).all(), [(u'dilbert', u'tps report #1'), (u'dilbert', u'tps report #2'), (u'dogbert', u'review #2'), (u'dogbert', u'review #3'), @@ -737,17 +739,17 @@ def make_test(select_type): ) if select_type != '': - self.assertEquals( + eq_( sess.query(func.count(Person.person_id)).filter(Engineer.primary_language=='java').all(), [(1, )] ) - self.assertEquals( + eq_( sess.query(Company.name, func.count(Person.person_id)).filter(Company.company_id==Person.company_id).group_by(Company.name).order_by(Company.name).all(), [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)] ) - self.assertEquals( + eq_( sess.query(Company.name, func.count(Person.person_id)).join(Company.employees).group_by(Company.name).order_by(Company.name).all(), [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)] ) @@ -757,7 +759,7 @@ def make_test(select_type): return PolymorphicQueryTest for select_type in ('', 'Polymorphic', 'Unions', 'AliasedJoins', 'Joins'): - testclass = make_test(select_type) + testclass = _produce_test(select_type) exec("%s = testclass" % testclass.__name__) del testclass @@ -765,7 +767,8 @@ del testclass class SelfReferentialTestJoinedToBase(_base.MappedTest): run_setup_mappers = 'once' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), @@ -778,7 +781,8 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest): Column('reports_to_id', Integer, ForeignKey('people.person_id')) ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') mapper(Engineer, engineers, inherits=Person, inherit_condition=engineers.c.person_id==people.c.person_id, @@ -796,7 +800,7 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest): sess.flush() sess.expunge_all() - self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert')) + eq_(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert')) def test_oftype_aliases_in_exists(self): e1 = Engineer(name='dilbert', primary_language='java') @@ -805,7 +809,7 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest): sess.add_all([e1, e2]) sess.flush() - self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.of_type(Engineer).has(Engineer.name=='dilbert')).first(), e2) + eq_(sess.query(Engineer).filter(Engineer.reports_to.of_type(Engineer).has(Engineer.name=='dilbert')).first(), e2) def test_join(self): p1 = Person(name='dogbert') @@ -816,14 +820,15 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest): sess.flush() sess.expunge_all() - self.assertEquals( + eq_( sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), Engineer(name='dilbert')) class SelfReferentialJ2JTest(_base.MappedTest): run_setup_mappers = 'once' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, managers people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), @@ -840,7 +845,8 @@ class SelfReferentialJ2JTest(_base.MappedTest): Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') mapper(Manager, managers, inherits=Person, polymorphic_identity='manager') @@ -859,7 +865,7 @@ class SelfReferentialJ2JTest(_base.MappedTest): sess.flush() sess.expunge_all() - self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Manager.name=='dogbert')).first(), Engineer(name='dilbert')) + eq_(sess.query(Engineer).filter(Engineer.reports_to.has(Manager.name=='dogbert')).first(), Engineer(name='dilbert')) def test_join(self): m1 = Manager(name='dogbert') @@ -870,7 +876,7 @@ class SelfReferentialJ2JTest(_base.MappedTest): sess.flush() sess.expunge_all() - self.assertEquals( + eq_( sess.query(Engineer).join('reports_to', aliased=True).filter(Manager.name=='dogbert').first(), Engineer(name='dilbert')) @@ -886,17 +892,17 @@ class SelfReferentialJ2JTest(_base.MappedTest): sess.expunge_all() # filter aliasing applied to Engineer doesn't whack Manager - self.assertEquals( + eq_( sess.query(Manager).join(Manager.engineers).filter(Manager.name=='dogbert').all(), [m1] ) - self.assertEquals( + eq_( sess.query(Manager).join(Manager.engineers).filter(Engineer.name=='dilbert').all(), [m2] ) - self.assertEquals( + eq_( sess.query(Manager, Engineer).join(Manager.engineers).order_by(Manager.name.desc()).all(), [ (m2, e2), @@ -919,12 +925,12 @@ class SelfReferentialJ2JTest(_base.MappedTest): sess.flush() sess.expunge_all() - self.assertEquals( + eq_( sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==None).all(), [] ) - self.assertEquals( + eq_( sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==m1).all(), [m1] ) @@ -936,7 +942,8 @@ class M2MFilterTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, organizations, engineers_to_org organizations = Table('organizations', metadata, @@ -958,7 +965,8 @@ class M2MFilterTest(_base.MappedTest): Column('primary_language', String(50)), ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): global Organization class Organization(_fixtures.Base): pass @@ -970,7 +978,8 @@ class M2MFilterTest(_base.MappedTest): mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer') - def insert_data(self): + @classmethod + def insert_data(cls): e1 = Engineer(name='e1') e2 = Engineer(name='e2') e3 = Engineer(name='e3') @@ -989,20 +998,21 @@ class M2MFilterTest(_base.MappedTest): e1 = sess.query(Person).filter(Engineer.name=='e1').one() # this works - self.assertEquals(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')]) + eq_(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')]) # this had a bug - self.assertEquals(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')]) + eq_(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')]) def test_any(self): sess = create_session() - self.assertEquals(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')]) - self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')]) + eq_(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')]) + eq_(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')]) class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL): run_setup_mappers = 'once' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global Parent, Child1, Child2 Base = declarative_base(metadata=metadata) @@ -1101,5 +1111,3 @@ class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL): assert q.first() is c1 -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/selects.py b/test/orm/inheritance/test_selects.py index e54a0ad13..a151af4fa 100644 --- a/test/orm/inheritance/selects.py +++ b/test/orm/inheritance/test_selects.py @@ -1,14 +1,14 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -from testlib import testing -from orm._fixtures import Base -from orm._base import MappedTest +from sqlalchemy.test import testing +from test.orm._fixtures import Base +from test.orm._base import MappedTest class InheritingSelectablesTest(MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, baz foo = Table('foo', metadata, Column('a', String(30), primary_key=1), @@ -49,5 +49,3 @@ class InheritingSelectablesTest(MappedTest): assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all() assert [Bar(), Bar()] == s.query(Bar).all() -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/single.py b/test/orm/inheritance/test_single.py index 7aee25031..705826885 100644 --- a/test/orm/inheritance/single.py +++ b/test/orm/inheritance/test_single.py @@ -1,14 +1,15 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ from sqlalchemy import * from sqlalchemy.orm import * -from testlib import testing -from orm import _fixtures -from orm._base import MappedTest, ComparableEntity +from sqlalchemy.test import testing +from test.orm import _fixtures +from test.orm._base import MappedTest, ComparableEntity class SingleInheritanceTest(MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('employees', metadata, Column('employee_id', Integer, primary_key=True), Column('name', String(50)), @@ -22,7 +23,8 @@ class SingleInheritanceTest(MappedTest): Column('name', String(50)), ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Employee(ComparableEntity): pass class Manager(Employee): @@ -32,8 +34,9 @@ class SingleInheritanceTest(MappedTest): class JuniorEngineer(Engineer): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Employee, employees, polymorphic_on=employees.c.type) mapper(Manager, inherits=Employee, polymorphic_identity='manager') mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') @@ -57,7 +60,7 @@ class SingleInheritanceTest(MappedTest): m1 = session.query(Manager).one() session.expire(m1, ['manager_data']) - self.assertEquals(m1.manager_data, "knows how to manage things") + eq_(m1.manager_data, "knows how to manage things") row = session.query(Engineer.name, Engineer.employee_id).filter(Engineer.name=='Kurt').first() assert row.name == 'Kurt' @@ -75,32 +78,32 @@ class SingleInheritanceTest(MappedTest): session.flush() ealias = aliased(Engineer) - self.assertEquals( + eq_( session.query(Manager, ealias).all(), [(m1, e1), (m1, e2)] ) - self.assertEquals( + eq_( session.query(Manager.name).all(), [("Tom",)] ) - self.assertEquals( + eq_( session.query(Manager.name, ealias.name).all(), [("Tom", "Kurt"), ("Tom", "Ed")] ) - self.assertEquals( + eq_( session.query(func.upper(Manager.name), func.upper(ealias.name)).all(), [("TOM", "KURT"), ("TOM", "ED")] ) - self.assertEquals( + eq_( session.query(Manager).add_entity(ealias).all(), [(m1, e1), (m1, e2)] ) - self.assertEquals( + eq_( session.query(Manager.name).add_column(ealias.name).all(), [("Tom", "Kurt"), ("Tom", "Ed")] ) @@ -121,7 +124,7 @@ class SingleInheritanceTest(MappedTest): sess.add_all([m1, m2, e1, e2]) sess.flush() - self.assertEquals( + eq_( sess.query(Manager).select_from(employees.select().limit(10)).all(), [m1, m2] ) @@ -136,12 +139,12 @@ class SingleInheritanceTest(MappedTest): sess.add_all([m1, m2, e1, e2]) sess.flush() - self.assertEquals(sess.query(Manager).count(), 2) - self.assertEquals(sess.query(Engineer).count(), 2) - self.assertEquals(sess.query(Employee).count(), 4) + eq_(sess.query(Manager).count(), 2) + eq_(sess.query(Engineer).count(), 2) + eq_(sess.query(Employee).count(), 4) - self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2) - self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3) + eq_(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2) + eq_(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3) @testing.resolve_artifact_names def test_type_filtering(self): @@ -180,7 +183,8 @@ class SingleInheritanceTest(MappedTest): class RelationToSingleTest(MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('employees', metadata, Column('employee_id', Integer, primary_key=True), Column('name', String(50)), @@ -195,7 +199,8 @@ class RelationToSingleTest(MappedTest): Column('name', String(50)), ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Company(ComparableEntity): pass @@ -229,14 +234,14 @@ class RelationToSingleTest(MappedTest): sess.add_all([c1, c2, m1, m2, e1, e2]) sess.commit() sess.expunge_all() - self.assertEquals( + eq_( sess.query(Company).filter(Company.employees.of_type(JuniorEngineer).any()).all(), [ Company(name='c1'), ] ) - self.assertEquals( + eq_( sess.query(Company).join(Company.employees.of_type(JuniorEngineer)).all(), [ Company(name='c1'), @@ -267,11 +272,11 @@ class RelationToSingleTest(MappedTest): sess.add_all([c1, c2, m1, m2, e1, e2]) sess.commit() - self.assertEquals(c1.engineers, [e2]) - self.assertEquals(c2.engineers, [e1]) + eq_(c1.engineers, [e2]) + eq_(c2.engineers, [e1]) sess.expunge_all() - self.assertEquals(sess.query(Company).order_by(Company.name).all(), + eq_(sess.query(Company).order_by(Company.name).all(), [ Company(name='c1', engineers=[JuniorEngineer(name='Ed')]), Company(name='c2', engineers=[Engineer(name='Kurt')]) @@ -280,7 +285,7 @@ class RelationToSingleTest(MappedTest): # eager load join should limit to only "Engineer" sess.expunge_all() - self.assertEquals(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(), + eq_(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(), [ Company(name='c1', engineers=[JuniorEngineer(name='Ed')]), Company(name='c2', engineers=[Engineer(name='Kurt')]) @@ -289,7 +294,7 @@ class RelationToSingleTest(MappedTest): # join() to Company.engineers, Employee as the requested entity sess.expunge_all() - self.assertEquals(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(), + eq_(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(), [ (Company(name='c1'), JuniorEngineer(name='Ed')), (Company(name='c2'), Engineer(name='Kurt')) @@ -299,7 +304,7 @@ class RelationToSingleTest(MappedTest): # join() to Company.engineers, Engineer as the requested entity. # this actually applies the IN criterion twice which is less than ideal. sess.expunge_all() - self.assertEquals(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(), + eq_(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(), [ (Company(name='c1'), JuniorEngineer(name='Ed')), (Company(name='c2'), Engineer(name='Kurt')) @@ -308,7 +313,7 @@ class RelationToSingleTest(MappedTest): # join() to Company.engineers without any Employee/Engineer entity sess.expunge_all() - self.assertEquals(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(), + eq_(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(), [ Company(name='c2') ] @@ -323,7 +328,7 @@ class RelationToSingleTest(MappedTest): @testing.fails_on_everything_except() def go(): sess.expunge_all() - self.assertEquals(sess.query(Company).\ + eq_(sess.query(Company).\ filter(Company.company_id==Engineer.company_id).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(), [ Company(name='c2') @@ -332,7 +337,8 @@ class RelationToSingleTest(MappedTest): go() class SingleOnJoinedTest(MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global persons_table, employees_table persons_table = Table('persons', metadata, @@ -366,31 +372,29 @@ class SingleOnJoinedTest(MappedTest): sess.flush() sess.expunge_all() - self.assertEquals(sess.query(Person).order_by(Person.person_id).all(), [ + eq_(sess.query(Person).order_by(Person.person_id).all(), [ Person(name='p1'), Employee(name='e1', employee_data='ed1'), Manager(name='m1', employee_data='ed2', manager_data='md1') ]) sess.expunge_all() - self.assertEquals(sess.query(Employee).order_by(Person.person_id).all(), [ + eq_(sess.query(Employee).order_by(Person.person_id).all(), [ Employee(name='e1', employee_data='ed1'), Manager(name='m1', employee_data='ed2', manager_data='md1') ]) sess.expunge_all() - self.assertEquals(sess.query(Manager).order_by(Person.person_id).all(), [ + eq_(sess.query(Manager).order_by(Person.person_id).all(), [ Manager(name='m1', employee_data='ed2', manager_data='md1') ]) sess.expunge_all() def go(): - self.assertEquals(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [ + eq_(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [ Person(name='p1'), Employee(name='e1', employee_data='ed1'), Manager(name='m1', employee_data='ed2', manager_data='md1') ]) self.assert_sql_count(testing.db, go, 1) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/sharding/alltests.py b/test/orm/sharding/alltests.py deleted file mode 100644 index 09fa86212..000000000 --- a/test/orm/sharding/alltests.py +++ /dev/null @@ -1,18 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -def suite(): - modules_to_test = ( - 'orm.sharding.shard', - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/test_shard.py index 10aaee131..89e23fb75 100644 --- a/test/orm/sharding/shard.py +++ b/test/orm/sharding/test_shard.py @@ -1,17 +1,17 @@ -import testenv; testenv.configure_for_tests() import datetime, os from sqlalchemy import * from sqlalchemy import sql from sqlalchemy.orm import * from sqlalchemy.orm.shard import ShardedSession from sqlalchemy.sql import operators -from testlib import * -from testlib.testing import eq_ +from sqlalchemy.test import * +from sqlalchemy.test.testing import eq_ # TODO: ShardTest can be turned into a base for further subclasses class ShardTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global db1, db2, db3, db4, weather_locations, weather_reports db1 = create_engine('sqlite:///shard1.db') @@ -48,16 +48,18 @@ class ShardTest(TestBase): db1.execute(ids.insert(), nextid=1) - self.setup_session() - self.setup_mappers() + cls.setup_session() + cls.setup_mappers() - def tearDownAll(self): + @classmethod + def teardown_class(cls): for db in (db1, db2, db3, db4): db.connect().invalidate() for i in range(1,5): os.remove("shard%d.db" % i) - def setup_session(self): + @classmethod + def setup_session(cls): global create_session shard_lookup = { @@ -104,7 +106,8 @@ class ShardTest(TestBase): }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): global WeatherLocation, Report class WeatherLocation(object): @@ -159,5 +162,3 @@ class ShardTest(TestBase): -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/association.py b/test/orm/test_association.py index d9265ffb1..ee7fb7af9 100644 --- a/test/orm/association.py +++ b/test/orm/test_association.py @@ -1,17 +1,19 @@ -import testenv; testenv.configure_for_tests() -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base -from testlib.testing import eq_ +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base +from sqlalchemy.test.testing import eq_ class AssociationTest(_base.MappedTest): run_setup_classes = 'once' run_setup_mappers = 'once' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('items', metadata, Column('item_id', Integer, primary_key=True), Column('name', String(40))) @@ -23,7 +25,8 @@ class AssociationTest(_base.MappedTest): Column('keyword_id', Integer, primary_key=True), Column('name', String(40))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Item(_base.BasicEntity): def __init__(self, name): self.name = name @@ -45,9 +48,10 @@ class AssociationTest(_base.MappedTest): return "KeywordAssociation itemid=%d keyword=%r data=%s" % ( self.item_id, self.keyword, self.data) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): - items, item_keywords, keywords = self.tables.get_all( + def setup_mappers(cls): + items, item_keywords, keywords = cls.tables.get_all( 'items', 'item_keywords', 'keywords') mapper(Keyword, keywords) @@ -133,13 +137,11 @@ class AssociationTest(_base.MappedTest): item2.keywords.append(KeywordAssociation(Keyword('green'), 'green_assoc')) sess.add_all((item1, item2)) sess.flush() - eq_(self.tables.item_keywords.count().scalar(), 3) + eq_(item_keywords.count().scalar(), 3) sess.delete(item1) sess.delete(item2) sess.flush() - eq_(self.tables.item_keywords.count().scalar(), 0) + eq_(item_keywords.count().scalar(), 0) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/assorted_eager.py b/test/orm/test_assorted_eager.py index 8dc95fa5b..09f007547 100644 --- a/test/orm/assorted_eager.py +++ b/test/orm/test_assorted_eager.py @@ -3,21 +3,25 @@ Derived from mailing list-reported problems and trac tickets. """ -import testenv; testenv.configure_for_tests() import datetime -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, backref, create_session -from testlib.testing import eq_ -from orm import _base +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base class EagerTest(_base.MappedTest): run_deletes = None run_inserts = "once" + run_setup_mappers = "once" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): # determine a literal value for "false" based on the dialect # FIXME: this DefaultClause setup is bogus. @@ -30,7 +34,7 @@ class EagerTest(_base.MappedTest): false = text('FALSE') else: false = str(False) - self.other_artifacts['false'] = false + cls.other_artifacts['false'] = false Table('owners', metadata , Column('id', Integer, primary_key=True, nullable=False), @@ -55,30 +59,32 @@ class EagerTest(_base.MappedTest): Column('someoption', sa.Boolean, server_default=false, nullable=False)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Owner(_base.BasicEntity): pass class Category(_base.BasicEntity): pass - class Test(_base.BasicEntity): + class Thing(_base.BasicEntity): pass class Option(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Owner, owners) mapper(Category, categories) mapper(Option, options, properties=dict( owner=relation(Owner), - test=relation(Test))) + test=relation(Thing))) - mapper(Test, tests, properties=dict( + mapper(Thing, tests, properties=dict( owner=relation(Owner, backref='tests'), category=relation(Category), owner_option=relation(Option, @@ -87,16 +93,17 @@ class EagerTest(_base.MappedTest): foreign_keys=[options.c.test_id, options.c.owner_id], uselist=False))) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): session = create_session() o = Owner() c = Category(name='Some Category') session.add_all(( - Test(owner=o, category=c), - Test(owner=o, category=c, owner_option=Option(someoption=True)), - Test(owner=o, category=c, owner_option=Option()))) + Thing(owner=o, category=c), + Thing(owner=o, category=c, owner_option=Option(someoption=True)), + Thing(owner=o, category=c, owner_option=Option()))) session.flush() @@ -129,7 +136,7 @@ class EagerTest(_base.MappedTest): @testing.resolve_artifact_names def test_withouteagerload(self): s = create_session() - l = (s.query(Test). + l = (s.query(Thing). select_from(tests.outerjoin(options, sa.and_(tests.c.id == options.c.test_id, tests.c.owner_id == @@ -150,7 +157,7 @@ class EagerTest(_base.MappedTest): """ s = create_session() - q=s.query(Test).options(sa.orm.eagerload('category')) + q=s.query(Thing).options(sa.orm.eagerload('category')) l=(q.select_from(tests.outerjoin(options, sa.and_(tests.c.id == @@ -168,7 +175,7 @@ class EagerTest(_base.MappedTest): def test_dslish(self): """test the same as witheagerload except using generative""" s = create_session() - q = s.query(Test).options(sa.orm.eagerload('category')) + q = s.query(Thing).options(sa.orm.eagerload('category')) l = q.filter ( sa.and_(tests.c.owner_id == 1, sa.or_(options.c.someoption == None, @@ -182,7 +189,7 @@ class EagerTest(_base.MappedTest): @testing.resolve_artifact_names def test_without_outerjoin_literal(self): s = create_session() - q = s.query(Test).options(sa.orm.eagerload('category')) + q = s.query(Thing).options(sa.orm.eagerload('category')) l = (q.filter( (tests.c.owner_id==1) & ('options.someoption is null or options.someoption=%s' % false)). @@ -194,7 +201,7 @@ class EagerTest(_base.MappedTest): @testing.resolve_artifact_names def test_withoutouterjoin(self): s = create_session() - q = s.query(Test).options(sa.orm.eagerload('category')) + q = s.query(Thing).options(sa.orm.eagerload('category')) l = q.filter( (tests.c.owner_id==1) & ((options.c.someoption==None) | (options.c.someoption==False)) @@ -205,7 +212,8 @@ class EagerTest(_base.MappedTest): class EagerTest2(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('left', metadata, Column('id', Integer, ForeignKey('middle.id'), primary_key=True), Column('data', String(50), primary_key=True)) @@ -218,7 +226,8 @@ class EagerTest2(_base.MappedTest): Column('id', Integer, ForeignKey('middle.id'), primary_key=True), Column('data', String(50), primary_key=True)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Left(_base.BasicEntity): def __init__(self, data): self.data = data @@ -231,8 +240,9 @@ class EagerTest2(_base.MappedTest): def __init__(self, data): self.data = data + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): # set up bi-directional eager loads mapper(Left, left) mapper(Right, right) @@ -267,7 +277,8 @@ class EagerTest2(_base.MappedTest): class EagerTest3(_base.MappedTest): """Eager loading combined with nested SELECT statements, functions, and aggregates.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('datas', metadata, Column('id', Integer, primary_key=True, nullable=False), Column('a', Integer, nullable=False)) @@ -283,7 +294,8 @@ class EagerTest3(_base.MappedTest): Column('data_id', Integer, ForeignKey('datas.id')), Column('somedata', Integer, nullable=False )) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Data(_base.BasicEntity): pass @@ -349,7 +361,8 @@ class EagerTest3(_base.MappedTest): class EagerTest4(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('departments', metadata, Column('department_id', Integer, primary_key=True), Column('name', String(50))) @@ -360,7 +373,8 @@ class EagerTest4(_base.MappedTest): Column('department_id', Integer, ForeignKey('departments.department_id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Department(_base.BasicEntity): pass @@ -401,7 +415,8 @@ class EagerTest4(_base.MappedTest): class EagerTest5(_base.MappedTest): """Construction of AliasedClauses for the same eager load property but different parent mappers, due to inheritance.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('base', metadata, Column('uid', String(30), primary_key=True), Column('x', String(30))) @@ -421,7 +436,8 @@ class EagerTest5(_base.MappedTest): Column('uid', String(30), ForeignKey('base.uid')), Column('comment', String(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Base(_base.BasicEntity): def __init__(self, uid, x): self.uid = uid @@ -486,7 +502,8 @@ class EagerTest5(_base.MappedTest): class EagerTest6(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('design_types', metadata, Column('design_type_id', Integer, primary_key=True)) @@ -506,7 +523,8 @@ class EagerTest6(_base.MappedTest): Column('part_id', Integer, ForeignKey('parts.part_id')), Column('design_id', Integer, ForeignKey('design.design_id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Part(_base.BasicEntity): pass @@ -552,7 +570,8 @@ class EagerTest6(_base.MappedTest): class EagerTest7(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('companies', metadata, Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -584,7 +603,8 @@ class EagerTest7(_base.MappedTest): Column('code', String(20)), Column('qty', Integer)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Company(_base.ComparableEntity): pass @@ -699,7 +719,8 @@ class EagerTest7(_base.MappedTest): class EagerTest8(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('prj', metadata, Column('id', Integer, primary_key=True), Column('created', sa.DateTime ), @@ -731,8 +752,9 @@ class EagerTest8(_base.MappedTest): Column('name', sa.Unicode(20)), Column('display_name', sa.Unicode(20))) + @classmethod @testing.resolve_artifact_names - def fixtures(self): + def fixtures(cls): return dict( prj=(('id',), (1,)), @@ -746,7 +768,8 @@ class EagerTest8(_base.MappedTest): task=(('title', 'task_type_id', 'status_id', 'prj_id'), (u'task 1', 1, 1, 1))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Task_Type(_base.BasicEntity): pass @@ -788,7 +811,8 @@ class EagerTest9(_base.MappedTest): throughout the query setup/mapper instances process. """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('accounts', metadata, Column('account_id', Integer, primary_key=True), Column('name', String(40))) @@ -805,7 +829,8 @@ class EagerTest9(_base.MappedTest): Column('transaction_id', Integer, ForeignKey('transactions.transaction_id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Account(_base.BasicEntity): pass @@ -815,8 +840,9 @@ class EagerTest9(_base.MappedTest): class Entry(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Account, accounts) mapper(Transaction, transactions) @@ -874,5 +900,3 @@ class EagerTest9(_base.MappedTest): -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/attributes.py b/test/orm/test_attributes.py index 7c116fcf7..3b1b42dad 100644 --- a/test/orm/attributes.py +++ b/test/orm/test_attributes.py @@ -1,12 +1,11 @@ -import testenv; testenv.configure_for_tests() import pickle import sqlalchemy.orm.attributes as attributes from sqlalchemy.orm.collections import collection from sqlalchemy.orm.interfaces import AttributeExtension from sqlalchemy import exc as sa_exc -from testlib import * -from testlib.testing import eq_ -from orm import _base +from sqlalchemy.test import * +from sqlalchemy.test.testing import eq_ +from test.orm import _base import gc # global for pickling tests @@ -15,12 +14,12 @@ MyTest2 = None class AttributesTest(_base.ORMTest): - def setUp(self): + def setup(self): global MyTest, MyTest2 class MyTest(object): pass class MyTest2(object): pass - def tearDown(self): + def teardown(self): global MyTest, MyTest2 MyTest, MyTest2 = None, None @@ -588,7 +587,7 @@ class BackrefTest(_base.ORMTest): self.assert_(p.jack is None) class PendingBackrefTest(_base.ORMTest): - def setUp(self): + def setup(self): global Post, Blog, called, lazy_load class Post(object): @@ -1327,5 +1326,3 @@ class ListenerTest(_base.ORMTest): assert f1.barset.pop().data == "some bar appended" -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/bind.py b/test/orm/test_bind.py index 33d028d22..9b1c20b60 100644 --- a/test/orm/bind.py +++ b/test/orm/test_bind.py @@ -1,23 +1,29 @@ -import testenv; testenv.configure_for_tests() -from testlib.sa import MetaData, Table, Column, Integer -from testlib.sa.orm import mapper, create_session -from testlib import sa, testing -from orm import _base +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy import MetaData, Integer +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, create_session +import sqlalchemy as sa +from sqlalchemy.test import testing +from test.orm import _base class BindTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('test_table', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', Integer)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Foo(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): meta = MetaData() test_table.tometadata(meta) @@ -44,12 +50,10 @@ class BindTest(_base.MappedTest): def test_session_unbound(self): sess = create_session() sess.add(Foo()) - self.assertRaisesMessage( + assert_raises_message( sa.exc.UnboundExecutionError, ('Could not locate a bind configured on Mapper|Foo|test_table ' 'or this Session'), sess.flush) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/cascade.py b/test/orm/test_cascade.py index c827a85ce..d0a7b9ded 100644 --- a/test/orm/cascade.py +++ b/test/orm/test_cascade.py @@ -1,18 +1,21 @@ -import testenv; testenv.configure_for_tests() -from testlib.sa import Table, Column, Integer, String, ForeignKey, Sequence, exc as sa_exc -from testlib.sa.orm import mapper, relation, create_session, class_mapper, backref -from testlib.sa.orm import attributes, exc as orm_exc -from testlib import testing -from testlib.testing import eq_ -from orm import _base, _fixtures +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref +from sqlalchemy.orm import attributes, exc as orm_exc +from sqlalchemy.test import testing +from sqlalchemy.test.testing import eq_ +from test.orm import _base, _fixtures class O2MCascadeTest(_fixtures.FixtureTest): run_inserts = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Address, addresses) mapper(User, users, properties = dict( addresses = relation(Address, cascade="all, delete-orphan", backref="user"), @@ -188,8 +191,9 @@ class O2MCascadeTest(_fixtures.FixtureTest): class O2OCascadeTest(_fixtures.FixtureTest): run_inserts = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Address, addresses) mapper(User, users, properties = { 'address':relation(Address, backref=backref("user", single_parent=True), uselist=False) @@ -200,7 +204,7 @@ class O2OCascadeTest(_fixtures.FixtureTest): a1 = Address(email_address='some address') u1 = User(name='u1', address=a1) - self.assertRaises(sa_exc.InvalidRequestError, Address, email_address='asd', user=u1) + assert_raises(sa_exc.InvalidRequestError, Address, email_address='asd', user=u1) a2 = Address(email_address='asd') u1.address = a2 @@ -212,8 +216,9 @@ class O2OCascadeTest(_fixtures.FixtureTest): class O2MBackrefTest(_fixtures.FixtureTest): run_inserts = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users, properties = dict( orders = relation( mapper(Order, orders), cascade="all, delete-orphan", backref="user") @@ -316,8 +321,9 @@ class NoSaveCascadeTest(_fixtures.FixtureTest): class O2MCascadeNoOrphanTest(_fixtures.FixtureTest): run_inserts = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users, properties = dict( orders = relation( mapper(Order, orders), cascade="all") @@ -342,7 +348,8 @@ class O2MCascadeNoOrphanTest(_fixtures.FixtureTest): class M2OCascadeTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("extra", metadata, Column("id", Integer, Sequence("extra_id_seq", optional=True), primary_key=True), @@ -359,7 +366,8 @@ class M2OCascadeTest(_base.MappedTest): Column('name', String(40)), Column('pref_id', Integer, ForeignKey('prefs.id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_fixtures.Base): pass class Pref(_fixtures.Base): @@ -367,8 +375,9 @@ class M2OCascadeTest(_base.MappedTest): class Extra(_fixtures.Base): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Extra, extra) mapper(Pref, prefs, properties=dict( extra = relation(Extra, cascade="all, delete") @@ -377,8 +386,9 @@ class M2OCascadeTest(_base.MappedTest): pref = relation(Pref, lazy=False, cascade="all, delete-orphan", single_parent=True ) )) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): u1 = User(name='ed', pref=Pref(data="pref 1", extra=[Extra()])) u2 = User(name='jack', pref=Pref(data="pref 2", extra=[Extra()])) u3 = User(name="foo", pref=Pref(data="pref 3", extra=[Extra()])) @@ -447,7 +457,8 @@ class M2OCascadeTest(_base.MappedTest): [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")]) class M2OCascadeDeleteTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)), @@ -460,7 +471,8 @@ class M2OCascadeDeleteTest(_base.MappedTest): Column('id', Integer, primary_key=True), Column('data', String(50))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T1(_fixtures.Base): pass class T2(_fixtures.Base): @@ -468,8 +480,9 @@ class M2OCascadeDeleteTest(_base.MappedTest): class T3(_fixtures.Base): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(T1, t1, properties={'t2': relation(T2, cascade="all")}) mapper(T2, t2, properties={'t3': relation(T3, cascade="all")}) mapper(T3, t3) @@ -565,7 +578,8 @@ class M2OCascadeDeleteTest(_base.MappedTest): class M2OCascadeDeleteOrphanTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)), @@ -578,7 +592,8 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): Column('id', Integer, primary_key=True), Column('data', String(50))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T1(_fixtures.Base): pass class T2(_fixtures.Base): @@ -586,8 +601,9 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): class T3(_fixtures.Base): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(T1, t1, properties=dict( t2=relation(T2, cascade="all, delete-orphan", single_parent=True))) mapper(T2, t2, properties=dict( @@ -655,7 +671,7 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): y = T2(data='T2a') x = T1(data='T1a', t2=y) - self.assertRaises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y) + assert_raises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y) @testing.resolve_artifact_names def test_single_parent_backref(self): @@ -666,7 +682,7 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): x = T2(data='T2a', t3=y) # cant attach the T3 to another T2 - self.assertRaises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y) + assert_raises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y) # set via backref tho is OK, unsets from previous parent # first @@ -677,7 +693,8 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): assert x.t3 is None class M2MCascadeTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('a', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), @@ -703,7 +720,8 @@ class M2MCascadeTest(_base.MappedTest): ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_fixtures.Base): pass class B(_fixtures.Base): @@ -784,7 +802,7 @@ class M2MCascadeTest(_base.MappedTest): b1 =B(data='b1') a1 = A(data='a1', bs=[b1]) - self.assertRaises(sa_exc.InvalidRequestError, + assert_raises(sa_exc.InvalidRequestError, A, data='a2', bs=[b1] ) @@ -804,7 +822,7 @@ class M2MCascadeTest(_base.MappedTest): b1 =B(data='b1') a1 = A(data='a1', bs=[b1]) - self.assertRaises( + assert_raises( sa_exc.InvalidRequestError, A, data='a2', bs=[b1] ) @@ -817,7 +835,8 @@ class M2MCascadeTest(_base.MappedTest): class UnsavedOrphansTest(_base.MappedTest): """Pending entities that are orphans""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('user_id', Integer, Sequence('user_id_seq', optional=True), @@ -831,7 +850,8 @@ class UnsavedOrphansTest(_base.MappedTest): Column('user_id', Integer, ForeignKey('users.user_id')), Column('email_address', String(40))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_fixtures.Base): pass class Address(_fixtures.Base): @@ -900,7 +920,8 @@ class UnsavedOrphansTest(_base.MappedTest): class UnsavedOrphansTest2(_base.MappedTest): """same test as UnsavedOrphans only three levels deep""" - def define_tables(self, meta): + @classmethod + def define_tables(cls, meta): Table('orders', meta, Column('id', Integer, Sequence('order_id_seq'), primary_key=True), @@ -958,7 +979,8 @@ class UnsavedOrphansTest2(_base.MappedTest): class UnsavedOrphansTest3(_base.MappedTest): """test not expunging double parents""" - def define_tables(self, meta): + @classmethod + def define_tables(cls, meta): Table('sales_reps', meta, Column('sales_rep_id', Integer, Sequence('sales_rep_id_seq'), @@ -1062,7 +1084,8 @@ class UnsavedOrphansTest3(_base.MappedTest): class DoubleParentOrphanTest(_base.MappedTest): """test orphan detection for an entity with two parent relations""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('addresses', metadata, Column('address_id', Integer, primary_key=True), Column('street', String(30)), @@ -1133,7 +1156,8 @@ class DoubleParentOrphanTest(_base.MappedTest): assert True class CollectionAssignmentOrphanTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('table_a', metadata, Column('id', Integer, primary_key=True), Column('name', String(30))) @@ -1181,7 +1205,8 @@ class PartialFlushTest(_base.MappedTest): """test cascade behavior as it relates to object lists passed to flush(). """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("base", metadata, Column("id", Integer, primary_key=True), Column("descr", String(50)) @@ -1288,5 +1313,3 @@ class PartialFlushTest(_base.MappedTest): assert c1 not in sess.new assert c2 in sess.new -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/collection.py b/test/orm/test_collection.py index 23f643597..12ff25c46 100644 --- a/test/orm/collection.py +++ b/test/orm/test_collection.py @@ -1,17 +1,19 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import sys from operator import and_ import sqlalchemy.orm.collections as collections from sqlalchemy.orm.collections import collection -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa import util, exc as sa_exc -from testlib.sa.orm import create_session, mapper, relation, \ - attributes -from orm import _base -from testlib.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy import util, exc as sa_exc +from sqlalchemy.orm import create_session, mapper, relation, attributes +from test.orm import _base +from sqlalchemy.test.testing import eq_ class Canary(sa.orm.interfaces.AttributeExtension): def __init__(self): @@ -45,12 +47,14 @@ class CollectionsTest(_base.ORMTest): def __repr__(self): return str((id(self), self.a, self.b, self.c)) - def setUpAll(self): - attributes.register_class(self.Entity) + @classmethod + def setup_class(cls): + attributes.register_class(cls.Entity) - def tearDownAll(self): - attributes.unregister_class(self.Entity) - _base.ORMTest.tearDownAll(self) + @classmethod + def teardown_class(cls): + attributes.unregister_class(cls.Entity) + super(CollectionsTest, cls).teardown_class() _entity_id = 1 @@ -937,7 +941,7 @@ class CollectionsTest(_base.ORMTest): pass self.assert_(obj.attr is not real_dict) self.assert_('badkey' not in obj.attr) - self.assertEquals(set(collections.collection_adapter(obj.attr)), + eq_(set(collections.collection_adapter(obj.attr)), set([e2])) self.assert_(e3 not in canary.added) else: @@ -945,13 +949,13 @@ class CollectionsTest(_base.ORMTest): obj.attr = real_dict self.assert_(obj.attr is not real_dict) self.assert_('keyignored1' not in obj.attr) - self.assertEquals(set(collections.collection_adapter(obj.attr)), + eq_(set(collections.collection_adapter(obj.attr)), set([e3])) self.assert_(e2 in canary.removed) self.assert_(e3 in canary.added) obj.attr = typecallable() - self.assertEquals(list(collections.collection_adapter(obj.attr)), []) + eq_(list(collections.collection_adapter(obj.attr)), []) e4 = creator() try: @@ -1336,7 +1340,8 @@ class CollectionsTest(_base.ORMTest): class DictHelpersTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('parents', metadata, Column('id', Integer, primary_key=True), Column('label', String(128))) @@ -1348,7 +1353,8 @@ class DictHelpersTest(_base.MappedTest): Column('b', String(128)), Column('c', String(128))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Parent(_base.BasicEntity): def __init__(self, label=None): self.label = label @@ -1378,7 +1384,7 @@ class DictHelpersTest(_base.MappedTest): p = session.query(Parent).get(pid) - self.assertEquals(set(p.children.keys()), set(['foo', 'bar'])) + eq_(set(p.children.keys()), set(['foo', 'bar'])) cid = p.children['foo'].id collections.collection_adapter(p.children).append_with_event( @@ -1519,7 +1525,8 @@ class DictHelpersTest(_base.MappedTest): # remove if so class CustomCollectionsTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('sometable', metadata, Column('col1',Integer, primary_key=True), Column('data', String(30))) @@ -1830,5 +1837,3 @@ class InstrumentationTest(_base.ORMTest): instrumented = collections._instrument_class(Touchy) assert True -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/compile.py b/test/orm/test_compile.py index 7c9bed4ec..7a5b63615 100644 --- a/test/orm/compile.py +++ b/test/orm/test_compile.py @@ -1,15 +1,14 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy import exc as sa_exc from sqlalchemy.orm import * -from testlib import * -from orm import _base +from sqlalchemy.test import * +from test.orm import _base class CompileTest(_base.ORMTest): """test various mapper compilation scenarios""" - def tearDown(self): + def teardown(self): clear_mappers() def testone(self): @@ -182,5 +181,3 @@ class CompileTest(_base.ORMTest): assert str(e).index("Error creating backref") > -1 -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/cycles.py b/test/orm/test_cycles.py index 3e3636085..fe77b3601 100644 --- a/test/orm/cycles.py +++ b/test/orm/test_cycles.py @@ -5,19 +5,21 @@ T1<->T2, with o2m or m2o between them, and a third T3 with o2m/m2o to one/both T1/T2. """ -import testenv; testenv.configure_for_tests() -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, backref, create_session -from testlib.testing import eq_ -from testlib.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf -from orm import _base +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, create_session +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf +from test.orm import _base class SelfReferentialTest(_base.MappedTest): """A self-referential mapper with an additional list of child objects.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), @@ -29,7 +31,8 @@ class SelfReferentialTest(_base.MappedTest): Column('c1id', Integer, ForeignKey('t1.c1')), Column('data', String(20))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class C1(_base.BasicEntity): def __init__(self, data=None): self.data = data @@ -132,20 +135,23 @@ class SelfReferentialTest(_base.MappedTest): class SelfReferentialNoPKTest(_base.MappedTest): """A self-referential relationship that joins on a column other than the primary key column""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('item', metadata, Column('id', Integer, primary_key=True), Column('uuid', String(32), unique=True, nullable=False), Column('parent_uuid', String(32), ForeignKey('item.uuid'), nullable=True)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class TT(_base.BasicEntity): def __init__(self): self.uuid = hex(id(self)) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(TT, item, properties={ 'children': relation( TT, @@ -181,7 +187,8 @@ class SelfReferentialNoPKTest(_base.MappedTest): class InheritTestOne(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("parent", metadata, Column("id", Integer, primary_key=True), Column("parent_data", String(50)), @@ -199,7 +206,8 @@ class InheritTestOne(_base.MappedTest): nullable=False), Column("child2_data", String(50))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Parent(_base.BasicEntity): pass @@ -209,8 +217,9 @@ class InheritTestOne(_base.MappedTest): class Child2(Parent): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Parent, parent) mapper(Child1, child1, inherits=Parent) mapper(Child2, child2, inherits=Parent, properties=dict( @@ -250,7 +259,8 @@ class InheritTestTwo(_base.MappedTest): """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('a', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), @@ -266,7 +276,8 @@ class InheritTestTwo(_base.MappedTest): Column('aid', Integer, ForeignKey('a.id', use_alter=True, name="foo"))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_base.BasicEntity): pass @@ -297,7 +308,8 @@ class InheritTestTwo(_base.MappedTest): class BiDirectionalManyToOneTest(_base.MappedTest): run_define_tables = 'each' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), @@ -313,7 +325,8 @@ class BiDirectionalManyToOneTest(_base.MappedTest): Column('t1id', Integer, ForeignKey('t1.id'), nullable=False), Column('t2id', Integer, ForeignKey('t2.id'), nullable=False)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T1(_base.BasicEntity): pass class T2(_base.BasicEntity): @@ -321,8 +334,9 @@ class BiDirectionalManyToOneTest(_base.MappedTest): class T3(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(T1, t1, properties={ 't2':relation(T2, primaryjoin=t1.c.t2id == t2.c.id)}) mapper(T2, t2, properties={ @@ -385,7 +399,8 @@ class BiDirectionalOneToManyTest(_base.MappedTest): run_define_tables = 'each' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), @@ -397,7 +412,8 @@ class BiDirectionalOneToManyTest(_base.MappedTest): Column('c2', Integer, ForeignKey('t1.c1', use_alter=True, name='t1c1_fk'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class C1(_base.BasicEntity): pass @@ -434,7 +450,8 @@ class BiDirectionalOneToManyTest2(_base.MappedTest): run_define_tables = 'each' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('c1', Integer, primary_key=True), Column('c2', Integer, ForeignKey('t2.c1')), @@ -452,7 +469,8 @@ class BiDirectionalOneToManyTest2(_base.MappedTest): Column('data', String(20)), test_needs_autoincrement=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class C1(_base.BasicEntity): pass @@ -462,8 +480,9 @@ class BiDirectionalOneToManyTest2(_base.MappedTest): class C1Data(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(C2, t2, properties={ 'c1s': relation(C1, primaryjoin=t2.c.c1 == t1.c.c2, @@ -508,7 +527,8 @@ class OneToManyManyToOneTest(_base.MappedTest): """ run_define_tables = 'each' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('ball', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -522,7 +542,8 @@ class OneToManyManyToOneTest(_base.MappedTest): Column('favorite_ball_id', Integer, ForeignKey('ball.id')), Column('data', String(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Person(_base.BasicEntity): pass @@ -709,7 +730,8 @@ class OneToManyManyToOneTest(_base.MappedTest): class SelfReferentialPostUpdateTest(_base.MappedTest): """Post_update on a single self-referential mapper""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('node', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -721,7 +743,8 @@ class SelfReferentialPostUpdateTest(_base.MappedTest): Column('next_sibling_id', Integer, ForeignKey('node.id'), nullable=True)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Node(_base.BasicEntity): def __init__(self, path=''): self.path = path @@ -815,13 +838,15 @@ class SelfReferentialPostUpdateTest(_base.MappedTest): class SelfReferentialPostUpdateTest2(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("a_table", metadata, Column("id", Integer(), primary_key=True), Column("fui", String(128)), Column("b", Integer(), ForeignKey("a_table.id"))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_base.BasicEntity): pass @@ -858,5 +883,3 @@ class SelfReferentialPostUpdateTest2(_base.MappedTest): assert f2.foo is f1 -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/defaults.py b/test/orm/test_defaults.py index 8dc192519..b063780ac 100644 --- a/test/orm/defaults.py +++ b/test/orm/test_defaults.py @@ -1,16 +1,19 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base -from testlib.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base +from sqlalchemy.test.testing import eq_ class TriggerDefaultsTest(_base.MappedTest): __requires__ = ('row_triggers',) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): dt = Table('dt', metadata, Column('id', Integer, primary_key=True), Column('col1', String(20)), @@ -63,12 +66,14 @@ class TriggerDefaultsTest(_base.MappedTest): sa.DDL("DROP TRIGGER dt_up").execute_at('before-drop', dt) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Default(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Default, dt) @testing.resolve_artifact_names @@ -107,7 +112,8 @@ class TriggerDefaultsTest(_base.MappedTest): eq_(d1.col4, 'up') class ExcludedDefaultsTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): dt = Table('dt', metadata, Column('id', Integer, primary_key=True), Column('col1', String(20), default="hello"), @@ -125,5 +131,3 @@ class ExcludedDefaultsTest(_base.MappedTest): sess.flush() eq_(dt.select().execute().fetchall(), [(1, "hello")]) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/deprecations.py b/test/orm/test_deprecations.py index 483e8f556..00d64119e 100644 --- a/test/orm/deprecations.py +++ b/test/orm/test_deprecations.py @@ -5,11 +5,12 @@ modern (i.e. not deprecated) alternative to them. The tests snippets here can be migrated directly to the wiki, docs, etc. """ -import testenv; testenv.configure_for_tests() -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, func -from testlib.sa.orm import mapper, relation, create_session, sessionmaker -from orm import _base +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, sessionmaker +from test.orm import _base class QueryAlternativesTest(_base.MappedTest): @@ -44,7 +45,8 @@ class QueryAlternativesTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users_table', metadata, Column('id', Integer, primary_key=True), Column('name', String(64))) @@ -56,21 +58,24 @@ class QueryAlternativesTest(_base.MappedTest): Column('purpose', String(16)), Column('bounces', Integer, default=0)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.BasicEntity): pass class Address(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users_table, properties=dict( addresses=relation(Address, backref='user'), )) mapper(Address, addresses_table) - def fixtures(self): + @classmethod + def fixtures(cls): return dict( users_table=( ('id', 'name'), @@ -479,5 +484,3 @@ class QueryAlternativesTest(_base.MappedTest): assert len(users) == 1 and users[0].name == 'ed' -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/dynamic.py b/test/orm/test_dynamic.py index 3bd94b7c0..f2089a435 100644 --- a/test/orm/dynamic.py +++ b/test/orm/test_dynamic.py @@ -1,13 +1,15 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import operator from sqlalchemy.orm import dynamic_loader, backref -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, desc, select, func -from testlib.sa.orm import mapper, relation, create_session, Query, attributes +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, desc, select, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, Query, attributes from sqlalchemy.orm.dynamic import AppenderMixin -from testlib.testing import eq_ -from testlib.compat import _function_named -from orm import _base, _fixtures +from sqlalchemy.test.testing import eq_ +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures class DynamicTest(_fixtures.FixtureTest): @@ -281,7 +283,7 @@ class SessionTest(_fixtures.FixtureTest): sess.flush() from sqlalchemy.orm import attributes - self.assertEquals(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], [])) + eq_(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], [])) sess.expunge_all() @@ -452,7 +454,7 @@ class SessionTest(_fixtures.FixtureTest): sess.close() -def create_backref_test(autoflush, saveuser): +def _create_backref_test(autoflush, saveuser): @testing.resolve_artifact_names def test_backref(self): @@ -487,17 +489,18 @@ def create_backref_test(autoflush, saveuser): sess.flush() self.assert_(list(u.addresses) == []) - test_backref = _function_named( + test_backref = function_named( test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""), (saveuser and "_saveuser" or "_savead"))) setattr(SessionTest, test_backref.__name__, test_backref) for autoflush in (False, True): for saveuser in (False, True): - create_backref_test(autoflush, saveuser) + _create_backref_test(autoflush, saveuser) class DontDereferenceTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(40)), @@ -509,8 +512,9 @@ class DontDereferenceTest(_base.MappedTest): Column('email_address', String(100), nullable=False), Column('user_id', Integer, ForeignKey('users.id'))) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class User(_base.ComparableEntity): pass @@ -555,5 +559,3 @@ class DontDereferenceTest(_base.MappedTest): eq_(query3(), [Address(email_address='joe@joesdomain.example')]) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/eager_relations.py b/test/orm/test_eager_relations.py index 87c2442cc..384e0472f 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -1,13 +1,16 @@ """basic tests of eager loaded attributes""" -import testenv; testenv.configure_for_tests() -from testlib import sa, testing +from sqlalchemy.test.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing from sqlalchemy.orm import eagerload, deferred, undefer -from testlib.sa import Table, Column, Integer, String, Date, ForeignKey, and_, select, func -from testlib.sa.orm import mapper, relation, create_session, lazyload, aliased -from testlib.testing import eq_ -from testlib.assertsql import CompiledSQL -from orm import _base, _fixtures +from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, lazyload, aliased +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.assertsql import CompiledSQL +from test.orm import _base, _fixtures import datetime class EagerTest(_fixtures.FixtureTest): @@ -23,7 +26,7 @@ class EagerTest(_fixtures.FixtureTest): q = sess.query(User) assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all() - self.assertEquals(self.static.user_address_result, q.order_by(User.id).all()) + eq_(self.static.user_address_result, q.order_by(User.id).all()) @testing.resolve_artifact_names def test_late_compile(self): @@ -287,7 +290,7 @@ class EagerTest(_fixtures.FixtureTest): assert sa.orm.class_mapper(Address).get_property('user').lazy is False sess = create_session() - self.assertEquals(self.static.user_address_result, sess.query(User).order_by(User.id).all()) + eq_(self.static.user_address_result, sess.query(User).order_by(User.id).all()) @testing.resolve_artifact_names def test_double(self): @@ -615,13 +618,13 @@ class EagerTest(_fixtures.FixtureTest): def go(): o1 = sess.query(Order).options(lazyload('address')).filter(Order.id==5).one() - self.assertEquals(o1.address, None) + eq_(o1.address, None) self.assert_sql_count(testing.db, go, 2) sess.expunge_all() def go(): o1 = sess.query(Order).filter(Order.id==5).one() - self.assertEquals(o1.address, None) + eq_(o1.address, None) self.assert_sql_count(testing.db, go, 1) @testing.resolve_artifact_names @@ -817,7 +820,8 @@ class AddEntityTest(_fixtures.FixtureTest): self.assert_sql_count(testing.db, go, 1) class OrderBySecondaryTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('m2m', metadata, Column('id', Integer, primary_key=True), Column('aid', Integer, ForeignKey('a.id')), @@ -830,7 +834,8 @@ class OrderBySecondaryTest(_base.MappedTest): Column('id', Integer, primary_key=True), Column('data', String(50))) - def fixtures(self): + @classmethod + def fixtures(cls): return dict( a=(('id', 'data'), (1, 'a1'), @@ -865,7 +870,8 @@ class OrderBySecondaryTest(_base.MappedTest): class SelfReferentialEagerTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('nodes', metadata, Column('id', Integer, sa.Sequence('node_id_seq', optional=True), primary_key=True), @@ -980,7 +986,7 @@ class SelfReferentialEagerTest(_base.MappedTest): sess.expunge_all() def go(): - self.assertEquals( + eq_( Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), sess.query(Node).order_by(Node.id).first(), ) @@ -1079,7 +1085,8 @@ class SelfReferentialEagerTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 3) class MixedSelfReferentialEagerTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('a_table', metadata, Column('id', Integer, primary_key=True) ) @@ -1091,8 +1098,9 @@ class MixedSelfReferentialEagerTest(_base.MappedTest): Column('parent_b2_id', Integer, ForeignKey('b_table.id'))) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class A(_base.ComparableEntity): pass class B(_base.ComparableEntity): @@ -1113,8 +1121,9 @@ class MixedSelfReferentialEagerTest(_base.MappedTest): ) }); + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): a_table.insert().execute(dict(id=1), dict(id=2), dict(id=3)) b_table.insert().execute( dict(id=1, parent_a_id=2, parent_b1_id=None, parent_b2_id=None), @@ -1149,7 +1158,8 @@ class MixedSelfReferentialEagerTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 1) class SelfReferentialM2MEagerTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('widget', metadata, Column('id', Integer, primary_key=True), Column('name', sa.Unicode(40), nullable=False, unique=True), @@ -1189,8 +1199,9 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): run_inserts = 'once' run_deletes = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users, properties={ 'addresses':relation(Address, backref='user'), 'orders':relation(Order, backref='user'), # o2m, m2o @@ -1323,7 +1334,8 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): class CyclicalInheritingEagerTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('c1', Integer, primary_key=True), Column('c2', String(30)), @@ -1361,7 +1373,8 @@ class CyclicalInheritingEagerTest(_base.MappedTest): create_session().query(SubT).all() class SubqueryTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users_table', metadata, Column('id', Integer, primary_key=True), Column('name', String(16)) @@ -1449,7 +1462,8 @@ class CorrelatedSubqueryTest(_base.MappedTest): """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): users = Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(50)) @@ -1460,8 +1474,9 @@ class CorrelatedSubqueryTest(_base.MappedTest): Column('date', Date), Column('user_id', Integer, ForeignKey('users.id'))) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): users.insert().execute( {'id':1, 'name':'user1'}, {'id':2, 'name':'user2'}, @@ -1591,5 +1606,3 @@ class CorrelatedSubqueryTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 1) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/evaluator.py b/test/orm/test_evaluator.py index 3527c93d7..af6a3f89e 100644 --- a/test/orm/evaluator.py +++ b/test/orm/test_evaluator.py @@ -1,10 +1,12 @@ """Evluating SQL expressions on ORM objects""" -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, String, Integer, select -from testlib.sa.orm import mapper, create_session -from testlib.testing import eq_ -from orm import _base +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import String, Integer, select +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base from sqlalchemy import and_, or_, not_ from sqlalchemy.orm import evaluator @@ -20,17 +22,20 @@ def eval_eq(clause, testcases=None): return testeval class EvaluateTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(64))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users) @testing.resolve_artifact_names @@ -90,5 +95,3 @@ class EvaluateTest(_base.MappedTest): (User(id=None, name=None), None), ]) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/expire.py b/test/orm/test_expire.py index c11fb69df..659349897 100644 --- a/test/orm/expire.py +++ b/test/orm/test_expire.py @@ -1,11 +1,14 @@ """Attribute/instance expiration, deferral of attributes, etc.""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import gc -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, exc as sa_exc -from testlib.sa.orm import mapper, relation, create_session, attributes, deferred -from orm import _base, _fixtures +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, exc as sa_exc +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, attributes, deferred +from test.orm import _base, _fixtures class ExpireTest(_fixtures.FixtureTest): @@ -56,7 +59,7 @@ class ExpireTest(_fixtures.FixtureTest): u = s.query(User).get(7) s.expunge_all() - self.assertRaisesMessage(sa.exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u) + assert_raises_message(sa.exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u) @testing.resolve_artifact_names def test_get_refreshes(self): @@ -69,7 +72,7 @@ class ExpireTest(_fixtures.FixtureTest): u = s.query(User).get(10) # get() refreshes self.assert_sql_count(testing.db, go, 1) def go(): - self.assertEquals(u.name, 'chuck') # attributes unexpired + eq_(u.name, 'chuck') # attributes unexpired self.assert_sql_count(testing.db, go, 0) def go(): u = s.query(User).get(10) # expire flag reset, so not expired @@ -86,7 +89,7 @@ class ExpireTest(_fixtures.FixtureTest): # add it back s.add(u) # nope, raises ObjectDeletedError - self.assertRaises(sa.orm.exc.ObjectDeletedError, getattr, u, 'name') + assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u, 'name') # do a get()/remove u from session again assert s.query(User).get(10) is None @@ -97,7 +100,7 @@ class ExpireTest(_fixtures.FixtureTest): assert u in s # but now its back, rollback has occured, the _remove_newly_deleted # is reverted - self.assertEquals(u.name, 'chuck') + eq_(u.name, 'chuck') @testing.resolve_artifact_names def test_deferred(self): @@ -122,7 +125,7 @@ class ExpireTest(_fixtures.FixtureTest): s = create_session(autoflush=True, autocommit=False) u = s.query(User).get(8) adlist = u.addresses - self.assertEquals(adlist, [ + eq_(adlist, [ Address(email_address='ed@bettyboop.com'), Address(email_address='ed@lala.com'), Address(email_address='ed@wood.com'), @@ -130,7 +133,7 @@ class ExpireTest(_fixtures.FixtureTest): a1 = u.addresses[2] a1.email_address = 'aaaaa' s.expire(u, ['addresses']) - self.assertEquals(u.addresses, [ + eq_(u.addresses, [ Address(email_address='aaaaa'), Address(email_address='ed@bettyboop.com'), Address(email_address='ed@lala.com'), @@ -146,10 +149,10 @@ class ExpireTest(_fixtures.FixtureTest): mapper(Address, addresses) s = create_session(autoflush=True, autocommit=False) u = s.query(User).get(8) - self.assertRaisesMessage(sa_exc.InvalidRequestError, "properties specified for refresh", s.refresh, u, ['addresses']) + assert_raises_message(sa_exc.InvalidRequestError, "properties specified for refresh", s.refresh, u, ['addresses']) # in contrast to a regular query with no columns - self.assertRaisesMessage(sa_exc.InvalidRequestError, "no columns with which to SELECT", s.query().all) + assert_raises_message(sa_exc.InvalidRequestError, "no columns with which to SELECT", s.query().all) @testing.resolve_artifact_names def test_refresh_cancels_expire(self): @@ -161,7 +164,7 @@ class ExpireTest(_fixtures.FixtureTest): def go(): u = s.query(User).get(7) - self.assertEquals(u.name, 'jack') + eq_(u.name, 'jack') self.assert_sql_count(testing.db, go, 0) @testing.resolve_artifact_names @@ -187,7 +190,7 @@ class ExpireTest(_fixtures.FixtureTest): sess.expire(u, attribute_names=['name']) sess.expunge(u) - self.assertRaises(sa.exc.UnboundExecutionError, getattr, u, 'name') + assert_raises(sa.exc.UnboundExecutionError, getattr, u, 'name') @testing.resolve_artifact_names def test_pending_raises(self): @@ -197,7 +200,7 @@ class ExpireTest(_fixtures.FixtureTest): sess = create_session() u = User(id=15) sess.add(u) - self.assertRaises(sa.exc.InvalidRequestError, sess.expire, u, ['name']) + assert_raises(sa.exc.InvalidRequestError, sess.expire, u, ['name']) @testing.resolve_artifact_names def test_no_instance_key(self): @@ -668,14 +671,15 @@ class ExpireTest(_fixtures.FixtureTest): userlist = sess.query(User).order_by(User.id).all() u = userlist[1] - self.assertEquals(self.static.user_address_result, userlist) + eq_(self.static.user_address_result, userlist) assert len(list(sess)) == 9 class PolymorphicExpireTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, Person, Engineer people = Table('people', metadata, @@ -690,14 +694,16 @@ class PolymorphicExpireTest(_base.MappedTest): Column('status', String(30)), ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Person(_base.ComparableEntity): pass class Engineer(Person): pass + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): people.insert().execute( {'person_id':1, 'name':'person1', 'type':'person'}, {'person_id':2, 'name':'engineer1', 'type':'engineer'}, @@ -745,7 +751,7 @@ class PolymorphicExpireTest(_base.MappedTest): assert e1.status == 'new engineer' assert e2.status == 'old engineer' self.assert_sql_count(testing.db, go, 2) - self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1'])) + eq_(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1'])) class ExpiredPendingTest(_fixtures.FixtureTest): run_define_tables = 'once' @@ -837,7 +843,7 @@ class RefreshTest(_fixtures.FixtureTest): s = create_session() u = s.query(User).get(7) s.expunge_all() - self.assertRaisesMessage(sa.exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u)) + assert_raises_message(sa.exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u)) @testing.resolve_artifact_names def test_refresh_expired(self): @@ -908,5 +914,3 @@ class RefreshTest(_fixtures.FixtureTest): s.refresh(u) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/extendedattr.py b/test/orm/test_extendedattr.py index aec6c181f..e0c64bf64 100644 --- a/test/orm/extendedattr.py +++ b/test/orm/test_extendedattr.py @@ -1,4 +1,4 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import pickle from sqlalchemy import util import sqlalchemy.orm.attributes as attributes @@ -6,8 +6,8 @@ from sqlalchemy.orm.collections import collection from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import InstrumentationManager -from testlib import * -from orm import _base +from sqlalchemy.test import * +from test.orm import _base class MyTypesManager(InstrumentationManager): @@ -100,7 +100,8 @@ class MyClass(object): del self._goofy_dict[key] class UserDefinedExtensionTest(_base.ORMTest): - def tearDownAll(self): + @classmethod + def teardown_class(cls): clear_mappers() attributes._install_lookup_strategy(util.symbol('native')) @@ -161,30 +162,30 @@ class UserDefinedExtensionTest(_base.ORMTest): assert Foo in attributes.instrumentation_registry._state_finders f = Foo() attributes.instance_state(f).expire_attributes(None) - self.assertEquals(f.a, "this is a") - self.assertEquals(f.b, 12) + eq_(f.a, "this is a") + eq_(f.b, 12) f.a = "this is some new a" attributes.instance_state(f).expire_attributes(None) - self.assertEquals(f.a, "this is a") - self.assertEquals(f.b, 12) + eq_(f.a, "this is a") + eq_(f.b, 12) attributes.instance_state(f).expire_attributes(None) f.a = "this is another new a" - self.assertEquals(f.a, "this is another new a") - self.assertEquals(f.b, 12) + eq_(f.a, "this is another new a") + eq_(f.b, 12) attributes.instance_state(f).expire_attributes(None) - self.assertEquals(f.a, "this is a") - self.assertEquals(f.b, 12) + eq_(f.a, "this is a") + eq_(f.b, 12) del f.a - self.assertEquals(f.a, None) - self.assertEquals(f.b, 12) + eq_(f.a, None) + eq_(f.b, 12) attributes.instance_state(f).commit_all(attributes.instance_dict(f)) - self.assertEquals(f.a, None) - self.assertEquals(f.b, 12) + eq_(f.a, None) + eq_(f.b, 12) def test_inheritance(self): """tests that attributes are polymorphic""" @@ -265,27 +266,27 @@ class UserDefinedExtensionTest(_base.ORMTest): f1 = Foo() f1.name = 'f1' - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], (), ())) + eq_(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], (), ())) b1 = Bar() b1.name = 'b1' f1.bars.append(b1) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], [])) + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], [])) attributes.instance_state(f1).commit_all(attributes.instance_dict(f1)) attributes.instance_state(b1).commit_all(attributes.instance_dict(b1)) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ())) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ())) + eq_(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ())) + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ())) f1.name = 'f1mod' b2 = Bar() b2.name = 'b2' f1.bars.append(b2) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1'])) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], [])) + eq_(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1'])) + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], [])) f1.bars.remove(b1) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1])) + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1])) def test_null_instrumentation(self): class Foo(MyBaseClass): @@ -311,9 +312,9 @@ class UserDefinedExtensionTest(_base.ORMTest): assert attributes.manager_of_class(None) is None assert attributes.instance_state(k) is not None - self.assertRaises((AttributeError, KeyError), + assert_raises((AttributeError, KeyError), attributes.instance_state, u) - self.assertRaises((AttributeError, KeyError), + assert_raises((AttributeError, KeyError), attributes.instance_state, None) diff --git a/test/orm/generative.py b/test/orm/test_generative.py index 995236741..0efc1814e 100644 --- a/test/orm/generative.py +++ b/test/orm/test_generative.py @@ -1,28 +1,34 @@ -import testenv; testenv.configure_for_tests() -from testlib import testing, sa -from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, func +from sqlalchemy.test.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, MetaData, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column from sqlalchemy.orm import mapper, relation, create_session -from testlib.testing import eq_ -from orm import _base, _fixtures +from sqlalchemy.test.testing import eq_ +from test.orm import _base, _fixtures class GenerativeQueryTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foo', metadata, Column('id', Integer, sa.Sequence('foo_id_seq'), primary_key=True), Column('bar', Integer), Column('range', Integer)) - def fixtures(self): + @classmethod + def fixtures(cls): rows = tuple([(i, i % 10) for i in range(100)]) foo_data = (('bar', 'range'),) + rows return dict(foo=foo_data) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Foo(_base.BasicEntity): pass @@ -131,7 +137,8 @@ class GenerativeQueryTest(_base.MappedTest): class GenerativeTest2(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('Table1', metadata, Column('id', Integer, primary_key=True)) Table('Table2', metadata, @@ -139,8 +146,9 @@ class GenerativeTest2(_base.MappedTest): primary_key=True), Column('num', Integer, primary_key=True)) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Obj1(_base.BasicEntity): pass class Obj2(_base.BasicEntity): @@ -149,7 +157,8 @@ class GenerativeTest2(_base.MappedTest): mapper(Obj1, Table1) mapper(Obj2, Table2) - def fixtures(self): + @classmethod + def fixtures(cls): return dict( Table1=(('id',), (1,), @@ -182,8 +191,9 @@ class RelationsTest(_fixtures.FixtureTest): run_inserts = 'once' run_deletes = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users, properties={ 'orders':relation(mapper(Order, orders, properties={ 'addresses':relation(mapper(Address, addresses))}))}) @@ -232,7 +242,8 @@ class RelationsTest(_fixtures.FixtureTest): class CaseSensitiveTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('Table1', metadata, Column('ID', Integer, primary_key=True)) Table('Table2', metadata, @@ -240,8 +251,9 @@ class CaseSensitiveTest(_base.MappedTest): primary_key=True), Column('NUM', Integer, primary_key=True)) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Obj1(_base.BasicEntity): pass class Obj2(_base.BasicEntity): @@ -250,7 +262,8 @@ class CaseSensitiveTest(_base.MappedTest): mapper(Obj1, Table1) mapper(Obj2, Table2) - def fixtures(self): + @classmethod + def fixtures(cls): return dict( Table1=(('ID',), (1,), @@ -272,8 +285,6 @@ class CaseSensitiveTest(_base.MappedTest): res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)) assert res.count() == 3 res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)).distinct() - self.assertEqual(res.count(), 1) + eq_(res.count(), 1) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/instrumentation.py b/test/orm/test_instrumentation.py index fd15420d0..b4c8f8601 100644 --- a/test/orm/instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -1,11 +1,13 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa -from testlib.sa import MetaData, Table, Column, Integer, ForeignKey, util -from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers -from testlib.testing import eq_, ne_ -from testlib.compat import _function_named -from orm import _base +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy import MetaData, Integer, ForeignKey, util +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers +from sqlalchemy.test.testing import eq_, ne_ +from sqlalchemy.util import function_named +from test.orm import _base def modifies_instrumentation_finders(fn): @@ -16,7 +18,7 @@ def modifies_instrumentation_finders(fn): finally: del attributes.instrumentation_finders[:] attributes.instrumentation_finders.extend(pristine) - return _function_named(decorated, fn.func_name) + return function_named(decorated, fn.func_name) def with_lookup_strategy(strategy): def decorate(fn): @@ -26,7 +28,7 @@ def with_lookup_strategy(strategy): return fn(*args, **kw) finally: attributes._install_lookup_strategy(sa.util.symbol('native')) - return _function_named(wrapped, fn.func_name) + return function_named(wrapped, fn.func_name) return decorate @@ -459,10 +461,10 @@ class MapperInitTest(_base.ORMTest): m = mapper(A, self.fixture()) # B is not mapped in the current implementation - self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, B) + assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, B) # C is not mapped in the current implementation - self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, C) + assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, C) class InstrumentationCollisionTest(_base.ORMTest): def test_none(self): @@ -486,7 +488,7 @@ class InstrumentationCollisionTest(_base.ORMTest): class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) - self.assertRaises(TypeError, attributes.register_class, B) + assert_raises(TypeError, attributes.register_class, B) def test_single_up(self): @@ -497,7 +499,7 @@ class InstrumentationCollisionTest(_base.ORMTest): class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) attributes.register_class(B) - self.assertRaises(TypeError, attributes.register_class, A) + assert_raises(TypeError, attributes.register_class, A) def test_diamond_b1(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -508,7 +510,7 @@ class InstrumentationCollisionTest(_base.ORMTest): __sa_instrumentation_manager__ = mgr_factory class C(object): pass - self.assertRaises(TypeError, attributes.register_class, B1) + assert_raises(TypeError, attributes.register_class, B1) def test_diamond_b2(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -519,7 +521,7 @@ class InstrumentationCollisionTest(_base.ORMTest): __sa_instrumentation_manager__ = mgr_factory class C(object): pass - self.assertRaises(TypeError, attributes.register_class, B2) + assert_raises(TypeError, attributes.register_class, B2) def test_diamond_c_b(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -531,7 +533,7 @@ class InstrumentationCollisionTest(_base.ORMTest): class C(object): pass attributes.register_class(C) - self.assertRaises(TypeError, attributes.register_class, B1) + assert_raises(TypeError, attributes.register_class, B1) class OnLoadTest(_base.ORMTest): @@ -557,7 +559,8 @@ class OnLoadTest(_base.ORMTest): finally: del A - def tearDownAll(self): + @classmethod + def teardown_class(cls): clear_mappers() attributes._install_lookup_strategy(util.symbol('native')) @@ -593,7 +596,7 @@ class NativeInstrumentationTest(_base.ORMTest): sa = attributes.ClassManager.STATE_ATTR ma = attributes.ClassManager.MANAGER_ATTR - fails = lambda method, attr: self.assertRaises( + fails = lambda method, attr: assert_raises( KeyError, getattr(manager, method), attr, property()) fails('install_member', sa) @@ -609,7 +612,7 @@ class NativeInstrumentationTest(_base.ORMTest): class T(object): pass - self.assertRaises(KeyError, mapper, T, t) + assert_raises(KeyError, mapper, T, t) @with_lookup_strategy(sa.util.symbol('native')) def test_mapped_managerattr(self): @@ -618,7 +621,7 @@ class NativeInstrumentationTest(_base.ORMTest): Column(attributes.ClassManager.MANAGER_ATTR, Integer)) class T(object): pass - self.assertRaises(KeyError, mapper, T, t) + assert_raises(KeyError, mapper, T, t) class MiscTest(_base.ORMTest): @@ -761,5 +764,3 @@ class FinderTest(_base.ORMTest): eq_(type(attributes.manager_of_class(A)), attributes.ClassManager) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/lazy_relations.py b/test/orm/test_lazy_relations.py index b5c3b3669..819f29911 100644 --- a/test/orm/lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -1,14 +1,17 @@ """basic tests of lazy loaded attributes""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import assert_raises, assert_raises_message import datetime from sqlalchemy import exc as sa_exc from sqlalchemy.orm import attributes -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from testlib.testing import eq_ -from orm import _base, _fixtures +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base, _fixtures class LazyTest(_fixtures.FixtureTest): @@ -35,7 +38,7 @@ class LazyTest(_fixtures.FixtureTest): q = sess.query(User) u = q.filter(users.c.id == 7).first() sess.expunge(u) - self.assertRaises(sa_exc.InvalidRequestError, getattr, u, 'addresses') + assert_raises(sa_exc.InvalidRequestError, getattr, u, 'addresses') @testing.resolve_artifact_names def test_orderby(self): @@ -363,6 +366,7 @@ class M2OGetTest(_fixtures.FixtureTest): class CorrelatedTest(_base.MappedTest): + @classmethod def define_tables(self, meta): Table('user_t', meta, Column('id', Integer, primary_key=True), @@ -373,8 +377,9 @@ class CorrelatedTest(_base.MappedTest): Column('date', sa.Date), Column('user_id', Integer, ForeignKey('user_t.id'))) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): user_t.insert().execute( {'id':1, 'name':'user1'}, {'id':2, 'name':'user2'}, @@ -412,5 +417,3 @@ class CorrelatedTest(_base.MappedTest): ]) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/lazytest1.py b/test/orm/test_lazytest1.py index 5ebb8feeb..f76cb3203 100644 --- a/test/orm/lazytest1.py +++ b/test/orm/test_lazytest1.py @@ -1,12 +1,15 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base class LazyTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('infos', metadata, Column('pk', Integer, primary_key=True), Column('info', String(128))) @@ -25,8 +28,9 @@ class LazyTest(_base.MappedTest): Column('start', Integer), Column('finish', Integer)) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): infos.insert().execute( {'pk':1, 'info':'pk_1_info'}, {'pk':2, 'info':'pk_2_info'}, @@ -86,5 +90,3 @@ class LazyTest(_base.MappedTest): assert len(info.rels[0].datas) == 3 -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/manytomany.py b/test/orm/test_manytomany.py index 23af3bd1f..dcd547f80 100644 --- a/test/orm/manytomany.py +++ b/test/orm/test_manytomany.py @@ -1,12 +1,16 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base class M2MTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('place', metadata, Column('place_id', Integer, sa.Sequence('pid_seq', optional=True), primary_key=True), @@ -40,7 +44,8 @@ class M2MTest(_base.MappedTest): Column('pl1_id', Integer, ForeignKey('place.place_id')), Column('pl2_id', Integer, ForeignKey('place.place_id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Place(_base.BasicEntity): def __init__(self, name=None): self.name = name @@ -70,7 +75,7 @@ class M2MTest(_base.MappedTest): mapper(Transition, transition, properties={ 'places':relation(Place, secondary=place_input, backref='transitions') }) - self.assertRaisesMessage(sa.exc.ArgumentError, "Error creating backref", + assert_raises_message(sa.exc.ArgumentError, "Error creating backref", sa.orm.compile_mappers) @testing.resolve_artifact_names @@ -187,7 +192,8 @@ class M2MTest(_base.MappedTest): class M2MTest2(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('student', metadata, Column('name', String(20), primary_key=True)) @@ -200,7 +206,8 @@ class M2MTest2(_base.MappedTest): Column('course_id', String(20), ForeignKey('course.name'), primary_key=True)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Student(_base.BasicEntity): def __init__(self, name=''): self.name = name @@ -248,7 +255,7 @@ class M2MTest2(_base.MappedTest): s1.courses.append(c1) s1.courses.append(c1) sess.add(s1) - self.assertRaises(sa.exc.DBAPIError, sess.flush) + assert_raises(sa.exc.DBAPIError, sess.flush) @testing.resolve_artifact_names def test_delete(self): @@ -274,7 +281,8 @@ class M2MTest2(_base.MappedTest): assert enroll.count().scalar() == 0 class M2MTest3(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('c', metadata, Column('c1', Integer, primary_key = True), Column('c2', String(20))) @@ -320,5 +328,3 @@ class M2MTest3(_base.MappedTest): # how about some data/inserts/queries/assertions for this one -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/mapper.py b/test/orm/test_mapper.py index 13e02a38a..025b96424 100644 --- a/test/orm/mapper.py +++ b/test/orm/test_mapper.py @@ -1,14 +1,16 @@ """General mapper operations with an emphasis on selecting/loading.""" -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, func -from testlib.sa.engine import default -from testlib.sa.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased -from testlib.sa.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property -from testlib.testing import eq_, AssertsCompiledSQL -import pickleable -from orm import _base, _fixtures +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing, pickleable +from sqlalchemy import MetaData, Integer, String, ForeignKey, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.engine import default +from sqlalchemy.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased +from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property +from sqlalchemy.test.testing import eq_, AssertsCompiledSQL +from test.orm import _base, _fixtures class MapperTest(_fixtures.FixtureTest): @@ -22,7 +24,7 @@ class MapperTest(_fixtures.FixtureTest): properties={ 'addresses':relation(Address, backref='email_address') }) - self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers) + assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers) @testing.resolve_artifact_names def test_update_attr_keys(self): @@ -74,14 +76,14 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_prop_accessor(self): mapper(User, users) - self.assertRaises(NotImplementedError, + assert_raises(NotImplementedError, getattr, sa.orm.class_mapper(User), 'properties') @testing.resolve_artifact_names def test_bad_cascade(self): mapper(Address, addresses) - self.assertRaises(sa.exc.ArgumentError, + assert_raises(sa.exc.ArgumentError, relation, Address, cascade="fake, all, delete-orphan") @testing.resolve_artifact_names @@ -93,7 +95,7 @@ class MapperTest(_fixtures.FixtureTest): }) hasattr(Address.user, 'property') - self.assertRaisesMessage(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers) + assert_raises_message(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers) @testing.resolve_artifact_names def test_column_prefix(self): @@ -111,7 +113,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_no_pks_1(self): s = sa.select([users.c.name]).alias('foo') - self.assertRaises(sa.exc.ArgumentError, mapper, User, s) + assert_raises(sa.exc.ArgumentError, mapper, User, s) @testing.emits_warning( 'mapper Mapper|User|Select object creating an alias for ' @@ -119,7 +121,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_no_pks_2(self): s = sa.select([users.c.name]) - self.assertRaises(sa.exc.ArgumentError, mapper, User, s) + assert_raises(sa.exc.ArgumentError, mapper, User, s) @testing.resolve_artifact_names def test_recompile_on_other_mapper(self): @@ -167,9 +169,9 @@ class MapperTest(_fixtures.FixtureTest): create_session).extension) sess = create_session() - self.assertRaises(TypeError, Foo, 'one', _sa_session=sess) + assert_raises(TypeError, Foo, 'one', _sa_session=sess) eq_(len(list(sess)), 0) - self.assertRaises(TypeError, Foo, 'one') + assert_raises(TypeError, Foo, 'one') Foo('one', 'two', _sa_session=sess) eq_(len(list(sess)), 1) @@ -197,7 +199,7 @@ class MapperTest(_fixtures.FixtureTest): raise Exception("this exception should be stated as a warning") sess.expunge = bad_expunge - self.assertRaises(sa.exc.SAWarning, Foo, _sa_session=sess) + assert_raises(sa.exc.SAWarning, Foo, _sa_session=sess) @testing.resolve_artifact_names def test_constructor_exc_2(self): @@ -211,8 +213,8 @@ class MapperTest(_fixtures.FixtureTest): mapper(Foo, users) mapper(Bar, addresses) - self.assertRaises(TypeError, Foo, x=5) - self.assertRaises(TypeError, Bar, x=5) + assert_raises(TypeError, Foo, x=5) + assert_raises(TypeError, Bar, x=5) @testing.resolve_artifact_names def test_props(self): @@ -499,7 +501,7 @@ class MapperTest(_fixtures.FixtureTest): # excluding the discriminator column is currently not allowed class Foo(Person): pass - self.assertRaises(sa.exc.InvalidRequestError, mapper, Foo, inherits=Person, polymorphic_identity='foo', exclude_properties=('type',) ) + assert_raises(sa.exc.InvalidRequestError, mapper, Foo, inherits=Person, polymorphic_identity='foo', exclude_properties=('type',) ) @testing.resolve_artifact_names def test_mapping_to_join(self): @@ -643,7 +645,7 @@ class MapperTest(_fixtures.FixtureTest): properties=dict( name=relation(mapper(Address, addresses)))) - self.assertRaises(sa.exc.ArgumentError, go) + assert_raises(sa.exc.ArgumentError, go) @testing.resolve_artifact_names def test_override_2(self): @@ -739,7 +741,7 @@ class MapperTest(_fixtures.FixtureTest): mapper(User, users, properties={ 'not_name':synonym('_name', map_column=True)}) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, ("Can't compile synonym '_name': no column on table " "'users' named 'not_name'"), @@ -834,7 +836,7 @@ class MapperTest(_fixtures.FixtureTest): eq_(User.uc_name.method1(), "method1") eq_(User.uc_name.method2('x'), "method2") - self.assertRaisesMessage( + assert_raises_message( AttributeError, "Neither 'extendedproperty' object nor 'UCComparator' object has an attribute 'nonexistent'", getattr, User.uc_name, 'nonexistent') @@ -879,7 +881,7 @@ class MapperTest(_fixtures.FixtureTest): 'name':sa.orm.column_property(users.c.name, comparator_factory=MyComparator) }) - self.assertRaisesMessage( + assert_raises_message( AttributeError, "Neither 'InstrumentedAttribute' object nor 'MyComparator' object has an attribute 'nonexistent'", getattr, User.name, "nonexistent") @@ -966,7 +968,7 @@ class MapperTest(_fixtures.FixtureTest): 'addresses':relation(Address) }) - self.assertRaises(sa.orm.exc.UnmappedClassError, sa.orm.compile_mappers) + assert_raises(sa.orm.exc.UnmappedClassError, sa.orm.compile_mappers) @testing.resolve_artifact_names def test_oldstyle_mixin(self): @@ -1148,8 +1150,9 @@ class OptionsTest(_fixtures.FixtureTest): class DeepOptionsTest(_fixtures.FixtureTest): + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Keyword, keywords) mapper(Item, items, properties=dict( @@ -1204,7 +1207,7 @@ class DeepOptionsTest(_fixtures.FixtureTest): def test_deep_options_4(self): sess = create_session() - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, r"Can't find entity Mapper\|Order\|orders in Query. " r"Current list: \['Mapper\|User\|users'\]", @@ -1232,7 +1235,7 @@ class ValidatorTest(_fixtures.FixtureTest): sess = create_session() u1 = User(name='ed') eq_(u1.name, 'ed modified') - self.assertRaises(AssertionError, setattr, u1, "name", "fred") + assert_raises(AssertionError, setattr, u1, "name", "fred") eq_(u1.name, 'ed modified') sess.add(u1) sess.flush() @@ -1252,7 +1255,7 @@ class ValidatorTest(_fixtures.FixtureTest): mapper(Address, addresses) sess = create_session() u1 = User(name='edward') - self.assertRaises(AssertionError, u1.addresses.append, Address(email_address='noemail')) + assert_raises(AssertionError, u1.addresses.append, Address(email_address='noemail')) u1.addresses.append(Address(id=15, email_address='foo@bar.com')) sess.add(u1) sess.flush() @@ -1629,7 +1632,8 @@ class DeferredTest(_fixtures.FixtureTest): eq_(item.description, 'item 4') class DeferredPopulationTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("thing", metadata, Column("id", Integer, primary_key=True), Column("name", String(20))) @@ -1639,16 +1643,18 @@ class DeferredPopulationTest(_base.MappedTest): Column("thing_id", Integer, ForeignKey("thing.id")), Column("name", String(20))) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Human(_base.BasicEntity): pass class Thing(_base.BasicEntity): pass mapper(Human, human, properties={"thing": relation(Thing)}) mapper(Thing, thing, properties={"name": deferred(thing.c.name)}) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): thing.insert().execute([ {"id": 1, "name": "Chair"}, ]) @@ -1714,7 +1720,8 @@ class DeferredPopulationTest(_base.MappedTest): class CompositeTypesTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('graphs', metadata, Column('id', Integer, primary_key=True), Column('version_id', Integer, primary_key=True, nullable=True), @@ -2246,7 +2253,8 @@ class MapperExtensionTest(_fixtures.FixtureTest): class RequirementsTest(_base.MappedTest): """Tests the contract for user classes.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('ht1', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -2280,9 +2288,9 @@ class RequirementsTest(_base.MappedTest): class OldStyle: pass - self.assertRaises(sa.exc.ArgumentError, mapper, OldStyle, ht1) + assert_raises(sa.exc.ArgumentError, mapper, OldStyle, ht1) - self.assertRaises(sa.exc.ArgumentError, mapper, 123) + assert_raises(sa.exc.ArgumentError, mapper, 123) class NoWeakrefSupport(str): pass @@ -2388,7 +2396,8 @@ class RequirementsTest(_base.MappedTest): class MagicNamesTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('cartographers', metadata, Column('id', Integer, primary_key=True), Column('name', String(50)), @@ -2401,7 +2410,8 @@ class MagicNamesTest(_base.MappedTest): Column('state', String(2)), Column('data', sa.Text)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Cartographer(_base.BasicEntity): pass @@ -2441,7 +2451,7 @@ class MagicNamesTest(_base.MappedTest): class T(object): pass - self.assertRaisesMessage( + assert_raises_message( KeyError, ('%r: requested attribute name conflicts with ' 'instrumentation attribute of the same name.' % reserved), @@ -2454,7 +2464,7 @@ class MagicNamesTest(_base.MappedTest): class M(object): pass - self.assertRaisesMessage( + assert_raises_message( KeyError, ('requested attribute name conflicts with ' 'instrumentation attribute of the same name'), @@ -2463,5 +2473,3 @@ class MagicNamesTest(_base.MappedTest): -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/merge.py b/test/orm/test_merge.py index fd553f2bf..70097cbee 100644 --- a/test/orm/merge.py +++ b/test/orm/test_merge.py @@ -1,9 +1,10 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa.util import OrderedSet -from testlib.sa.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property -from testlib.testing import eq_, ne_ -from orm import _base, _fixtures +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy.util import OrderedSet +from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property +from sqlalchemy.test.testing import eq_, ne_ +from test.orm import _base, _fixtures class MergeTest(_fixtures.FixtureTest): @@ -447,7 +448,7 @@ class MergeTest(_fixtures.FixtureTest): sess = create_session() u = User() - self.assertRaisesMessage(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True) + assert_raises_message(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True) @testing.resolve_artifact_names @@ -732,5 +733,3 @@ class MergeTest(_fixtures.FixtureTest): -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/naturalpks.py b/test/orm/test_naturalpks.py index 8efce660c..1376c402e 100644 --- a/test/orm/naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -2,16 +2,20 @@ Primary key changing capabilities and passive/non-passive cascading updates. """ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from testlib.testing import eq_ -from orm import _base +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base class NaturalPKTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): users = Table('users', metadata, Column('username', String(50), primary_key=True), Column('fullname', String(100)), @@ -32,7 +36,8 @@ class NaturalPKTest(_base.MappedTest): Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True), test_needs_fk=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass class Address(_base.ComparableEntity): @@ -62,7 +67,7 @@ class NaturalPKTest(_base.MappedTest): sess.expunge_all() u1 = sess.query(User).get('ed') - self.assertEquals(User(username='ed', fullname='jack'), u1) + eq_(User(username='ed', fullname='jack'), u1) @testing.resolve_artifact_names def test_load_after_expire(self): @@ -81,7 +86,7 @@ class NaturalPKTest(_base.MappedTest): # in this case so theres no way to look it up. criterion- # based session invalidation could solve this [ticket:911] sess.expire(u1) - self.assertRaises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username') + assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username') sess.expunge_all() assert sess.query(User).get('jack') is None @@ -132,7 +137,7 @@ class NaturalPKTest(_base.MappedTest): assert u1.addresses[0].username == 'ed' sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) u1 = sess.query(User).get('ed') u1.username = 'jack' @@ -152,7 +157,7 @@ class NaturalPKTest(_base.MappedTest): sess.expunge_all() assert sess.query(Address).get('jack1').username is None u1 = sess.query(User).get('fred') - self.assertEquals(User(username='fred', fullname='jack'), u1) + eq_(User(username='fred', fullname='jack'), u1) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') @@ -195,7 +200,7 @@ class NaturalPKTest(_base.MappedTest): assert a1.username == a2.username == 'ed' sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') def test_onetoone_passive(self): @@ -236,7 +241,7 @@ class NaturalPKTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 0) sess.expunge_all() - self.assertEquals([Address(username='ed')], sess.query(Address).all()) + eq_([Address(username='ed')], sess.query(Address).all()) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') def test_bidirectional_passive(self): @@ -265,7 +270,7 @@ class NaturalPKTest(_base.MappedTest): u1.username = 'ed' (ad1, ad2) = sess.query(Address).all() - self.assertEquals([Address(username='jack'), Address(username='jack')], [ad1, ad2]) + eq_([Address(username='jack'), Address(username='jack')], [ad1, ad2]) def go(): sess.flush() if passive_updates: @@ -273,9 +278,9 @@ class NaturalPKTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 1) else: self.assert_sql_count(testing.db, go, 3) - self.assertEquals([Address(username='ed'), Address(username='ed')], [ad1, ad2]) + eq_([Address(username='ed'), Address(username='ed')], [ad1, ad2]) sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) u1 = sess.query(User).get('ed') assert len(u1.addresses) == 2 # load addresses @@ -289,7 +294,7 @@ class NaturalPKTest(_base.MappedTest): else: self.assert_sql_count(testing.db, go, 3) sess.expunge_all() - self.assertEquals([Address(username='fred'), Address(username='fred')], sess.query(Address).all()) + eq_([Address(username='fred'), Address(username='fred')], sess.query(Address).all()) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') @@ -323,10 +328,10 @@ class NaturalPKTest(_base.MappedTest): r = sess.query(Item).all() # ComparableEntity can't handle a comparison with the backrefs # involved.... - self.assertEquals(Item(itemname='item1'), r[0]) - self.assertEquals(['jack'], [u.username for u in r[0].users]) - self.assertEquals(Item(itemname='item2'), r[1]) - self.assertEquals(['jack', 'fred'], [u.username for u in r[1].users]) + eq_(Item(itemname='item1'), r[0]) + eq_(['jack'], [u.username for u in r[0].users]) + eq_(Item(itemname='item2'), r[1]) + eq_(['jack', 'fred'], [u.username for u in r[1].users]) u2.username='ed' def go(): @@ -338,29 +343,31 @@ class NaturalPKTest(_base.MappedTest): sess.expunge_all() r = sess.query(Item).all() - self.assertEquals(Item(itemname='item1'), r[0]) - self.assertEquals(['jack'], [u.username for u in r[0].users]) - self.assertEquals(Item(itemname='item2'), r[1]) - self.assertEquals(['ed', 'jack'], sorted([u.username for u in r[1].users])) + eq_(Item(itemname='item1'), r[0]) + eq_(['jack'], [u.username for u in r[0].users]) + eq_(Item(itemname='item2'), r[1]) + eq_(['ed', 'jack'], sorted([u.username for u in r[1].users])) sess.expunge_all() u2 = sess.query(User).get(u2.username) u2.username='wendy' sess.flush() r = sess.query(Item).with_parent(u2).all() - self.assertEquals(Item(itemname='item2'), r[0]) + eq_(Item(itemname='item2'), r[0]) class SelfRefTest(_base.MappedTest): __unsupported_on__ = 'mssql' # mssql doesn't allow ON UPDATE on self-referential keys - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('nodes', metadata, Column('name', String(50), primary_key=True), Column('parent', String(50), ForeignKey('nodes.name', onupdate='cascade'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Node(_base.ComparableEntity): pass @@ -391,7 +398,8 @@ class SelfRefTest(_base.MappedTest): class NonPKCascadeTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('username', String(50), unique=True), @@ -406,7 +414,8 @@ class NonPKCascadeTest(_base.MappedTest): test_needs_fk=True ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass class Address(_base.ComparableEntity): @@ -433,17 +442,17 @@ class NonPKCascadeTest(_base.MappedTest): sess.flush() a1 = u1.addresses[0] - self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)]) + eq_(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)]) assert sess.query(Address).get(a1.id) is u1.addresses[0] u1.username = 'ed' sess.flush() assert u1.addresses[0].username == 'ed' - self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)]) + eq_(sa.select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)]) sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) u1 = sess.query(User).get(u1.id) u1.username = 'jack' @@ -463,13 +472,11 @@ class NonPKCascadeTest(_base.MappedTest): sess.flush() sess.expunge_all() a1 = sess.query(Address).get(a1.id) - self.assertEquals(a1.username, None) + eq_(a1.username, None) - self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [(None,), (None,)]) + eq_(sa.select([addresses.c.username]).execute().fetchall(), [(None,), (None,)]) u1 = sess.query(User).get(u1.id) - self.assertEquals(User(username='fred', fullname='jack'), u1) + eq_(User(username='fred', fullname='jack'), u1) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/onetoone.py b/test/orm/test_onetoone.py index be0375e48..0d66915ea 100644 --- a/test/orm/onetoone.py +++ b/test/orm/test_onetoone.py @@ -1,12 +1,15 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base class O2OTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('jack', metadata, Column('id', Integer, primary_key=True), Column('number', String(50)), @@ -19,8 +22,9 @@ class O2OTest(_base.MappedTest): Column('description', String(100)), Column('jack_id', Integer, ForeignKey("jack.id"))) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Jack(_base.BasicEntity): pass class Port(_base.BasicEntity): @@ -70,5 +74,3 @@ class O2OTest(_base.MappedTest): session.delete(j) session.flush() -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/pickled.py b/test/orm/test_pickled.py index 878fe931e..5343cc15b 100644 --- a/test/orm/pickled.py +++ b/test/orm/test_pickled.py @@ -1,9 +1,12 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import pickle -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session, attributes -from orm import _base, _fixtures +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, attributes +from test.orm import _base, _fixtures User, EmailUser = None, None @@ -28,7 +31,7 @@ class PickleTest(_fixtures.FixtureTest): sess.expunge_all() - self.assertEquals(u1, sess.query(User).get(u2.id)) + eq_(u1, sess.query(User).get(u2.id)) @testing.resolve_artifact_names def test_class_deferred_cols(self): @@ -52,14 +55,14 @@ class PickleTest(_fixtures.FixtureTest): u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() sess2.add(u2) - self.assertEquals(u2.name, 'ed') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + eq_(u2.name, 'ed') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() u2 = sess2.merge(u2, dont_load=True) - self.assertEquals(u2.name, 'ed') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + eq_(u2.name, 'ed') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) @testing.resolve_artifact_names def test_instance_deferred_cols(self): @@ -82,22 +85,22 @@ class PickleTest(_fixtures.FixtureTest): u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() sess2.add(u2) - self.assertEquals(u2.name, 'ed') + eq_(u2.name, 'ed') assert 'addresses' not in u2.__dict__ ad = u2.addresses[0] assert 'email_address' not in ad.__dict__ - self.assertEquals(ad.email_address, 'ed@bar.com') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + eq_(ad.email_address, 'ed@bar.com') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() u2 = sess2.merge(u2, dont_load=True) - self.assertEquals(u2.name, 'ed') + eq_(u2.name, 'ed') assert 'addresses' not in u2.__dict__ ad = u2.addresses[0] assert 'email_address' in ad.__dict__ # mapper options dont transmit over merge() right now - self.assertEquals(ad.email_address, 'ed@bar.com') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + eq_(ad.email_address, 'ed@bar.com') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) @testing.resolve_artifact_names def test_options_with_descriptors(self): @@ -122,7 +125,7 @@ class PickleTest(_fixtures.FixtureTest): sa.orm.eagerload(["addresses", User.addresses]), ]: opt2 = pickle.loads(pickle.dumps(opt)) - self.assertEquals(opt.key, opt2.key) + eq_(opt.key, opt2.key) u1 = sess.query(User).options(opt).first() @@ -130,7 +133,8 @@ class PickleTest(_fixtures.FixtureTest): class PolymorphicDeferredTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(30)), @@ -139,7 +143,8 @@ class PolymorphicDeferredTest(_base.MappedTest): Column('id', Integer, ForeignKey('users.id'), primary_key=True), Column('email_address', String(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): global User, EmailUser class User(_base.BasicEntity): pass @@ -147,10 +152,11 @@ class PolymorphicDeferredTest(_base.MappedTest): class EmailUser(User): pass - def tearDownAll(self): + @classmethod + def teardown_class(cls): global User, EmailUser User, EmailUser = None, None - _base.MappedTest.tearDownAll(self) + super(PolymorphicDeferredTest, cls).teardown_class() @testing.resolve_artifact_names def test_polymorphic_deferred(self): @@ -168,9 +174,9 @@ class PolymorphicDeferredTest(_base.MappedTest): sess2 = create_session() sess2.add(eu2) assert 'email_address' not in eu2.__dict__ - self.assertEquals(eu2.email_address, 'foo@bar.com') + eq_(eu2.email_address, 'foo@bar.com') -class CustomSetupTeardowntest(_fixtures.FixtureTest): +class CustomSetupTeardownTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_rebuild_state(self): """not much of a 'test', but illustrate how to @@ -186,5 +192,3 @@ class CustomSetupTeardowntest(_fixtures.FixtureTest): attributes.manager_of_class(User).setup_instance(u2) assert attributes.instance_state(u2) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/query.py b/test/orm/test_query.py index 33c3e39d7..66c219b10 100644 --- a/test/orm/query.py +++ b/test/orm/test_query.py @@ -1,4 +1,4 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import operator from sqlalchemy import * from sqlalchemy import exc as sa_exc, util @@ -7,17 +7,18 @@ from sqlalchemy.engine import default from sqlalchemy.orm import * from sqlalchemy.orm import attributes -from testlib.testing import eq_ +from sqlalchemy.test.testing import eq_ -from testlib import sa, testing, AssertsCompiledSQL, Column, engines +import sqlalchemy as sa +from sqlalchemy.test import testing, AssertsCompiledSQL, Column, engines -from orm import _fixtures -from orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \ +from test.orm import _fixtures +from test.orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \ Dingaling, item_keywords, dingalings, User, items,\ orders, Address, users, nodes, \ order_items, Item, Order, Node -from orm import _base +from test.orm import _base from sqlalchemy.orm.util import join, outerjoin, with_parent @@ -27,7 +28,8 @@ class QueryTest(_fixtures.FixtureTest): run_deletes = None - def setup_mappers(self): + @classmethod + def setup_mappers(cls): mapper(User, users, properties={ 'addresses':relation(Address, backref='user', order_by=addresses.c.id), 'orders':relation(Order, backref='user', order_by=orders.c.id), # o2m, m2o @@ -82,8 +84,8 @@ class GetTest(QueryTest): s = create_session() q = s.query(User).join('addresses').filter(Address.user_id==8) - self.assertRaises(sa_exc.InvalidRequestError, q.get, 7) - self.assertRaises(sa_exc.InvalidRequestError, s.query(User).filter(User.id==7).get, 19) + assert_raises(sa_exc.InvalidRequestError, q.get, 7) + assert_raises(sa_exc.InvalidRequestError, s.query(User).filter(User.id==7).get, 19) # order_by()/get() doesn't raise s.query(User).order_by(User.id).get(8) @@ -142,7 +144,7 @@ class GetTest(QueryTest): class LocalFoo(Base): pass mapper(LocalFoo, table) - self.assertEquals(create_session().query(LocalFoo).get(ustring), + eq_(create_session().query(LocalFoo).get(ustring), LocalFoo(id=ustring, data=ustring)) finally: metadata.drop_all() @@ -183,7 +185,7 @@ class GetTest(QueryTest): def test_query_str(self): s = create_session() q = s.query(User).filter(User.id==1) - self.assertEquals( + eq_( str(q).replace('\n',''), 'SELECT users.id AS users_id, users.name AS users_name FROM users WHERE users.id = ?' ) @@ -197,29 +199,29 @@ class InvalidGenerationsTest(QueryTest): s.query(User).offset(2), s.query(User).limit(2).offset(2) ): - self.assertRaises(sa_exc.InvalidRequestError, q.join, "addresses") + assert_raises(sa_exc.InvalidRequestError, q.join, "addresses") - self.assertRaises(sa_exc.InvalidRequestError, q.filter, User.name=='ed') + assert_raises(sa_exc.InvalidRequestError, q.filter, User.name=='ed') - self.assertRaises(sa_exc.InvalidRequestError, q.filter_by, name='ed') + assert_raises(sa_exc.InvalidRequestError, q.filter_by, name='ed') - self.assertRaises(sa_exc.InvalidRequestError, q.order_by, 'foo') + assert_raises(sa_exc.InvalidRequestError, q.order_by, 'foo') - self.assertRaises(sa_exc.InvalidRequestError, q.group_by, 'foo') + assert_raises(sa_exc.InvalidRequestError, q.group_by, 'foo') - self.assertRaises(sa_exc.InvalidRequestError, q.having, 'foo') + assert_raises(sa_exc.InvalidRequestError, q.having, 'foo') def test_no_from(self): s = create_session() q = s.query(User).select_from(users) - self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users) + assert_raises(sa_exc.InvalidRequestError, q.select_from, users) q = s.query(User).join('addresses') - self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users) + assert_raises(sa_exc.InvalidRequestError, q.select_from, users) q = s.query(User).order_by(User.id) - self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users) + assert_raises(sa_exc.InvalidRequestError, q.select_from, users) # this is fine, however q.from_self() @@ -227,43 +229,43 @@ class InvalidGenerationsTest(QueryTest): def test_invalid_select_from(self): s = create_session() q = s.query(User) - self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id==5) - self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id) + assert_raises(sa_exc.ArgumentError, q.select_from, User.id==5) + assert_raises(sa_exc.ArgumentError, q.select_from, User.id) def test_invalid_from_statement(self): s = create_session() q = s.query(User) - self.assertRaises(sa_exc.ArgumentError, q.from_statement, User.id==5) - self.assertRaises(sa_exc.ArgumentError, q.from_statement, users.join(addresses)) + assert_raises(sa_exc.ArgumentError, q.from_statement, User.id==5) + assert_raises(sa_exc.ArgumentError, q.from_statement, users.join(addresses)) def test_invalid_column(self): s = create_session() q = s.query(User) - self.assertRaises(sa_exc.InvalidRequestError, q.add_column, object()) + assert_raises(sa_exc.InvalidRequestError, q.add_column, object()) def test_mapper_zero(self): s = create_session() q = s.query(User, Address) - self.assertRaises(sa_exc.InvalidRequestError, q.get, 5) + assert_raises(sa_exc.InvalidRequestError, q.get, 5) def test_from_statement(self): s = create_session() q = s.query(User).filter(User.id==5) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") q = s.query(User).filter_by(id=5) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") q = s.query(User).limit(5) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") q = s.query(User).group_by(User.name) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") q = s.query(User).order_by(User.name) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") class OperatorTest(QueryTest, AssertsCompiledSQL): """test sql.Comparator implementation for MapperProperties""" @@ -431,7 +433,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): "users.id IN (:id_1, :id_2)") def test_in_on_relation_not_supported(self): - self.assertRaises(NotImplementedError, Address.user.in_, [User(id=5)]) + assert_raises(NotImplementedError, Address.user.in_, [User(id=5)]) def test_between(self): self._test(User.id.between('a', 'b'), @@ -705,8 +707,8 @@ class FilterTest(QueryTest): assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all() # m2m - self.assertEquals(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)]) - self.assertEquals(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)]) + eq_(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)]) + eq_(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)]) def test_filter_by(self): sess = create_session() @@ -723,16 +725,16 @@ class FilterTest(QueryTest): sess = create_session() # o2o - self.assertEquals([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all()) - self.assertEquals([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all()) + eq_([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all()) + eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all()) # m2o - self.assertEquals([Order(id=5)], sess.query(Order).filter(Order.address==None).all()) - self.assertEquals([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).order_by(Order.id).filter(Order.address!=None).all()) + eq_([Order(id=5)], sess.query(Order).filter(Order.address==None).all()) + eq_([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).order_by(Order.id).filter(Order.address!=None).all()) # o2m - self.assertEquals([User(id=10)], sess.query(User).filter(User.addresses==None).all()) - self.assertEquals([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).order_by(User.id).all()) + eq_([User(id=10)], sess.query(User).filter(User.addresses==None).all()) + eq_([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).order_by(User.id).all()) class FromSelfTest(QueryTest, AssertsCompiledSQL): @@ -818,7 +820,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): def test_multiple_entities(self): sess = create_session() - self.assertEquals( + eq_( sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().all(), [ (User(id=8), Address(id=2)), @@ -826,7 +828,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): ] ) - self.assertEquals( + eq_( sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().options(eagerload('addresses')).first(), # order_by(User.id, Address.id).first(), @@ -842,11 +844,11 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): ed = s.query(User).filter(User.name=='ed') jack = s.query(User).filter(User.name=='jack') - self.assertEquals(fred.union(ed).order_by(User.name).all(), + eq_(fred.union(ed).order_by(User.name).all(), [User(name='ed'), User(name='fred')] ) - self.assertEquals(fred.union(ed, jack).order_by(User.name).all(), + eq_(fred.union(ed, jack).order_by(User.name).all(), [User(name='ed'), User(name='fred'), User(name='jack')] ) @@ -857,11 +859,11 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): fred = s.query(User).filter(User.name=='fred') ed = s.query(User).filter(User.name=='ed') jack = s.query(User).filter(User.name=='jack') - self.assertEquals(fred.intersect(ed, jack).all(), + eq_(fred.intersect(ed, jack).all(), [] ) - self.assertEquals(fred.union(ed).intersect(ed.union(jack)).all(), + eq_(fred.union(ed).intersect(ed.union(jack)).all(), [User(name='ed')] ) @@ -873,7 +875,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): jack = s.query(User).filter(User.name=='jack') def go(): - self.assertEquals( + eq_( fred.union(ed).order_by(User.name).options(eagerload(User.addresses)).all(), [ User(name='ed', addresses=[Address(), Address(), Address()]), @@ -888,8 +890,8 @@ class AggregateTest(QueryTest): def test_sum(self): sess = create_session() orders = sess.query(Order).filter(Order.id.in_([2, 3, 4])) - self.assertEquals(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,)) - self.assertEquals(orders.value(func.sum(Order.user_id * Order.address_id)), 79) + eq_(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,)) + eq_(orders.value(func.sum(Order.user_id * Order.address_id)), 79) def test_apply(self): sess = create_session() @@ -987,13 +989,13 @@ class YieldTest(QueryTest): q = iter(sess.query(User).yield_per(1).from_statement("select * from users")) ret = [] - self.assertEquals(len(sess.identity_map), 0) + eq_(len(sess.identity_map), 0) ret.append(q.next()) ret.append(q.next()) - self.assertEquals(len(sess.identity_map), 2) + eq_(len(sess.identity_map), 2) ret.append(q.next()) ret.append(q.next()) - self.assertEquals(len(sess.identity_map), 4) + eq_(len(sess.identity_map), 4) try: q.next() assert False @@ -1019,7 +1021,7 @@ class TextTest(QueryTest): def test_as_column(self): s = create_session() - self.assertRaises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name")) + assert_raises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name")) eq_(s.query(User.id, "name").order_by(User.id).all(), [(7, u'jack'), (8, u'ed'), (9, u'fred'), (10, u'chuck')]) @@ -1091,24 +1093,24 @@ class JoinTest(QueryTest): sess = create_session() for oalias,ialias in [(True, True), (False, False), (True, False), (False, True)]: - self.assertEquals( + eq_( sess.query(User).join('orders', aliased=oalias).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description == 'item 4').all(), [User(name='jack')] ) # use middle criterion - self.assertEquals( + eq_( sess.query(User).join('orders', aliased=oalias).filter(Order.user_id==9).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description=='item 4').all(), [] ) orderalias = aliased(Order) itemalias = aliased(Item) - self.assertEquals( + eq_( sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(itemalias.description == 'item 4').all(), [User(name='jack')] ) - self.assertEquals( + eq_( sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(orderalias.user_id==9).filter(itemalias.description=='item 4').all(), [] ) @@ -1119,28 +1121,28 @@ class JoinTest(QueryTest): sess = create_session() - self.assertEquals( + eq_( sess.query(User).join(Address.user).filter(Address.email_address=='ed@wood.com').all(), [User(id=8,name=u'ed')] ) # its actually not so controversial if you view it in terms # of multiple entities. - self.assertEquals( + eq_( sess.query(User, Address).join(Address.user).filter(Address.email_address=='ed@wood.com').all(), [(User(id=8,name=u'ed'), Address(email_address='ed@wood.com'))] ) # this was the controversial part. now, raise an error if the feature is abused. # before the error raise was added, this would silently work..... - self.assertRaises( + assert_raises( sa_exc.InvalidRequestError, sess.query(User).join, (Address, Address.user), ) # but this one would silently fail adalias = aliased(Address) - self.assertRaises( + assert_raises( sa_exc.InvalidRequestError, sess.query(User).join, (adalias, Address.user), ) @@ -1153,7 +1155,7 @@ class JoinTest(QueryTest): oalias2 = aliased(Order) result = sess.query(ualias).join((oalias1, ualias.orders), (oalias2, ualias.orders)).\ filter(or_(oalias1.user_id==9, oalias2.user_id==7)).all() - self.assertEquals(result, [User(id=7,name=u'jack'), User(id=9,name=u'fred')]) + eq_(result, [User(id=7,name=u'jack'), User(id=9,name=u'fred')]) def test_orderby_arg_bug(self): sess = create_session() @@ -1163,17 +1165,17 @@ class JoinTest(QueryTest): def test_no_onclause(self): sess = create_session() - self.assertEquals( + eq_( sess.query(User).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(), [User(name='jack')] ) - self.assertEquals( + eq_( sess.query(User.name).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(), [('jack',)] ) - self.assertEquals( + eq_( sess.query(User).join(Order, (Item, Order.items)).filter(Item.description == 'item 4').all(), [User(name='jack')] ) @@ -1181,7 +1183,7 @@ class JoinTest(QueryTest): def test_clause_onclause(self): sess = create_session() - self.assertEquals( + eq_( sess.query(User).join( (Order, User.id==Order.user_id), (order_items, Order.id==order_items.c.order_id), @@ -1190,7 +1192,7 @@ class JoinTest(QueryTest): [User(name='jack')] ) - self.assertEquals( + eq_( sess.query(User.name).join( (Order, User.id==Order.user_id), (order_items, Order.id==order_items.c.order_id), @@ -1200,7 +1202,7 @@ class JoinTest(QueryTest): ) ualias = aliased(User) - self.assertEquals( + eq_( sess.query(ualias.name).join( (Order, ualias.id==Order.user_id), (order_items, Order.id==order_items.c.order_id), @@ -1212,7 +1214,7 @@ class JoinTest(QueryTest): # explicit onclause with from_self(), means # the onclause must be aliased against the query's custom # FROM object - self.assertEquals( + eq_( sess.query(User).order_by(User.id).offset(2).from_self().join( (Order, User.id==Order.user_id) ).all(), @@ -1220,7 +1222,7 @@ class JoinTest(QueryTest): ) # same with an explicit select_from() - self.assertEquals( + eq_( sess.query(User).select_from(select([users]).order_by(User.id).offset(2).alias()).join( (Order, User.id==Order.user_id) ).all(), @@ -1244,34 +1246,34 @@ class JoinTest(QueryTest): AdAlias = aliased(Address) q = q.add_entity(AdAlias).select_from(outerjoin(User, AdAlias)) l = q.order_by(User.id, AdAlias.id).all() - self.assertEquals(l, expected) + eq_(l, expected) sess.expunge_all() q = sess.query(User).add_entity(AdAlias) l = q.select_from(outerjoin(User, AdAlias)).filter(AdAlias.email_address=='ed@bettyboop.com').all() - self.assertEquals(l, [(user8, address3)]) + eq_(l, [(user8, address3)]) l = q.select_from(outerjoin(User, AdAlias, 'addresses')).filter(AdAlias.email_address=='ed@bettyboop.com').all() - self.assertEquals(l, [(user8, address3)]) + eq_(l, [(user8, address3)]) l = q.select_from(outerjoin(User, AdAlias, User.id==AdAlias.user_id)).filter(AdAlias.email_address=='ed@bettyboop.com').all() - self.assertEquals(l, [(user8, address3)]) + eq_(l, [(user8, address3)]) # this is the first test where we are joining "backwards" - from AdAlias to User even though # the query is against User q = sess.query(User, AdAlias) l = q.join(AdAlias.user).filter(User.name=='ed') - self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) + eq_(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) q = sess.query(User, AdAlias).select_from(join(AdAlias, User, AdAlias.user)).filter(User.name=='ed') - self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) + eq_(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) def test_implicit_joins_from_aliases(self): sess = create_session() OrderAlias = aliased(Order) - self.assertEquals( + eq_( sess.query(OrderAlias).join('items').filter_by(description='item 3').\ order_by(OrderAlias.id).all(), [ @@ -1281,7 +1283,7 @@ class JoinTest(QueryTest): ] ) - self.assertEquals( + eq_( sess.query(User, OrderAlias, Item.description).join(('orders', OrderAlias), 'items').filter_by(description='item 3').\ order_by(User.id, OrderAlias.id).all(), [ @@ -1314,12 +1316,12 @@ class JoinTest(QueryTest): q = sess.query(Order) q = q.add_entity(Item).select_from(join(Order, Item, 'items')).order_by(Order.id, Item.id) l = q.all() - self.assertEquals(l, expected) + eq_(l, expected) IAlias = aliased(Item) q = sess.query(Order, IAlias).select_from(join(Order, IAlias, 'items')).filter(IAlias.description=='item 3') l = q.all() - self.assertEquals(l, + eq_(l, [ (order1, item3), (order2, item3), @@ -1385,7 +1387,7 @@ class JoinTest(QueryTest): sess = create_session() ualias = aliased(User) - self.assertEquals( + eq_( sess.query(User, ualias).filter(User.id > ualias.id).order_by(desc(ualias.id), User.name).all(), [ (User(id=10,name=u'chuck'), User(id=9,name=u'fred')), @@ -1401,14 +1403,15 @@ class JoinTest(QueryTest): sess = create_session() - self.assertEquals( + eq_( sess.query(User.name).join((addresses, User.id==addresses.c.user_id)).order_by(User.id).all(), [(u'jack',), (u'ed',), (u'ed',), (u'ed',), (u'fred',)] ) class MultiplePathTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1, t2, t1t2_1, t1t2_2 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), @@ -1440,7 +1443,7 @@ class MultiplePathTest(_base.MappedTest): mapper(T2, t2) q = create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint() - self.assertRaisesMessage(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.", + assert_raises_message(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.", q.join, 't2s_2' ) @@ -1449,7 +1452,8 @@ class MultiplePathTest(_base.MappedTest): class SynonymTest(QueryTest): - def setup_mappers(self): + @classmethod + def setup_mappers(cls): mapper(User, users, properties={ 'name_syn':synonym('name'), 'addresses':relation(Address), @@ -1547,7 +1551,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): adalias = addresses.alias() q = sess.query(User).select_from(users.outerjoin(adalias)).options(contains_eager(User.addresses, alias=adalias)) def go(): - self.assertEquals(self.static.user_address_result, q.order_by(User.id).all()) + eq_(self.static.user_address_result, q.order_by(User.id).all()) self.assert_sql_count(testing.db, go, 1) sess.expunge_all() @@ -1675,35 +1679,35 @@ class MixedEntitiesTest(QueryTest): sel = users.select(User.id.in_([7, 8])).alias() q = sess.query(User) q2 = q.select_from(sel).values(User.name) - self.assertEquals(list(q2), [(u'jack',), (u'ed',)]) + eq_(list(q2), [(u'jack',), (u'ed',)]) q = sess.query(User) q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String)) - self.assertEquals(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')]) + eq_(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')]) q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address) - self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) + eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice(1, 3).values(User.name, Address.email_address) - self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')]) + eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')]) adalias = aliased(Address) q2 = q.join(('addresses', adalias)).filter(User.name.like('%e%')).values(User.name, adalias.email_address) - self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) + eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) q2 = q.values(func.count(User.name)) assert q2.next() == (4,) q2 = q.select_from(sel).filter(User.id==8).values(User.name, sel.c.name, User.name) - self.assertEquals(list(q2), [(u'ed', u'ed', u'ed')]) + eq_(list(q2), [(u'ed', u'ed', u'ed')]) # using User.xxx is alised against "sel", so this query returns nothing q2 = q.select_from(sel).filter(User.id==8).filter(User.id>sel.c.id).values(User.name, sel.c.name, User.name) - self.assertEquals(list(q2), []) + eq_(list(q2), []) # whereas this uses users.c.xxx, is not aliased and creates a new join q2 = q.select_from(sel).filter(users.c.id==8).filter(users.c.id>sel.c.id).values(users.c.name, sel.c.name, User.name) - self.assertEquals(list(q2), [(u'ed', u'jack', u'jack')]) + eq_(list(q2), [(u'ed', u'jack', u'jack')]) @testing.fails_on('mssql', 'FIXME: unknown') def test_values_specific_order_by(self): @@ -1715,7 +1719,7 @@ class MixedEntitiesTest(QueryTest): q = sess.query(User) u2 = aliased(User) q2 = q.select_from(sel).filter(u2.id>1).order_by([User.id, sel.c.id, u2.id]).values(User.name, sel.c.name, u2.name) - self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')]) + eq_(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')]) @testing.fails_on('mssql', 'FIXME: unknown') def test_values_with_boolean_selects(self): @@ -1724,7 +1728,7 @@ class MixedEntitiesTest(QueryTest): q = sess.query(User) q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%'))) - self.assertEquals(list(q2), [(True, 1), (False, 3)]) + eq_(list(q2), [(True, 1), (False, 3)]) def test_correlated_subquery(self): """test that a subquery constructed from ORM attributes doesn't leak out @@ -1739,7 +1743,7 @@ class MixedEntitiesTest(QueryTest): label('count') # we don't want Address to be outside of the subquery here - self.assertEquals( + eq_( list(sess.query(User, subq)[0:3]), [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)] ) @@ -1751,7 +1755,7 @@ class MixedEntitiesTest(QueryTest): label('count') # we don't want Address to be outside of the subquery here - self.assertEquals( + eq_( list(sess.query(User, subq)[0:3]), [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)] ) @@ -1759,71 +1763,71 @@ class MixedEntitiesTest(QueryTest): def test_tuple_labeling(self): sess = create_session() for row in sess.query(User, Address).join(User.addresses).all(): - self.assertEquals(set(row.keys()), set(['User', 'Address'])) - self.assertEquals(row.User, row[0]) - self.assertEquals(row.Address, row[1]) + eq_(set(row.keys()), set(['User', 'Address'])) + eq_(row.User, row[0]) + eq_(row.Address, row[1]) for row in sess.query(User.name, User.id.label('foobar')): - self.assertEquals(set(row.keys()), set(['name', 'foobar'])) - self.assertEquals(row.name, row[0]) - self.assertEquals(row.foobar, row[1]) + eq_(set(row.keys()), set(['name', 'foobar'])) + eq_(row.name, row[0]) + eq_(row.foobar, row[1]) for row in sess.query(User).values(User.name, User.id.label('foobar')): - self.assertEquals(set(row.keys()), set(['name', 'foobar'])) - self.assertEquals(row.name, row[0]) - self.assertEquals(row.foobar, row[1]) + eq_(set(row.keys()), set(['name', 'foobar'])) + eq_(row.name, row[0]) + eq_(row.foobar, row[1]) oalias = aliased(Order) for row in sess.query(User, oalias).join(User.orders).all(): - self.assertEquals(set(row.keys()), set(['User'])) - self.assertEquals(row.User, row[0]) + eq_(set(row.keys()), set(['User'])) + eq_(row.User, row[0]) oalias = aliased(Order, name='orders') for row in sess.query(User, oalias).join(User.orders).all(): - self.assertEquals(set(row.keys()), set(['User', 'orders'])) - self.assertEquals(row.User, row[0]) - self.assertEquals(row.orders, row[1]) + eq_(set(row.keys()), set(['User', 'orders'])) + eq_(row.User, row[0]) + eq_(row.orders, row[1]) def test_column_queries(self): sess = create_session() - self.assertEquals(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)]) + eq_(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)]) sel = users.select(User.id.in_([7, 8])).alias() q = sess.query(User.name) q2 = q.select_from(sel).all() - self.assertEquals(list(q2), [(u'jack',), (u'ed',)]) + eq_(list(q2), [(u'jack',), (u'ed',)]) - self.assertEquals(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [ + eq_(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [ (u'jack', u'jack@bean.com'), (u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com') ]) - self.assertEquals(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(), + eq_(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(), [(u'jack', 1), (u'ed', 3), (u'fred', 1), (u'chuck', 0)] ) - self.assertEquals(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), + eq_(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)] ) - self.assertEquals(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), + eq_(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), [(1, User(name='jack',id=7)), (3, User(name='ed',id=8)), (1, User(name='fred',id=9)), (0, User(name='chuck',id=10))] ) adalias = aliased(Address) - self.assertEquals(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(), + eq_(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(), [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)] ) - self.assertEquals(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(), + eq_(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(), [(1, User(name=u'jack',id=7)), (3, User(name=u'ed',id=8)), (1, User(name=u'fred',id=9)), (0, User(name=u'chuck',id=10))] ) # select from aliasing + explicit aliasing - self.assertEquals( + eq_( sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).order_by(User.id, adalias.id).all(), [ (User(name=u'jack',id=7), u'jack@bean.com'), @@ -1836,7 +1840,7 @@ class MixedEntitiesTest(QueryTest): ) # anon + select from aliasing - self.assertEquals( + eq_( sess.query(User).join(User.addresses, aliased=True).filter(Address.email_address.like('%ed%')).from_self().all(), [ User(name=u'ed',id=8), @@ -1849,7 +1853,7 @@ class MixedEntitiesTest(QueryTest): sess.query(User, adalias.email_address).outerjoin((User.addresses, adalias)).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10), sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10), ]: - self.assertEquals( + eq_( q.all(), [(User(addresses=[Address(user_id=7,email_address=u'jack@bean.com',id=1)],name=u'jack',id=7), u'jack@bean.com'), @@ -1875,7 +1879,7 @@ class MixedEntitiesTest(QueryTest): def go(): results = sess.query(User).limit(1).options(eagerload('addresses')).add_column(User.name).all() - self.assertEquals(results, [(User(name='jack'), 'jack')]) + eq_(results, [(User(name='jack'), 'jack')]) self.assert_sql_count(testing.db, go, 1) def test_self_referential(self): @@ -1898,7 +1902,7 @@ class MixedEntitiesTest(QueryTest): ]: - self.assertEquals( + eq_( q.all(), [ (Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)), @@ -1924,25 +1928,25 @@ class MixedEntitiesTest(QueryTest): sess = create_session() selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id]) - self.assertEquals(list(sess.query(User, Address).instances(selectquery.execute())), expected) + eq_(list(sess.query(User, Address).instances(selectquery.execute())), expected) sess.expunge_all() for address_entity in (Address, aliased(Address)): q = sess.query(User).add_entity(address_entity).outerjoin(('addresses', address_entity)).order_by(User.id, address_entity.id) - self.assertEquals(q.all(), expected) + eq_(q.all(), expected) sess.expunge_all() q = sess.query(User).add_entity(address_entity) q = q.join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com') - self.assertEquals(q.all(), [(user8, address3)]) + eq_(q.all(), [(user8, address3)]) sess.expunge_all() q = sess.query(User, address_entity).join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com') - self.assertEquals(q.all(), [(user8, address3)]) + eq_(q.all(), [(user8, address3)]) sess.expunge_all() q = sess.query(User, address_entity).join(('addresses', address_entity)).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com') - self.assertEquals(list(util.OrderedSet(q.all())), [(user8, address3)]) + eq_(list(util.OrderedSet(q.all())), [(user8, address3)]) sess.expunge_all() def test_aliased_multi_mappers(self): @@ -1979,7 +1983,7 @@ class MixedEntitiesTest(QueryTest): assert sess.query(User).add_column(add_col).all() == expected sess.expunge_all() - self.assertRaises(sa_exc.InvalidRequestError, sess.query(User).add_column, object()) + assert_raises(sa_exc.InvalidRequestError, sess.query(User).add_column, object()) def test_add_multi_columns(self): """test that add_column accepts a FROM clause.""" @@ -2004,13 +2008,13 @@ class MixedEntitiesTest(QueryTest): q = sess.query(User) q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(Address.id).label('count')) - self.assertEquals(q.all(), expected) + eq_(q.all(), expected) sess.expunge_all() adalias = aliased(Address) q = sess.query(User) q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin(('addresses', adalias)).add_column(func.count(adalias.id).label('count')) - self.assertEquals(q.all(), expected) + eq_(q.all(), expected) sess.expunge_all() s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id) @@ -2069,8 +2073,9 @@ class ImmediateTest(_fixtures.FixtureTest): run_inserts = 'once' run_deletes = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Address, addresses) mapper(User, users, properties=dict( @@ -2080,25 +2085,25 @@ class ImmediateTest(_fixtures.FixtureTest): def test_one(self): sess = create_session() - self.assertRaises(sa.orm.exc.NoResultFound, + assert_raises(sa.orm.exc.NoResultFound, sess.query(User).filter(User.id == 99).one) eq_(sess.query(User).filter(User.id == 7).one().id, 7) - self.assertRaises(sa.orm.exc.MultipleResultsFound, + assert_raises(sa.orm.exc.MultipleResultsFound, sess.query(User).one) - self.assertRaises( + assert_raises( sa.orm.exc.NoResultFound, sess.query(User.id, User.name).filter(User.id == 99).one) eq_(sess.query(User.id, User.name).filter(User.id == 7).one(), (7, 'jack')) - self.assertRaises(sa.orm.exc.MultipleResultsFound, + assert_raises(sa.orm.exc.MultipleResultsFound, sess.query(User.id, User.name).one) - self.assertRaises(sa.orm.exc.NoResultFound, + assert_raises(sa.orm.exc.NoResultFound, (sess.query(User, Address). join(User.addresses). filter(Address.id == 99)).one) @@ -2108,7 +2113,7 @@ class ImmediateTest(_fixtures.FixtureTest): filter(Address.id == 4)).one(), (User(id=8), Address(id=4))) - self.assertRaises(sa.orm.exc.MultipleResultsFound, + assert_raises(sa.orm.exc.MultipleResultsFound, sess.query(User, Address).join(User.addresses).one) @testing.future @@ -2133,7 +2138,7 @@ class ImmediateTest(_fixtures.FixtureTest): eq_(sess.query(User.id, User.name).filter_by(id=7).value(User.id), 7) eq_(sess.query(User).filter_by(id=0).value(User.id), None) - sess.bind = sa.testing.db + sess.bind = testing.db eq_(sess.query().value(sa.literal_column('1').label('x')), 1) @@ -2149,19 +2154,19 @@ class SelectFromTest(QueryTest): sel = users.select(users.c.id.in_([7, 8])).alias() sess = create_session() - self.assertEquals(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)]) + eq_(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)]) - self.assertEquals(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)]) + eq_(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)]) - self.assertEquals(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [ + eq_(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [ User(name='jack',id=7), User(name='ed',id=8) ]) - self.assertEquals(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [ + eq_(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [ User(name='ed',id=8), User(name='jack',id=7) ]) - self.assertEquals(sess.query(User).select_from(sel).options(eagerload('addresses')).first(), + eq_(sess.query(User).select_from(sel).options(eagerload('addresses')).first(), User(name='jack', addresses=[Address(id=1)]) ) @@ -2173,7 +2178,7 @@ class SelectFromTest(QueryTest): sel = users.select(users.c.id.in_([7, 8])) sess = create_session() - self.assertEquals(sess.query(User).select_from(sel).all(), + eq_(sess.query(User).select_from(sel).all(), [ User(name='jack',id=7), User(name='ed',id=8) ] @@ -2185,7 +2190,7 @@ class SelectFromTest(QueryTest): sel = users.select(users.c.id.in_([7, 8])) sess = create_session() - self.assertEquals(sess.query(User).select_from(sel).all(), + eq_(sess.query(User).select_from(sel).all(), [ User(name='jack',id=7), User(name='ed',id=8) ] @@ -2200,7 +2205,7 @@ class SelectFromTest(QueryTest): sel = users.select(users.c.id.in_([7, 8])) sess = create_session() - self.assertEquals(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(), + eq_(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(), [ (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)), (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)), @@ -2210,7 +2215,7 @@ class SelectFromTest(QueryTest): ) adalias = aliased(Address) - self.assertEquals(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(), + eq_(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(), [ (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)), (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)), @@ -2238,16 +2243,16 @@ class SelectFromTest(QueryTest): # TODO: remove sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all() - self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ User(name=u'jack',id=7) ]) - self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ User(name=u'jack',id=7) ]) def go(): - self.assertEquals(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + eq_(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ User(name=u'jack',orders=[ Order(description=u'order 1',items=[ Item(description=u'item 1',keywords=[Keyword(name=u'red'), Keyword(name=u'big'), Keyword(name=u'round')]), @@ -2265,11 +2270,11 @@ class SelectFromTest(QueryTest): sess.expunge_all() sel2 = orders.select(orders.c.id.in_([1,2,3])) - self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').order_by(Order.id).all(), [ + eq_(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').order_by(Order.id).all(), [ Order(description=u'order 1',id=1), Order(description=u'order 2',id=2), ]) - self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [ + eq_(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [ Order(description=u'order 1',id=1), Order(description=u'order 2',id=2), ]) @@ -2285,7 +2290,7 @@ class SelectFromTest(QueryTest): sess = create_session() def go(): - self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id).all(), + eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id).all(), [ User(id=7, addresses=[Address(id=1)]), User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]) @@ -2295,14 +2300,14 @@ class SelectFromTest(QueryTest): sess.expunge_all() def go(): - self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).order_by(User.id).all(), + eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).order_by(User.id).all(), [User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])] ) self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])) + eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])) self.assert_sql_count(testing.db, go, 1) class CustomJoinTest(QueryTest): @@ -2329,14 +2334,16 @@ class SelfReferentialTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global nodes nodes = Table('nodes', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_id', Integer, ForeignKey('nodes.id')), Column('data', String(30))) - def insert_data(self): + @classmethod + def insert_data(cls): global Node class Node(Base): @@ -2399,7 +2406,7 @@ class SelfReferentialTest(_base.MappedTest): filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).first() assert node.data == 'n122' - self.assertEquals( + eq_( list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\ filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)), [('n122', 'n12', 'n1')]) @@ -2410,13 +2417,13 @@ class SelfReferentialTest(_base.MappedTest): n1 = aliased(Node) # using 'n1.parent' implicitly joins to unaliased Node - self.assertEquals( + eq_( sess.query(n1).join(n1.parent).filter(Node.data=='n1').all(), [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)] ) # explicit (new syntax) - self.assertEquals( + eq_( sess.query(n1).join((Node, n1.parent)).filter(Node.data=='n1').all(), [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)] ) @@ -2426,7 +2433,7 @@ class SelfReferentialTest(_base.MappedTest): parent = aliased(Node) grandparent = aliased(Node) - self.assertEquals( + eq_( sess.query(Node, parent, grandparent).\ join((Node.parent, parent), (parent.parent, grandparent)).\ filter(Node.data=='n122').filter(parent.data=='n12').\ @@ -2434,7 +2441,7 @@ class SelfReferentialTest(_base.MappedTest): (Node(data='n122'), Node(data='n12'), Node(data='n1')) ) - self.assertEquals( + eq_( sess.query(Node, parent, grandparent).\ join((Node.parent, parent), (parent.parent, grandparent)).\ filter(Node.data=='n122').filter(parent.data=='n12').\ @@ -2443,7 +2450,7 @@ class SelfReferentialTest(_base.MappedTest): ) # same, change order around - self.assertEquals( + eq_( sess.query(parent, grandparent, Node).\ join((Node.parent, parent), (parent.parent, grandparent)).\ filter(Node.data=='n122').filter(parent.data=='n12').\ @@ -2451,7 +2458,7 @@ class SelfReferentialTest(_base.MappedTest): (Node(data='n12'), Node(data='n1'), Node(data='n122')) ) - self.assertEquals( + eq_( sess.query(Node, parent, grandparent).\ join((Node.parent, parent), (parent.parent, grandparent)).\ filter(Node.data=='n122').filter(parent.data=='n12').\ @@ -2460,7 +2467,7 @@ class SelfReferentialTest(_base.MappedTest): (Node(data='n122'), Node(data='n12'), Node(data='n1')) ) - self.assertEquals( + eq_( sess.query(Node, parent, grandparent).\ join((Node.parent, parent), (parent.parent, grandparent)).\ filter(Node.data=='n122').filter(parent.data=='n12').\ @@ -2472,40 +2479,41 @@ class SelfReferentialTest(_base.MappedTest): def test_any(self): sess = create_session() - self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), []) - self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')]) - self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) + eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), []) + eq_(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')]) + eq_(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) def test_has(self): sess = create_session() - self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) - self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), []) - self.assertEquals(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')]) + eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), []) + eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')]) def test_contains(self): sess = create_session() n122 = sess.query(Node).filter(Node.data=='n122').one() - self.assertEquals(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')]) + eq_(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')]) n13 = sess.query(Node).filter(Node.data=='n13').one() - self.assertEquals(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')]) + eq_(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')]) def test_eq_ne(self): sess = create_session() n12 = sess.query(Node).filter(Node.data=='n12').one() - self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + eq_(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) - self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')]) + eq_(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')]) class SelfReferentialM2MTest(_base.MappedTest): run_setup_mappers = 'once' run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global nodes, node_to_nodes nodes = Table('nodes', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -2516,7 +2524,8 @@ class SelfReferentialM2MTest(_base.MappedTest): Column('right_node_id', Integer, ForeignKey('nodes.id'),primary_key=True), ) - def insert_data(self): + @classmethod + def insert_data(cls): global Node class Node(Base): @@ -2550,20 +2559,20 @@ class SelfReferentialM2MTest(_base.MappedTest): def test_any(self): sess = create_session() - self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')]) + eq_(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')]) def test_contains(self): sess = create_session() n4 = sess.query(Node).filter_by(data='n4').one() - self.assertEquals(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')]) - self.assertEquals(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')]) + eq_(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')]) + eq_(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')]) def test_explicit_join(self): sess = create_session() n1 = aliased(Node) - self.assertEquals( + eq_( sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data.in_(['n3', 'n7'])).order_by(Node.id).all(), [Node(data='n1'), Node(data='n2')] ) @@ -2575,7 +2584,7 @@ class ExternalColumnsTest(QueryTest): def test_external_columns_bad(self): - self.assertRaisesMessage(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={ + assert_raises_message(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={ 'concat': (users.c.id * 2), }) clear_mappers() @@ -2596,7 +2605,7 @@ class ExternalColumnsTest(QueryTest): sess.query(Address).options(eagerload('user')).all() - self.assertEquals(sess.query(User).all(), + eq_(sess.query(User).all(), [ User(id=7, concat=14, count=1), User(id=8, concat=16, count=3), @@ -2612,22 +2621,22 @@ class ExternalColumnsTest(QueryTest): Address(id=4, user=User(id=8, concat=16, count=3)), Address(id=5, user=User(id=9, concat=18, count=1)) ] - self.assertEquals(sess.query(Address).all(), address_result) + eq_(sess.query(Address).all(), address_result) # run the eager version twice to test caching of aliased clauses for x in range(2): sess.expunge_all() def go(): - self.assertEquals(sess.query(Address).options(eagerload('user')).all(), address_result) + eq_(sess.query(Address).options(eagerload('user')).all(), address_result) self.assert_sql_count(testing.db, go, 1) ualias = aliased(User) - self.assertEquals( + eq_( sess.query(Address, ualias).join(('user', ualias)).all(), [(address, address.user) for address in address_result] ) - self.assertEquals( + eq_( sess.query(Address, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(), [ (Address(id=1), 1), @@ -2638,7 +2647,7 @@ class ExternalColumnsTest(QueryTest): ] ) - self.assertEquals(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(), + eq_(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(), [ (Address(id=1), 14, 1), (Address(id=2), 16, 3), @@ -2649,7 +2658,7 @@ class ExternalColumnsTest(QueryTest): ) ua = aliased(User) - self.assertEquals(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(), + eq_(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(), [ (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1), (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3), @@ -2659,11 +2668,11 @@ class ExternalColumnsTest(QueryTest): ] ) - self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), + eq_(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] ) - self.assertEquals(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)), + eq_(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)), [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] ) @@ -2686,17 +2695,18 @@ class ExternalColumnsTest(QueryTest): sess = create_session() def go(): o1 = sess.query(Order).options(eagerload_all('address.user')).get(1) - self.assertEquals(o1.address.user.count, 1) + eq_(o1.address.user.count, 1) self.assert_sql_count(testing.db, go, 1) sess = create_session() def go(): o1 = sess.query(Order).options(eagerload_all('address.user')).first() - self.assertEquals(o1.address.user.count, 1) + eq_(o1.address.user.count, 1) self.assert_sql_count(testing.db, go, 1) class TestOverlyEagerEquivalentCols(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global base, sub1, sub2 base = Table('base', metadata, Column('id', Integer, primary_key=True), @@ -2747,14 +2757,15 @@ class TestOverlyEagerEquivalentCols(_base.MappedTest): q = sess.query(Base).outerjoin('sub2', aliased=True) assert sub1.c.id not in q._filter_aliases.equivalents - self.assertEquals( + eq_( sess.query(Base).join('sub1').outerjoin('sub2', aliased=True).\ filter(Sub1.id==1).one(), b1 ) class UpdateDeleteTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(32)), @@ -2765,15 +2776,17 @@ class UpdateDeleteTest(_base.MappedTest): Column('user_id', None, ForeignKey('users.id')), Column('title', String(32))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass class Document(_base.ComparableEntity): pass + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): users.insert().execute([ dict(id=1, name='john', age=25), dict(id=2, name='jack', age=47), @@ -2789,8 +2802,9 @@ class UpdateDeleteTest(_base.MappedTest): dict(id=3, user_id=2, title='baz'), ]) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users) mapper(Document, documents, properties={ 'user': relation(User, lazy=False, backref=backref('documents', lazy=True)) @@ -2964,17 +2978,17 @@ class UpdateDeleteTest(_base.MappedTest): sess = create_session(bind=testing.db, autocommit=False) rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age + 0}) - self.assertEquals(rowcount, 2) + eq_(rowcount, 2) rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10}) - self.assertEquals(rowcount, 2) + eq_(rowcount, 2) @testing.resolve_artifact_names def test_delete_returns_rowcount(self): sess = create_session(bind=testing.db, autocommit=False) rowcount = sess.query(User).filter(User.age > 26).delete(synchronize_session=False) - self.assertEquals(rowcount, 3) + eq_(rowcount, 3) @testing.resolve_artifact_names def test_update_with_eager_relations(self): @@ -3008,5 +3022,3 @@ class UpdateDeleteTest(_base.MappedTest): eq_(sess.query(Document.title).all(), zip(['baz'])) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/relationships.py b/test/orm/test_relationships.py index a0a8900b2..1bc074c31 100644 --- a/test/orm/relationships.py +++ b/test/orm/test_relationships.py @@ -1,10 +1,13 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import assert_raises, assert_raises_message import datetime -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, and_ -from testlib.sa.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers -from testlib.testing import eq_, startswith_ -from orm import _base, _fixtures +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, MetaData, and_ +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers +from sqlalchemy.test.testing import eq_, startswith_ +from test.orm import _base, _fixtures class RelationTest(_base.MappedTest): @@ -26,7 +29,8 @@ class RelationTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("tbl_a", metadata, Column("id", Integer, primary_key=True), Column("name", String(128))) @@ -43,7 +47,8 @@ class RelationTest(_base.MappedTest): Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")), Column("name", String(128))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_base.Entity): pass class B(_base.Entity): @@ -53,8 +58,9 @@ class RelationTest(_base.MappedTest): class D(_base.Entity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(A, tbl_a, properties=dict( c_rows=relation(C, cascade="all, delete-orphan", backref="a_row"))) mapper(B, tbl_b) @@ -63,8 +69,9 @@ class RelationTest(_base.MappedTest): mapper(D, tbl_d, properties=dict( b_row=relation(B))) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): session = create_session() a = A(name='a1') b = B(name='b1') @@ -102,7 +109,8 @@ class RelationTest2(_base.MappedTest): key where one column in the foreign key is 'joined to itself'. """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('company_t', metadata, Column('company_id', Integer, primary_key=True), Column('name', sa.Unicode(30))) @@ -218,7 +226,8 @@ class RelationTest2(_base.MappedTest): assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5' class RelationTest3(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("jobs", metadata, Column("jobno", sa.Unicode(15), primary_key=True), Column("created", sa.DateTime, nullable=False, @@ -257,8 +266,9 @@ class RelationTest3(_base.MappedTest): ["jobno", "pagename"], ["pages.jobno", "pages.pagename"])) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Job(_base.Entity): def create_page(self, pagename): return Page(job=self, pagename=pagename) @@ -360,7 +370,8 @@ class RelationTest3(_base.MappedTest): class RelationTest4(_base.MappedTest): """Syncrules on foreign keys that are also primary""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("tableA", metadata, Column("id",Integer,primary_key=True), Column("foo",Integer,), @@ -369,7 +380,8 @@ class RelationTest4(_base.MappedTest): Column("id",Integer,ForeignKey("tableA.id"),primary_key=True), test_needs_fk=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_base.Entity): pass @@ -537,7 +549,8 @@ class RelationTest4(_base.MappedTest): class RelationTest5(_base.MappedTest): """Test a map to a select that relates to a map to the table.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('items', metadata, Column('item_policy_num', String(10), primary_key=True, key='policyNum'), @@ -605,7 +618,8 @@ class RelationTest6(_base.MappedTest): """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('tags', metadata, Column("id", Integer, primary_key=True), Column("data", String(50)), ) @@ -659,7 +673,8 @@ class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest): """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): subscriber_table = Table('subscriber', metadata, Column('id', Integer, primary_key=True), Column('dummy', String(10)) # to appease older sqlite version @@ -671,8 +686,9 @@ class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest): Column('type', String(1), primary_key=True), ) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): subscriber_and_address = subscriber.join(address, and_(address.c.subscriber_id==subscriber.c.id, address.c.type.in_(['A', 'B', 'C']))) @@ -762,7 +778,7 @@ class ManualBackrefTest(_fixtures.FixtureTest): 'user':relation(User, back_populates='addresses') }) - self.assertRaises(sa.exc.InvalidRequestError, compile_mappers) + assert_raises(sa.exc.InvalidRequestError, compile_mappers) @testing.resolve_artifact_names def test_invalid_target(self): @@ -775,7 +791,7 @@ class ManualBackrefTest(_fixtures.FixtureTest): 'dingaling':relation(Dingaling) }) - self.assertRaisesMessage(sa.exc.ArgumentError, + assert_raises_message(sa.exc.ArgumentError, r"reverse_property 'dingaling' on relation User.addresses references " "relation Address.dingaling, which does not reference mapper Mapper\|User\|users", compile_mappers) @@ -794,7 +810,7 @@ class JoinConditionErrorTest(testing.TestBase): c1id = Column('c1id', Integer, ForeignKey('c1.id')) c2 = relation(C1, primaryjoin=C1.id) - self.assertRaises(sa.exc.ArgumentError, compile_mappers) + assert_raises(sa.exc.ArgumentError, compile_mappers) def test_clauseelement_pj_false(self): from sqlalchemy.ext.declarative import declarative_base @@ -808,7 +824,7 @@ class JoinConditionErrorTest(testing.TestBase): c1id = Column('c1id', Integer, ForeignKey('c1.id')) c2 = relation(C1, primaryjoin="x"=="y") - self.assertRaises(sa.exc.ArgumentError, compile_mappers) + assert_raises(sa.exc.ArgumentError, compile_mappers) def test_fk_error_raised(self): @@ -834,7 +850,7 @@ class JoinConditionErrorTest(testing.TestBase): mapper(C1, t1, properties={'c2':relation(C2)}) mapper(C2, t3) - self.assertRaises(sa.exc.NoReferencedColumnError, compile_mappers) + assert_raises(sa.exc.NoReferencedColumnError, compile_mappers) def test_join_error_raised(self): m = MetaData() @@ -858,15 +874,16 @@ class JoinConditionErrorTest(testing.TestBase): mapper(C1, t1, properties={'c2':relation(C2)}) mapper(C2, t3) - self.assertRaises(sa.exc.ArgumentError, compile_mappers) + assert_raises(sa.exc.ArgumentError, compile_mappers) - def tearDown(self): + def teardown(self): clear_mappers() class TypeMatchTest(_base.MappedTest): """test errors raised when trying to add items whose type is not handled by a relation""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("a", metadata, Column('aid', Integer, primary_key=True), Column('data', String(30))) @@ -924,7 +941,7 @@ class TypeMatchTest(_base.MappedTest): sess.add(a1) sess.add(b1) sess.add(c1) - self.assertRaisesMessage(sa.orm.exc.FlushError, + assert_raises_message(sa.orm.exc.FlushError, "Attempting to flush an item", sess.flush) @testing.resolve_artifact_names @@ -945,7 +962,7 @@ class TypeMatchTest(_base.MappedTest): sess.add(a1) sess.add(b1) sess.add(c1) - self.assertRaisesMessage(sa.orm.exc.FlushError, + assert_raises_message(sa.orm.exc.FlushError, "Attempting to flush an item", sess.flush) @testing.resolve_artifact_names @@ -962,7 +979,7 @@ class TypeMatchTest(_base.MappedTest): sess = create_session() sess.add(b1) sess.add(d1) - self.assertRaisesMessage(sa.orm.exc.FlushError, + assert_raises_message(sa.orm.exc.FlushError, "Attempting to flush an item", sess.flush) @testing.resolve_artifact_names @@ -977,12 +994,13 @@ class TypeMatchTest(_base.MappedTest): d1 = D() d1.a = b1 sess = create_session() - self.assertRaisesMessage(AssertionError, + assert_raises_message(AssertionError, "doesn't handle objects of type", sess.add, d1) class TypedAssociationTable(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): class MySpecialType(sa.types.TypeDecorator): impl = String def process_bind_param(self, value, dialect): @@ -1033,7 +1051,8 @@ class TypedAssociationTable(_base.MappedTest): class ViewOnlyOverlappingNames(_base.MappedTest): """'viewonly' mappings with overlapping PK column names.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("t1", metadata, Column('id', Integer, primary_key=True), Column('data', String(40))) @@ -1092,7 +1111,8 @@ class ViewOnlyOverlappingNames(_base.MappedTest): class ViewOnlyUniqueNames(_base.MappedTest): """'viewonly' mappings with unique PK column names.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("t1", metadata, Column('t1id', Integer, primary_key=True), Column('data', String(40))) @@ -1182,7 +1202,8 @@ class ViewOnlyLocalRemoteM2M(testing.TestBase): class ViewOnlyNonEquijoin(_base.MappedTest): """'viewonly' mappings based on non-equijoins.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foos', metadata, Column('id', Integer, primary_key=True)) Table('bars', metadata, @@ -1223,7 +1244,8 @@ class ViewOnlyNonEquijoin(_base.MappedTest): class ViewOnlyRepeatedRemoteColumn(_base.MappedTest): """'viewonly' mappings that contain the same 'remote' column twice""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foos', metadata, Column('id', Integer, primary_key=True), Column('bid1', Integer,ForeignKey('bars.id')), @@ -1270,7 +1292,8 @@ class ViewOnlyRepeatedRemoteColumn(_base.MappedTest): class ViewOnlyRepeatedLocalColumn(_base.MappedTest): """'viewonly' mappings that contain the same 'local' column twice""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foos', metadata, Column('id', Integer, primary_key=True), Column('data', String(50))) @@ -1317,7 +1340,8 @@ class ViewOnlyRepeatedLocalColumn(_base.MappedTest): class ViewOnlyComplexJoin(_base.MappedTest): """'viewonly' mappings with a complex join condition.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50))) @@ -1332,7 +1356,8 @@ class ViewOnlyComplexJoin(_base.MappedTest): Column('t2id', Integer, ForeignKey('t2.id')), Column('t3id', Integer, ForeignKey('t3.id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T1(_base.ComparableEntity): pass class T2(_base.ComparableEntity): @@ -1379,14 +1404,15 @@ class ViewOnlyComplexJoin(_base.MappedTest): 't1':relation(T1), 't3s':relation(T3, secondary=t2tot3)}) mapper(T3, t3) - self.assertRaisesMessage(sa.exc.ArgumentError, + assert_raises_message(sa.exc.ArgumentError, "Specify remote_side argument", sa.orm.compile_mappers) class ExplicitLocalRemoteTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', String(50), primary_key=True), Column('data', String(50))) @@ -1395,8 +1421,9 @@ class ExplicitLocalRemoteTest(_base.MappedTest): Column('data', String(50)), Column('t1id', String(50))) + @classmethod @testing.resolve_artifact_names - def setup_classes(self): + def setup_classes(cls): class T1(_base.ComparableEntity): pass class T2(_base.ComparableEntity): @@ -1508,7 +1535,7 @@ class ExplicitLocalRemoteTest(_base.MappedTest): foreign_keys=[t2.c.t1id], remote_side=[t2.c.t1id])}) mapper(T2, t2) - self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers) + assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers) @testing.resolve_artifact_names def test_escalation_2(self): @@ -1517,18 +1544,20 @@ class ExplicitLocalRemoteTest(_base.MappedTest): primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), _local_remote_pairs=[(t1.c.id, t2.c.t1id)])}) mapper(T2, t2) - self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers) + assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers) class InvalidRemoteSideTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)), Column('t_id', Integer, ForeignKey('t1.id')) ) + @classmethod @testing.resolve_artifact_names - def setup_classes(self): + def setup_classes(cls): class T1(_base.ComparableEntity): pass @@ -1538,7 +1567,7 @@ class InvalidRemoteSideTest(_base.MappedTest): 't1s':relation(T1, backref='parent') }) - self.assertRaisesMessage(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " + assert_raises_message(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " "both of the same direction <symbol 'ONETOMANY>. Did you " "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) @@ -1548,7 +1577,7 @@ class InvalidRemoteSideTest(_base.MappedTest): 't1s':relation(T1, backref=backref('parent', remote_side=t1.c.id), remote_side=t1.c.id) }) - self.assertRaisesMessage(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " + assert_raises_message(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " "both of the same direction <symbol 'MANYTOONE>. Did you " "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) @@ -1560,7 +1589,7 @@ class InvalidRemoteSideTest(_base.MappedTest): }) # can't be sure of ordering here - self.assertRaisesMessage(sa.exc.ArgumentError, + assert_raises_message(sa.exc.ArgumentError, "both of the same direction <symbol 'ONETOMANY>. Did you " "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) @@ -1572,14 +1601,15 @@ class InvalidRemoteSideTest(_base.MappedTest): }) # can't be sure of ordering here - self.assertRaisesMessage(sa.exc.ArgumentError, + assert_raises_message(sa.exc.ArgumentError, "both of the same direction <symbol 'MANYTOONE>. Did you " "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) class InvalidRelationEscalationTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foos', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer)) @@ -1587,7 +1617,8 @@ class InvalidRelationEscalationTest(_base.MappedTest): Column('id', Integer, primary_key=True), Column('fid', Integer)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Foo(_base.Entity): pass class Bar(_base.Entity): @@ -1599,7 +1630,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): 'bars':relation(Bar)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine join condition between parent/child " "tables on relation", sa.orm.compile_mappers) @@ -1610,7 +1641,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): 'foos':relation(Foo)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine join condition between parent/child " "tables on relation", sa.orm.compile_mappers) @@ -1622,7 +1653,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): primaryjoin=foos.c.id>bars.c.fid)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1635,7 +1666,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): foreign_keys=bars.c.fid)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not locate any equated, locally mapped column pairs " "for primaryjoin condition", sa.orm.compile_mappers) @@ -1648,7 +1679,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): foreign_keys=[foos.c.id, bars.c.fid])}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Do the columns in 'foreign_keys' represent only the " "'foreign' columns in this join condition ?", @@ -1665,7 +1696,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): )}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "could not determine any local/remote column pairs", sa.orm.compile_mappers) @@ -1681,7 +1712,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): )}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "could not determine any local/remote column pairs", sa.orm.compile_mappers) @@ -1694,7 +1725,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): primaryjoin=foos.c.id>foos.c.fid)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1707,7 +1738,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): foreign_keys=[foos.c.fid])}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not locate any equated, locally mapped column pairs " "for primaryjoin condition", sa.orm.compile_mappers) @@ -1720,7 +1751,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): viewonly=True)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1733,7 +1764,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): viewonly=True)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Specify the 'foreign_keys' argument to indicate which columns " "on the relation are foreign.", sa.orm.compile_mappers) @@ -1756,7 +1787,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): primaryjoin=foos.c.id==bars.c.fid)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1767,7 +1798,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): 'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid)}) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1779,7 +1810,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): primaryjoin=foos.c.id==foos.c.fid, foreign_keys=[bars.c.id])}) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1787,7 +1818,8 @@ class InvalidRelationEscalationTest(_base.MappedTest): class InvalidRelationEscalationTestM2M(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foos', metadata, Column('id', Integer, primary_key=True)) Table('foobars', metadata, @@ -1795,8 +1827,9 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): Table('bars', metadata, Column('id', Integer, primary_key=True)) + @classmethod @testing.resolve_artifact_names - def setup_classes(self): + def setup_classes(cls): class Foo(_base.Entity): pass class Bar(_base.Entity): @@ -1808,7 +1841,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): 'bars': relation(Bar, secondary=foobars)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine join condition between parent/child tables " "on relation", sa.orm.compile_mappers) @@ -1821,7 +1854,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): primaryjoin=foos.c.id > foobars.c.fid)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine join condition between parent/child tables " "on relation", @@ -1836,7 +1869,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): secondaryjoin=foobars.c.bid<=bars.c.id)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1851,7 +1884,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): foreign_keys=[foobars.c.fid])}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for secondaryjoin " "condition", sa.orm.compile_mappers) @@ -1866,11 +1899,9 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): foreign_keys=[foobars.c.fid, foobars.c.bid])}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for " "secondaryjoin condition", sa.orm.compile_mappers) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/scoping.py b/test/orm/test_scoping.py index bdfc5a9d5..2117e8dcc 100644 --- a/test/orm/scoping.py +++ b/test/orm/test_scoping.py @@ -1,10 +1,13 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing from sqlalchemy.orm import scoped_session -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, query -from testlib.testing import eq_ -from orm import _base +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, query +from sqlalchemy.test.testing import eq_ +from test.orm import _base class _ScopedTest(_base.MappedTest): @@ -15,18 +18,21 @@ class _ScopedTest(_base.MappedTest): _artifact_registries = ( _base.MappedTest._artifact_registries + ('scoping',)) - def setUpAll(self): - type(self).scoping = _base.adict() - _base.MappedTest.setUpAll(self) + @classmethod + def setup_class(cls): + cls.scoping = _base.adict() + super(_ScopedTest, cls).setup_class() - def tearDownAll(self): - self.scoping.clear() - _base.MappedTest.tearDownAll(self) + @classmethod + def teardown_class(cls): + cls.scoping.clear() + super(_ScopedTest, cls).teardown_class() class ScopedSessionTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('table1', metadata, Column('id', Integer, primary_key=True), Column('data', String(30))) @@ -73,7 +79,8 @@ class ScopedSessionTest(_base.MappedTest): class ScopedMapperTest(_ScopedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('table1', metadata, Column('id', Integer, primary_key=True), Column('data', String(30))) @@ -81,24 +88,27 @@ class ScopedMapperTest(_ScopedTest): Column('id', Integer, primary_key=True), Column('someid', None, ForeignKey('table1.id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class SomeObject(_base.ComparableEntity): pass class SomeOtherObject(_base.ComparableEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): Session = scoped_session(sa.orm.create_session) Session.mapper(SomeObject, table1, properties={ 'options':relation(SomeOtherObject) }) Session.mapper(SomeOtherObject, table2) - self.scoping['Session'] = Session + cls.scoping['Session'] = Session + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): s = SomeObject() s.id = 1 s.data = 'hello' @@ -145,7 +155,7 @@ class ScopedMapperTest(_ScopedTest): scope.mapper(B, table2) A(foo='bar') - self.assertRaises(TypeError, B, foo='bar') + assert_raises(TypeError, B, foo='bar') scope = scoped_session(sa.orm.sessionmaker()) @@ -158,7 +168,7 @@ class ScopedMapperTest(_ScopedTest): scope.mapper(C, table1) scope.mapper(D, table2) - self.assertRaises(TypeError, C, foo='bar') + assert_raises(TypeError, C, foo='bar') D(foo='bar') @testing.resolve_artifact_names @@ -170,7 +180,7 @@ class ScopedMapperTest(_ScopedTest): Session.mapper(ValidatedOtherObject, table2, validate=True) v1 = ValidatedOtherObject(someid=12) - self.assertRaises(sa.exc.ArgumentError, ValidatedOtherObject, + assert_raises(sa.exc.ArgumentError, ValidatedOtherObject, someid=12, bogus=345) @testing.resolve_artifact_names @@ -186,7 +196,8 @@ class ScopedMapperTest(_ScopedTest): class ScopedMapperTest2(_ScopedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('table1', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), @@ -196,14 +207,16 @@ class ScopedMapperTest2(_ScopedTest): Column('someid', None, ForeignKey('table1.id')), Column('somedata', String(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class BaseClass(_base.ComparableEntity): pass class SubClass(BaseClass): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): Session = scoped_session(sa.orm.sessionmaker()) Session.mapper(BaseClass, table1, @@ -213,7 +226,7 @@ class ScopedMapperTest2(_ScopedTest): polymorphic_identity='sub', inherits=BaseClass) - self.scoping['Session'] = Session + cls.scoping['Session'] = Session @testing.resolve_artifact_names def test_inheritance(self): @@ -234,5 +247,3 @@ class ScopedMapperTest2(_ScopedTest): SubClass.query.all()) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/selectable.py b/test/orm/test_selectable.py index 74c41c852..0a2025360 100644 --- a/test/orm/selectable.py +++ b/test/orm/test_selectable.py @@ -1,22 +1,27 @@ """Generic mapping to Select statements""" -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, String, Integer, select -from testlib.sa.orm import mapper, create_session -from testlib.testing import eq_ -from orm import _base +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import String, Integer, select +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base # TODO: more tests mapping to selects class SelectableNoFromsTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('common', metadata, Column('id', Integer, primary_key=True), Column('data', Integer), Column('extra', String(45))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Subset(_base.ComparableEntity): pass @@ -24,7 +29,7 @@ class SelectableNoFromsTest(_base.MappedTest): def test_no_tables(self): selectable = select(["x", "y", "z"]) - self.assertRaisesMessage(sa.exc.InvalidRequestError, + assert_raises_message(sa.exc.InvalidRequestError, "Could not find any Table objects", mapper, Subset, selectable) @@ -48,5 +53,3 @@ class SelectableNoFromsTest(_base.MappedTest): Subset(data=1)) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/session.py b/test/orm/test_session.py index 6cbd62a50..3020d66e9 100644 --- a/test/orm/session.py +++ b/test/orm/test_session.py @@ -1,14 +1,17 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import gc import inspect import pickle from sqlalchemy.orm import create_session, sessionmaker, attributes -from testlib import engines, sa, testing, config -from testlib.sa import Table, Column, Integer, String, Sequence -from testlib.sa.orm import mapper, relation, backref, eagerload -from testlib.testing import eq_ -from engine import _base as engine_base -from orm import _base, _fixtures +import sqlalchemy as sa +from sqlalchemy.test import engines, testing, config +from sqlalchemy import Integer, String, Sequence +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, eagerload +from sqlalchemy.test.testing import eq_ +from test.engine import _base as engine_base +from test.orm import _base, _fixtures class SessionTest(_fixtures.FixtureTest): @@ -508,7 +511,7 @@ class SessionTest(_fixtures.FixtureTest): sess.commit() - self.assertEquals(set(sess.query(User).all()), set([u2])) + eq_(set(sess.query(User).all()), set([u2])) sess.begin() sess.begin_nested() @@ -518,7 +521,7 @@ class SessionTest(_fixtures.FixtureTest): sess.commit() # commit the nested transaction sess.rollback() - self.assertEquals(set(sess.query(User).all()), set([u2])) + eq_(set(sess.query(User).all()), set([u2])) sess.close() @@ -541,7 +544,7 @@ class SessionTest(_fixtures.FixtureTest): sess.close() - self.assertEquals(len(sess.query(User).all()), 1) + eq_(len(sess.query(User).all()), 1) t1 = sess.begin() t2 = sess.begin_nested() @@ -572,7 +575,7 @@ class SessionTest(_fixtures.FixtureTest): sess.close() - self.assertEquals(len(sess.query(User).all()), 1) + eq_(len(sess.query(User).all()), 1) @testing.resolve_artifact_names def test_error_on_using_inactive_session(self): @@ -587,7 +590,7 @@ class SessionTest(_fixtures.FixtureTest): sess.flush() sess.rollback() - self.assertRaisesMessage(sa.exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True) + assert_raises_message(sa.exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True) sess.close() @testing.resolve_artifact_names @@ -612,7 +615,7 @@ class SessionTest(_fixtures.FixtureTest): sess.flush() assert transaction._connection_for_bind(testing.db) is transaction._connection_for_bind(c) is c - self.assertRaisesMessage(sa.exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect()) + assert_raises_message(sa.exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect()) transaction.rollback() assert len(sess.query(User).all()) == 0 @@ -667,8 +670,8 @@ class SessionTest(_fixtures.FixtureTest): user = User(name='u1') - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is not persisted", s.update, user) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is not persisted", s.delete, user) + assert_raises_message(sa.exc.InvalidRequestError, "is not persisted", s.update, user) + assert_raises_message(sa.exc.InvalidRequestError, "is not persisted", s.delete, user) s.add(user) s.flush() @@ -694,13 +697,13 @@ class SessionTest(_fixtures.FixtureTest): assert user in s assert user not in s.dirty - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is already persistent", s.save, user) + assert_raises_message(sa.exc.InvalidRequestError, "is already persistent", s.save, user) s2 = create_session() - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is already attached to session", s2.delete, user) + assert_raises_message(sa.exc.InvalidRequestError, "is already attached to session", s2.delete, user) u2 = s2.query(User).get(user.id) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "another instance with key", s.delete, u2) + assert_raises_message(sa.exc.InvalidRequestError, "another instance with key", s.delete, u2) s.expire(user) s.expunge(user) @@ -1029,7 +1032,7 @@ class SessionTest(_fixtures.FixtureTest): u = User(name='u1') sess.add(u) sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='another u1'), User(name='u1') @@ -1037,7 +1040,7 @@ class SessionTest(_fixtures.FixtureTest): ) sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='another u1'), User(name='u1') @@ -1046,7 +1049,7 @@ class SessionTest(_fixtures.FixtureTest): u.name='u2' sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='another u1'), User(name='another u2'), @@ -1056,7 +1059,7 @@ class SessionTest(_fixtures.FixtureTest): sess.delete(u) sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='another u1'), ] @@ -1075,7 +1078,7 @@ class SessionTest(_fixtures.FixtureTest): u = User(name='u1') sess.add(u) sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='u1') ] @@ -1084,7 +1087,7 @@ class SessionTest(_fixtures.FixtureTest): sess.add(User(name='u2')) sess.flush() sess.expunge_all() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='u1 modified'), User(name='u2') @@ -1102,7 +1105,7 @@ class SessionTest(_fixtures.FixtureTest): sess = create_session(extension=MyExt()) sess.add(User(name='foo')) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "already flushing", sess.flush) + assert_raises_message(sa.exc.InvalidRequestError, "already flushing", sess.flush) @testing.resolve_artifact_names def test_pickled_update(self): @@ -1113,7 +1116,7 @@ class SessionTest(_fixtures.FixtureTest): u1 = User(name='u1') sess1.add(u1) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "already attached to session", sess2.add, u1) + assert_raises_message(sa.exc.InvalidRequestError, "already attached to session", sess2.add, u1) u2 = pickle.loads(pickle.dumps(u1)) @@ -1139,7 +1142,7 @@ class SessionTest(_fixtures.FixtureTest): assert u2 is not None and u2 is not u1 assert u2 in sess - self.assertRaises(Exception, lambda: sess.add(u1)) + assert_raises(Exception, lambda: sess.add(u1)) sess.expunge(u2) assert u2 not in sess @@ -1181,24 +1184,26 @@ class DisposedStates(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)) ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): global T class T(object): def __init__(self, data): self.data = data mapper(T, t1) - def tearDown(self): + def teardown(self): from sqlalchemy.orm.session import _sessions _sessions.clear() - super(DisposedStates, self).tearDown() + super(DisposedStates, self).teardown() def _set_imap_in_disposal(self, sess, *objs): """remove selected objects from the given session, as though they @@ -1291,7 +1296,7 @@ class SessionInterface(testing.TestBase): def x_raises_(obj, method, *args, **kw): watchdog.add(method) callable_ = getattr(obj, method) - self.assertRaises(sa.orm.exc.UnmappedInstanceError, + assert_raises(sa.orm.exc.UnmappedInstanceError, callable_, *args, **kw) def raises_(method, *args, **kw): @@ -1343,7 +1348,7 @@ class SessionInterface(testing.TestBase): def raises_(method, *args, **kw): watchdog.add(method) callable_ = getattr(create_session(), method) - self.assertRaises(sa.orm.exc.UnmappedClassError, + assert_raises(sa.orm.exc.UnmappedClassError, callable_, *args, **kw) raises_('connection', mapper=user_arg) @@ -1395,32 +1400,27 @@ class SessionInterface(testing.TestBase): class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest): - def create_engine(self): + @classmethod + def create_engine(cls): return engines.testing_engine(options=dict(strategy='threadlocal')) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(20)), test_needs_acid=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users) - def setUpAll(self): - engine_base.AltEngineTest.setUpAll(self) - _base.MappedTest.setUpAll(self) - - - def tearDownAll(self): - _base.MappedTest.tearDownAll(self) - engine_base.AltEngineTest.tearDownAll(self) - @testing.exclude('mysql', '<', (5, 0, 3), 'FIXME: unknown') @testing.resolve_artifact_names def test_session_nesting(self): @@ -1432,5 +1432,3 @@ class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest): self.engine.commit() -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/transaction.py b/test/orm/test_transaction.py index 0fcd55df3..5aa541cda 100644 --- a/test/orm/transaction.py +++ b/test/orm/test_transaction.py @@ -1,13 +1,13 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy.orm import attributes from sqlalchemy import exc as sa_exc from sqlalchemy.orm import * -from testlib import testing -from orm import _base -from orm._fixtures import FixtureTest, User, Address, users, addresses +from sqlalchemy.test import testing +from test.orm import _base +from test.orm._fixtures import FixtureTest, User, Address, users, addresses import gc @@ -16,7 +16,8 @@ class TransactionTest(FixtureTest): run_inserts = None session = sessionmaker() - def setup_mappers(self): + @classmethod + def setup_mappers(cls): mapper(User, users, properties={ 'addresses':relation(Address, backref='user', cascade="all, delete-orphan"), @@ -32,7 +33,7 @@ class FixtureDataTest(TransactionTest): u1 = sess.query(User).get(7) u1.name = 'ed' sess.rollback() - self.assertEquals(u1.name, 'jack') + eq_(u1.name, 'jack') def test_commit_persistent(self): sess = self.session() @@ -40,7 +41,7 @@ class FixtureDataTest(TransactionTest): u1.name = 'ed' sess.flush() sess.commit() - self.assertEquals(u1.name, 'ed') + eq_(u1.name, 'ed') def test_concurrent_commit_persistent(self): s1 = self.session() @@ -157,7 +158,7 @@ class AutoExpireTest(TransactionTest): u1.addresses.remove(a1) s.flush() - self.assertEquals(s.query(Address).filter(Address.email_address=='foo').all(), []) + eq_(s.query(Address).filter(Address.email_address=='foo').all(), []) s.rollback() assert a1 not in s.deleted assert u1.addresses == [a1] @@ -168,7 +169,7 @@ class AutoExpireTest(TransactionTest): sess.add(u1) sess.flush() sess.commit() - self.assertEquals(u1.name, 'newuser') + eq_(u1.name, 'newuser') def test_concurrent_commit_pending(self): @@ -212,8 +213,8 @@ class RollbackRecoverTest(TransactionTest): u1.name = 'edward' a1.email_address = 'foober' s.add(u2) - self.assertRaises(sa_exc.FlushError, s.commit) - self.assertRaises(sa_exc.InvalidRequestError, s.commit) + assert_raises(sa_exc.FlushError, s.commit) + assert_raises(sa_exc.InvalidRequestError, s.commit) s.rollback() assert u2 not in s assert a2 not in s @@ -224,7 +225,7 @@ class RollbackRecoverTest(TransactionTest): u1.name = 'edward' a1.email_address = 'foober' s.commit() - self.assertEquals( + eq_( s.query(User).all(), [User(id=1, name='edward', addresses=[Address(email_address='foober')])] ) @@ -244,8 +245,8 @@ class RollbackRecoverTest(TransactionTest): a1.email_address = 'foober' s.begin_nested() s.add(u2) - self.assertRaises(sa_exc.FlushError, s.commit) - self.assertRaises(sa_exc.InvalidRequestError, s.commit) + assert_raises(sa_exc.FlushError, s.commit) + assert_raises(sa_exc.InvalidRequestError, s.commit) s.rollback() assert u2 not in s assert a2 not in s @@ -271,15 +272,15 @@ class SavepointTest(TransactionTest): u1.name = 'edward' u2.name = 'jackward' s.add_all([u3, u4]) - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) s.rollback() assert u1.name == 'ed' assert u2.name == 'jack' - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) s.commit() assert u1.name == 'ed' assert u2.name == 'jack' - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) @testing.requires.savepoints def test_savepoint_delete(self): @@ -287,11 +288,11 @@ class SavepointTest(TransactionTest): u1 = User(name='ed') s.add(u1) s.commit() - self.assertEquals(s.query(User).filter_by(name='ed').count(), 1) + eq_(s.query(User).filter_by(name='ed').count(), 1) s.begin_nested() s.delete(u1) s.commit() - self.assertEquals(s.query(User).filter_by(name='ed').count(), 0) + eq_(s.query(User).filter_by(name='ed').count(), 0) s.commit() @testing.requires.savepoints @@ -307,16 +308,16 @@ class SavepointTest(TransactionTest): u1.name = 'edward' u2.name = 'jackward' s.add_all([u3, u4]) - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) s.commit() def go(): assert u1.name == 'edward' assert u2.name == 'jackward' - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) self.assert_sql_count(testing.db, go, 1) s.commit() - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) @testing.requires.savepoints def test_savepoint_rollback_collections(self): @@ -330,20 +331,20 @@ class SavepointTest(TransactionTest): s.begin_nested() u2 = User(name='jack', addresses=[Address(email_address='bat')]) s.add(u2) - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), User(name='jack', addresses=[Address(email_address='bat')]) ] ) s.rollback() - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), ] ) s.commit() - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), ] @@ -361,21 +362,21 @@ class SavepointTest(TransactionTest): s.begin_nested() u2 = User(name='jack', addresses=[Address(email_address='bat')]) s.add(u2) - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), User(name='jack', addresses=[Address(email_address='bat')]) ] ) s.commit() - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), User(name='jack', addresses=[Address(email_address='bat')]) ] ) s.commit() - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), User(name='jack', addresses=[Address(email_address='bat')]) @@ -476,7 +477,7 @@ class AccountingFlagsTest(TransactionTest): class AutoCommitTest(TransactionTest): def test_begin_nested_requires_trans(self): sess = create_session(autocommit=True) - self.assertRaises(sa_exc.InvalidRequestError, sess.begin_nested) + assert_raises(sa_exc.InvalidRequestError, sess.begin_nested) def test_begin_preflush(self): sess = create_session(autocommit=True) @@ -495,5 +496,3 @@ class AutoCommitTest(TransactionTest): -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/unitofwork.py b/test/orm/test_unitofwork.py index c5e3afd01..f95346902 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -1,19 +1,21 @@ # coding: utf-8 """Tests unitofwork operations.""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import datetime import operator from sqlalchemy.orm import mapper as orm_mapper -from testlib import engines, sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, literal_column -from testlib.sa.orm import mapper, relation, create_session, column_property -from testlib.testing import eq_, ne_ -from orm import _base, _fixtures -from engine import _base as engine_base -import pickleable -from testlib.assertsql import AllOf, CompiledSQL +import sqlalchemy as sa +from sqlalchemy.test import engines, testing, pickleable +from sqlalchemy import Integer, String, ForeignKey, literal_column +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, column_property +from sqlalchemy.test.testing import eq_, ne_ +from test.orm import _base, _fixtures +from test.engine import _base as engine_base +from sqlalchemy.test.assertsql import AllOf, CompiledSQL import gc class UnitOfWorkTest(object): @@ -22,7 +24,8 @@ class UnitOfWorkTest(object): class HistoryTest(_fixtures.FixtureTest): run_inserts = None - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass class Address(_base.ComparableEntity): @@ -51,14 +54,16 @@ class HistoryTest(_fixtures.FixtureTest): class VersioningTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('version_table', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('version_id', Integer, nullable=False), Column('value', String(40), nullable=False)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Foo(_base.ComparableEntity): pass @@ -86,7 +91,7 @@ class VersioningTest(_base.MappedTest): # Only dialects with a sane rowcount can detect the # ConcurrentModificationError if testing.db.dialect.supports_sane_rowcount: - self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.commit) + assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit) s1.rollback() else: s1.commit() @@ -102,7 +107,7 @@ class VersioningTest(_base.MappedTest): s1.delete(f2) if testing.db.dialect.supports_sane_multi_rowcount: - self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.commit) + assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit) else: s1.commit() @@ -124,7 +129,7 @@ class VersioningTest(_base.MappedTest): s2.commit() # load, version is wrong - self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id) + assert_raises(sa.orm.exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id) # reload it s1.query(Foo).populate_existing().get(f1s1.id) @@ -153,7 +158,8 @@ class VersioningTest(_base.MappedTest): class UnicodeTest(_base.MappedTest): __requires__ = ('unicode_connections',) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('uni_t1', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -163,7 +169,8 @@ class UnicodeTest(_base.MappedTest): test_needs_autoincrement=True), Column('txt', sa.Unicode(50), ForeignKey('uni_t1'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Test(_base.BasicEntity): pass class Test2(_base.BasicEntity): @@ -205,10 +212,12 @@ class UnicodeTest(_base.MappedTest): class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest): __requires__ = ('unicode_connections', 'unicode_ddl',) - def create_engine(self): + @classmethod + def create_engine(cls): return engines.utf8_engine() - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): t1 = Table('unitable1', metadata, Column(u'méil', Integer, primary_key=True, key='a', test_needs_autoincrement=True), Column(u'\u6e2c\u8a66', Integer, key='b'), @@ -223,16 +232,16 @@ class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest): test_needs_fk=True, test_needs_autoincrement=True) - self.tables['t1'] = t1 - self.tables['t2'] = t2 + cls.tables['t1'] = t1 + cls.tables['t2'] = t2 - def setUpAll(self): - engine_base.AltEngineTest.setUpAll(self) - _base.MappedTest.setUpAll(self) + @classmethod + def setup_class(cls): + super(UnicodeSchemaTest, cls).setup_class() - def tearDownAll(self): - _base.MappedTest.tearDownAll(self) - engine_base.AltEngineTest.tearDownAll(self) + @classmethod + def teardown_class(cls): + super(UnicodeSchemaTest, cls).teardown_class() @testing.fails_on('mssql', 'pyodbc returns a non unicode encoding of the results description.') @testing.resolve_artifact_names @@ -298,19 +307,22 @@ class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest): class MutableTypesTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('mutable_t', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', sa.PickleType), Column('val', sa.Unicode(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Foo(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Foo, mutable_t) @testing.resolve_artifact_names @@ -433,20 +445,23 @@ class MutableTypesTest(_base.MappedTest): self.sql_count_(0, session.commit) -class PickledDicts(_base.MappedTest): +class PickledDictsTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('mutable_t', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', sa.PickleType(comparator=operator.eq))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Foo(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Foo, mutable_t) @testing.resolve_artifact_names @@ -519,7 +534,8 @@ class PickledDicts(_base.MappedTest): class PKTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('multipk1', metadata, Column('multi_id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -537,7 +553,8 @@ class PKTest(_base.MappedTest): Column('date_assigned', sa.Date, key='assigned', primary_key=True), Column('data', String(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Entry(_base.BasicEntity): pass @@ -587,7 +604,8 @@ class PKTest(_base.MappedTest): class ForeignPKTest(_base.MappedTest): """Detection of the relationship direction on PK joins.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("people", metadata, Column('person', String(10), primary_key=True), Column('firstname', String(10)), @@ -598,7 +616,8 @@ class ForeignPKTest(_base.MappedTest): primary_key=True), Column('site', String(10))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Person(_base.BasicEntity): pass class PersonSite(_base.BasicEntity): @@ -629,19 +648,22 @@ class ForeignPKTest(_base.MappedTest): class ClauseAttributesTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users_t', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30)), Column('counter', Integer, default=1)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users_t) @testing.resolve_artifact_names @@ -697,7 +719,8 @@ class ClauseAttributesTest(_base.MappedTest): class PassiveDeletesTest(_base.MappedTest): __requires__ = ('foreign_keys',) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('mytable', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), @@ -712,7 +735,8 @@ class PassiveDeletesTest(_base.MappedTest): ondelete="CASCADE"), test_needs_fk=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class MyClass(_base.BasicEntity): pass class MyOtherClass(_base.BasicEntity): @@ -773,7 +797,8 @@ class PassiveDeletesTest(_base.MappedTest): class ExtraPassiveDeletesTest(_base.MappedTest): __requires__ = ('foreign_keys',) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('mytable', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), @@ -788,7 +813,8 @@ class ExtraPassiveDeletesTest(_base.MappedTest): ['mytable.id']), test_needs_fk=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class MyClass(_base.BasicEntity): pass class MyOtherClass(_base.BasicEntity): @@ -829,7 +855,7 @@ class ExtraPassiveDeletesTest(_base.MappedTest): assert myothertable.count().scalar() == 4 mc = session.query(MyClass).get(mc.id) session.delete(mc) - self.assertRaises(sa.exc.DBAPIError, session.flush) + assert_raises(sa.exc.DBAPIError, session.flush) @testing.resolve_artifact_names def test_extra_passive_2(self): @@ -851,7 +877,7 @@ class ExtraPassiveDeletesTest(_base.MappedTest): mc = session.query(MyClass).get(mc.id) session.delete(mc) mc.children[0].data = 'some new data' - self.assertRaises(sa.exc.DBAPIError, session.flush) + assert_raises(sa.exc.DBAPIError, session.flush) class DefaultTest(_base.MappedTest): @@ -864,7 +890,8 @@ class DefaultTest(_base.MappedTest): """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): use_string_defaults = testing.against('postgres', 'oracle', 'sqlite', 'mssql') if use_string_defaults: @@ -876,8 +903,8 @@ class DefaultTest(_base.MappedTest): hohoval = 9 althohoval = 15 - self.other_artifacts['hohoval'] = hohoval - self.other_artifacts['althohoval'] = althohoval + cls.other_artifacts['hohoval'] = hohoval + cls.other_artifacts['althohoval'] = althohoval dt = Table('default_t', metadata, Column('id', Integer, primary_key=True, @@ -906,7 +933,8 @@ class DefaultTest(_base.MappedTest): st.append_column( Column('hoho', hohotype, ForeignKey('default_t.hoho'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Hoho(_base.ComparableEntity): pass class Secondary(_base.ComparableEntity): @@ -1037,7 +1065,8 @@ class DefaultTest(_base.MappedTest): Secondary(data='s2')])) class ColumnPropertyTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('data', metadata, Column('id', Integer, primary_key=True), Column('a', String(50)), @@ -1049,7 +1078,8 @@ class ColumnPropertyTest(_base.MappedTest): Column('c', String(50)), ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): class Data(_base.BasicEntity): pass @@ -1079,7 +1109,7 @@ class ColumnPropertyTest(_base.MappedTest): sd1 = SubData(a="hello", b="there", c="hi") sess.add(sd1) sess.flush() - self.assertEquals(sd1.aplusb, "hello there") + eq_(sd1.aplusb, "hello there") @testing.resolve_artifact_names def _test(self): @@ -1089,16 +1119,16 @@ class ColumnPropertyTest(_base.MappedTest): sess.add(d1) sess.flush() - self.assertEquals(d1.aplusb, "hello there") + eq_(d1.aplusb, "hello there") d1.b = "bye" sess.flush() - self.assertEquals(d1.aplusb, "hello bye") + eq_(d1.aplusb, "hello bye") d1.b = 'foobar' d1.aplusb = 'im setting this explicitly' sess.flush() - self.assertEquals(d1.aplusb, "im setting this explicitly") + eq_(d1.aplusb, "im setting this explicitly") class OneToManyTest(_fixtures.FixtureTest): run_inserts = None @@ -1596,7 +1626,7 @@ class SaveTest(_fixtures.FixtureTest): u1 = User(name='user1') u2 = User(name='user2') session.add_all((u1, u2)) - self.assertRaises(AssertionError, session.flush) + assert_raises(AssertionError, session.flush) class ManyToOneTest(_fixtures.FixtureTest): @@ -2029,7 +2059,8 @@ class SaveTest2(_fixtures.FixtureTest): ) class SaveTest3(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('items', metadata, Column('item_id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -2045,7 +2076,8 @@ class SaveTest3(_base.MappedTest): Column('keyword_id', Integer, ForeignKey("keywords")), Column('foo', sa.Boolean, default=True)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Keyword(_base.BasicEntity): pass class Item(_base.BasicEntity): @@ -2076,7 +2108,8 @@ class SaveTest3(_base.MappedTest): assert assoc.count().scalar() == 0 class BooleanColTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1_t', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30)), @@ -2118,7 +2151,8 @@ class BooleanColTest(_base.MappedTest): class RowSwitchTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): # parent Table('t5', metadata, Column('id', Integer, primary_key=True), @@ -2140,7 +2174,8 @@ class RowSwitchTest(_base.MappedTest): Column('t5id', Integer, ForeignKey('t5.id'),nullable=False), Column('t7id', Integer, ForeignKey('t7.id'),nullable=False)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T5(_base.ComparableEntity): pass @@ -2240,7 +2275,8 @@ class RowSwitchTest(_base.MappedTest): assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some other t6', 2)] class InheritingRowSwitchTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('parent', metadata, Column('id', Integer, primary_key=True), Column('pdata', String(30)) @@ -2251,7 +2287,8 @@ class InheritingRowSwitchTest(_base.MappedTest): Column('cdata', String(30)) ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class P(_base.ComparableEntity): pass @@ -2290,7 +2327,8 @@ class TransactionTest(_base.MappedTest): # be specified. it'll raise immediately post-INSERT, instead of at # COMMIT. either way, this test should pass. - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): t1 = Table('t1', metadata, Column('id', Integer, primary_key=True)) @@ -2299,15 +2337,17 @@ class TransactionTest(_base.MappedTest): Column('t1_id', Integer, ForeignKey('t1.id', deferrable=True, initially='deferred') )) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T1(_base.ComparableEntity): pass class T2(_base.ComparableEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): orm_mapper(T1, t1) orm_mapper(T2, t2) @@ -2332,5 +2372,3 @@ class TransactionTest(_base.MappedTest): if testing.against('postgres'): t1.bind.engine.dispose() -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/utils.py b/test/orm/test_utils.py index 813121a44..06533a243 100644 --- a/test/orm/utils.py +++ b/test/orm/test_utils.py @@ -1,4 +1,4 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import assert_raises, assert_raises_message from sqlalchemy.orm import interfaces, util from sqlalchemy import Column from sqlalchemy import Integer @@ -8,10 +8,10 @@ from sqlalchemy.orm import aliased from sqlalchemy.orm import mapper, create_session -from testlib import TestBase, testing +from sqlalchemy.test import TestBase, testing -from orm import _fixtures -from testlib.testing import eq_ +from test.orm import _fixtures +from sqlalchemy.test.testing import eq_ class ExtensionCarrierTest(TestBase): @@ -22,7 +22,7 @@ class ExtensionCarrierTest(TestBase): assert carrier.translate_row() is interfaces.EXT_CONTINUE assert 'translate_row' not in carrier - self.assertRaises(AttributeError, lambda: carrier.snickysnack) + assert_raises(AttributeError, lambda: carrier.snickysnack) class Partial(object): def __init__(self, marker): @@ -74,7 +74,7 @@ class AliasedClassTest(TestBase): table = self.point_map(Point) alias = aliased(Point) - self.assertRaises(TypeError, alias) + assert_raises(TypeError, alias) def test_instancemethods(self): class Point(object): @@ -236,6 +236,4 @@ class IdentityKeyTest(_fixtures.FixtureTest): key = util.identity_key(User, row=row) eq_(key, (User, (1,))) -if __name__ == '__main__': - testenv.main() |
