From 45cec095b4904ba71425d2fe18c143982dd08f43 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 10 Jun 2009 21:18:24 +0000 Subject: - unit tests have been migrated from unittest to nose. See README.unittests for information on how to run the tests. [ticket:970] --- test/orm/_base.py | 108 +- test/orm/_fixtures.py | 43 +- test/orm/alltests.py | 60 - test/orm/association.py | 145 -- test/orm/assorted_eager.py | 878 ------- test/orm/attributes.py | 1331 ----------- test/orm/bind.py | 55 - test/orm/cascade.py | 1292 ----------- test/orm/collection.py | 1834 --------------- test/orm/compile.py | 186 -- test/orm/cycles.py | 862 ------- test/orm/defaults.py | 129 -- test/orm/deprecations.py | 483 ---- test/orm/dynamic.py | 559 ----- test/orm/eager_relations.py | 1595 ------------- test/orm/evaluator.py | 94 - test/orm/expire.py | 912 -------- test/orm/extendedattr.py | 321 --- test/orm/generative.py | 279 --- test/orm/inheritance/abc_inheritance.py | 171 -- test/orm/inheritance/abc_polymorphic.py | 90 - test/orm/inheritance/alltests.py | 31 - test/orm/inheritance/basic.py | 1015 --------- test/orm/inheritance/concrete.py | 515 ----- test/orm/inheritance/magazine.py | 220 -- test/orm/inheritance/manytomany.py | 248 -- test/orm/inheritance/poly_linked_list.py | 199 -- test/orm/inheritance/polymorph.py | 304 --- test/orm/inheritance/polymorph2.py | 1107 --------- test/orm/inheritance/productspec.py | 320 --- test/orm/inheritance/query.py | 1105 --------- test/orm/inheritance/selects.py | 53 - test/orm/inheritance/single.py | 396 ---- test/orm/inheritance/test_abc_inheritance.py | 170 ++ test/orm/inheritance/test_abc_polymorphic.py | 88 + test/orm/inheritance/test_basic.py | 1027 +++++++++ test/orm/inheritance/test_concrete.py | 519 +++++ test/orm/inheritance/test_magazine.py | 219 ++ test/orm/inheritance/test_manytomany.py | 249 ++ test/orm/inheritance/test_poly_linked_list.py | 197 ++ test/orm/inheritance/test_polymorph.py | 304 +++ test/orm/inheritance/test_polymorph2.py | 1121 +++++++++ test/orm/inheritance/test_productspec.py | 318 +++ test/orm/inheritance/test_query.py | 1113 +++++++++ test/orm/inheritance/test_selects.py | 51 + test/orm/inheritance/test_single.py | 400 ++++ test/orm/instrumentation.py | 765 ------- test/orm/lazy_relations.py | 416 ---- test/orm/lazytest1.py | 90 - test/orm/manytomany.py | 324 --- test/orm/mapper.py | 2467 -------------------- test/orm/merge.py | 736 ------ test/orm/naturalpks.py | 475 ---- test/orm/onetoone.py | 74 - test/orm/pickled.py | 190 -- test/orm/query.py | 3012 ------------------------ test/orm/relationships.py | 1876 --------------- test/orm/scoping.py | 238 -- test/orm/selectable.py | 52 - test/orm/session.py | 1436 ------------ test/orm/sharding/alltests.py | 18 - test/orm/sharding/shard.py | 163 -- test/orm/sharding/test_shard.py | 164 ++ test/orm/test_association.py | 147 ++ test/orm/test_assorted_eager.py | 902 ++++++++ test/orm/test_attributes.py | 1328 +++++++++++ test/orm/test_bind.py | 59 + test/orm/test_cascade.py | 1315 +++++++++++ test/orm/test_collection.py | 1839 +++++++++++++++ test/orm/test_compile.py | 183 ++ test/orm/test_cycles.py | 885 ++++++++ test/orm/test_defaults.py | 133 ++ test/orm/test_deprecations.py | 486 ++++ test/orm/test_dynamic.py | 561 +++++ test/orm/test_eager_relations.py | 1608 +++++++++++++ test/orm/test_evaluator.py | 97 + test/orm/test_expire.py | 916 ++++++++ test/orm/test_extendedattr.py | 322 +++ test/orm/test_generative.py | 290 +++ test/orm/test_instrumentation.py | 766 +++++++ test/orm/test_lazy_relations.py | 419 ++++ test/orm/test_lazytest1.py | 92 + test/orm/test_manytomany.py | 330 +++ test/orm/test_mapper.py | 2475 ++++++++++++++++++++ test/orm/test_merge.py | 735 ++++++ test/orm/test_naturalpks.py | 482 ++++ test/orm/test_onetoone.py | 76 + test/orm/test_pickled.py | 194 ++ test/orm/test_query.py | 3024 +++++++++++++++++++++++++ test/orm/test_relationships.py | 1907 ++++++++++++++++ test/orm/test_scoping.py | 249 ++ test/orm/test_selectable.py | 55 + test/orm/test_session.py | 1434 ++++++++++++ test/orm/test_transaction.py | 498 ++++ test/orm/test_unitofwork.py | 2374 +++++++++++++++++++ test/orm/test_utils.py | 239 ++ test/orm/transaction.py | 499 ---- test/orm/unitofwork.py | 2336 ------------------- test/orm/utils.py | 241 -- 99 files changed, 32433 insertions(+), 32275 deletions(-) delete mode 100644 test/orm/alltests.py delete mode 100644 test/orm/association.py delete mode 100644 test/orm/assorted_eager.py delete mode 100644 test/orm/attributes.py delete mode 100644 test/orm/bind.py delete mode 100644 test/orm/cascade.py delete mode 100644 test/orm/collection.py delete mode 100644 test/orm/compile.py delete mode 100644 test/orm/cycles.py delete mode 100644 test/orm/defaults.py delete mode 100644 test/orm/deprecations.py delete mode 100644 test/orm/dynamic.py delete mode 100644 test/orm/eager_relations.py delete mode 100644 test/orm/evaluator.py delete mode 100644 test/orm/expire.py delete mode 100644 test/orm/extendedattr.py delete mode 100644 test/orm/generative.py delete mode 100644 test/orm/inheritance/abc_inheritance.py delete mode 100644 test/orm/inheritance/abc_polymorphic.py delete mode 100644 test/orm/inheritance/alltests.py delete mode 100644 test/orm/inheritance/basic.py delete mode 100644 test/orm/inheritance/concrete.py delete mode 100644 test/orm/inheritance/magazine.py delete mode 100644 test/orm/inheritance/manytomany.py delete mode 100644 test/orm/inheritance/poly_linked_list.py delete mode 100644 test/orm/inheritance/polymorph.py delete mode 100644 test/orm/inheritance/polymorph2.py delete mode 100644 test/orm/inheritance/productspec.py delete mode 100644 test/orm/inheritance/query.py delete mode 100644 test/orm/inheritance/selects.py delete mode 100644 test/orm/inheritance/single.py create mode 100644 test/orm/inheritance/test_abc_inheritance.py create mode 100644 test/orm/inheritance/test_abc_polymorphic.py create mode 100644 test/orm/inheritance/test_basic.py create mode 100644 test/orm/inheritance/test_concrete.py create mode 100644 test/orm/inheritance/test_magazine.py create mode 100644 test/orm/inheritance/test_manytomany.py create mode 100644 test/orm/inheritance/test_poly_linked_list.py create mode 100644 test/orm/inheritance/test_polymorph.py create mode 100644 test/orm/inheritance/test_polymorph2.py create mode 100644 test/orm/inheritance/test_productspec.py create mode 100644 test/orm/inheritance/test_query.py create mode 100644 test/orm/inheritance/test_selects.py create mode 100644 test/orm/inheritance/test_single.py delete mode 100644 test/orm/instrumentation.py delete mode 100644 test/orm/lazy_relations.py delete mode 100644 test/orm/lazytest1.py delete mode 100644 test/orm/manytomany.py delete mode 100644 test/orm/mapper.py delete mode 100644 test/orm/merge.py delete mode 100644 test/orm/naturalpks.py delete mode 100644 test/orm/onetoone.py delete mode 100644 test/orm/pickled.py delete mode 100644 test/orm/query.py delete mode 100644 test/orm/relationships.py delete mode 100644 test/orm/scoping.py delete mode 100644 test/orm/selectable.py delete mode 100644 test/orm/session.py delete mode 100644 test/orm/sharding/alltests.py delete mode 100644 test/orm/sharding/shard.py create mode 100644 test/orm/sharding/test_shard.py create mode 100644 test/orm/test_association.py create mode 100644 test/orm/test_assorted_eager.py create mode 100644 test/orm/test_attributes.py create mode 100644 test/orm/test_bind.py create mode 100644 test/orm/test_cascade.py create mode 100644 test/orm/test_collection.py create mode 100644 test/orm/test_compile.py create mode 100644 test/orm/test_cycles.py create mode 100644 test/orm/test_defaults.py create mode 100644 test/orm/test_deprecations.py create mode 100644 test/orm/test_dynamic.py create mode 100644 test/orm/test_eager_relations.py create mode 100644 test/orm/test_evaluator.py create mode 100644 test/orm/test_expire.py create mode 100644 test/orm/test_extendedattr.py create mode 100644 test/orm/test_generative.py create mode 100644 test/orm/test_instrumentation.py create mode 100644 test/orm/test_lazy_relations.py create mode 100644 test/orm/test_lazytest1.py create mode 100644 test/orm/test_manytomany.py create mode 100644 test/orm/test_mapper.py create mode 100644 test/orm/test_merge.py create mode 100644 test/orm/test_naturalpks.py create mode 100644 test/orm/test_onetoone.py create mode 100644 test/orm/test_pickled.py create mode 100644 test/orm/test_query.py create mode 100644 test/orm/test_relationships.py create mode 100644 test/orm/test_scoping.py create mode 100644 test/orm/test_selectable.py create mode 100644 test/orm/test_session.py create mode 100644 test/orm/test_transaction.py create mode 100644 test/orm/test_unitofwork.py create mode 100644 test/orm/test_utils.py delete mode 100644 test/orm/transaction.py delete mode 100644 test/orm/unitofwork.py delete mode 100644 test/orm/utils.py (limited to 'test/orm') diff --git a/test/orm/_base.py b/test/orm/_base.py index 9e599a6f1..8d695e912 100644 --- a/test/orm/_base.py +++ b/test/orm/_base.py @@ -2,9 +2,10 @@ import gc import inspect import sys import types -from testlib import config, sa, testing -from testlib.testing import resolve_artifact_names, adict -from testlib.compat import _function_named +import sqlalchemy as sa +from sqlalchemy.test import config, testing +from sqlalchemy.test.testing import resolve_artifact_names, adict +from sqlalchemy.util import function_named _repr_stack = set() @@ -95,7 +96,8 @@ class ComparableEntity(BasicEntity): class ORMTest(testing.TestBase, testing.AssertsExecutionResults): __requires__ = ('subqueries',) - def tearDownAll(self): + @classmethod + def teardown_class(cls): sa.orm.session.Session.close_all() sa.orm.clear_mappers() # TODO: ensure mapper registry is empty @@ -124,18 +126,18 @@ class MappedTest(ORMTest): classes = None other_artifacts = None - def setUpAll(self): - if self.run_setup_classes == 'each': - assert self.run_setup_mappers != 'once' + @classmethod + def setup_class(cls): + if cls.run_setup_classes == 'each': + assert cls.run_setup_mappers != 'once' - assert self.run_deletes in (None, 'each') - if self.run_inserts == 'once': - assert self.run_deletes is None + assert cls.run_deletes in (None, 'each') + if cls.run_inserts == 'once': + assert cls.run_deletes is None - assert not hasattr(self, 'keep_mappers') - assert not hasattr(self, 'keep_data') + assert not hasattr(cls, 'keep_mappers') + assert not hasattr(cls, 'keep_data') - cls = self.__class__ if cls.tables is None: cls.tables = adict() if cls.classes is None: @@ -143,35 +145,32 @@ class MappedTest(ORMTest): if cls.other_artifacts is None: cls.other_artifacts = adict() - if self.metadata is None: - setattr(type(self), 'metadata', sa.MetaData()) + if cls.metadata is None: + setattr(cls, 'metadata', sa.MetaData()) - if self.metadata.bind is None: - self.metadata.bind = getattr(self, 'engine', config.db) + if cls.metadata.bind is None: + cls.metadata.bind = getattr(cls, 'engine', config.db) - if self.run_define_tables: - self.define_tables(self.metadata) - self.metadata.create_all() - self.tables.update(self.metadata.tables) + if cls.run_define_tables == 'once': + cls.define_tables(cls.metadata) + cls.metadata.create_all() + cls.tables.update(cls.metadata.tables) - if self.run_setup_classes: + if cls.run_setup_classes == 'once': baseline = subclasses(BasicEntity) - self.setup_classes() - self._register_new_class_artifacts(baseline) + cls.setup_classes() + cls._register_new_class_artifacts(baseline) - if self.run_setup_mappers: + if cls.run_setup_mappers == 'once': baseline = subclasses(BasicEntity) - self.setup_mappers() - self._register_new_class_artifacts(baseline) + cls.setup_mappers() + cls._register_new_class_artifacts(baseline) - if self.run_inserts: - self._load_fixtures() - self.insert_data() - - def setUp(self): - if self._sa_first_test: - return + if cls.run_inserts == 'once': + cls._load_fixtures() + cls.insert_data() + def setup(self): if self.run_define_tables == 'each': self.tables.clear() self.metadata.drop_all() @@ -195,7 +194,7 @@ class MappedTest(ORMTest): self._load_fixtures() self.insert_data() - def tearDown(self): + def teardown(self): sa.orm.session.Session.close_all() # some tests create mappers in the test bodies @@ -213,26 +212,32 @@ class MappedTest(ORMTest): print >> sys.stderr, "Error emptying table %s: %r" % ( table, ex) - def tearDownAll(self): - for cls in self.classes.values(): - self.unregister_class(cls) - ORMTest.tearDownAll(self) - self.metadata.drop_all() - self.metadata.bind = None + @classmethod + def teardown_class(cls): + for cl in cls.classes.values(): + cls.unregister_class(cl) + ORMTest.teardown_class() + cls.metadata.drop_all() + cls.metadata.bind = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): raise NotImplementedError() - def setup_classes(self): + @classmethod + def setup_classes(cls): pass - def setup_mappers(self): + @classmethod + def setup_mappers(cls): pass - def fixtures(self): + @classmethod + def fixtures(cls): return {} - def insert_data(self): + @classmethod + def insert_data(cls): pass def sql_count_(self, count, fn): @@ -260,15 +265,16 @@ class MappedTest(ORMTest): if name[0].isupper: delattr(cls, name) del cls.classes[name] - - def _load_fixtures(self): + + @classmethod + def _load_fixtures(cls): headers, rows = {}, {} - for table, data in self.fixtures().iteritems(): + for table, data in cls.fixtures().iteritems(): if isinstance(table, basestring): - table = self.tables[table] + table = cls.tables[table] headers[table] = data[0] rows[table] = data[1:] - for table in self.metadata.sorted_tables: + for table in cls.metadata.sorted_tables: if table not in headers: continue table.bind.execute( diff --git a/test/orm/_fixtures.py b/test/orm/_fixtures.py index f036b92b2..14709ec43 100644 --- a/test/orm/_fixtures.py +++ b/test/orm/_fixtures.py @@ -1,7 +1,9 @@ -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import attributes -from testlib.testing import fixture -from orm import _base +from sqlalchemy import MetaData, Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import attributes +from sqlalchemy.test.testing import fixture +from test.orm import _base __all__ = () @@ -227,34 +229,21 @@ class FixtureTest(_base.MappedTest): Address=Address, Dingaling=Dingaling) - def setUpAll(self): - assert not hasattr(self, 'refresh_data') - assert not hasattr(self, 'only_tables') - #refresh_data = False - #only_tables = False - - #if type(self) is not FixtureTest: - # setattr(type(self), 'classes', _base.adict(self.classes)) - - #if self.run_setup_classes: - # for cls in self.classes.values(): - # self.register_class(cls) - super(FixtureTest, self).setUpAll() - - #if not self.only_tables and self.keep_data: - # _registry.load() - - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): pass - def setup_classes(self): - for cls in self.fixture_classes.values(): - self.register_class(cls) + @classmethod + def setup_classes(cls): + for cl in cls.fixture_classes.values(): + cls.register_class(cl) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): pass - def insert_data(self): + @classmethod + def insert_data(cls): _load_fixtures() diff --git a/test/orm/alltests.py b/test/orm/alltests.py deleted file mode 100644 index 9458ca523..000000000 --- a/test/orm/alltests.py +++ /dev/null @@ -1,60 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -import inheritance.alltests as inheritance -import sharding.alltests as sharding - -def suite(): - modules_to_test = ( - 'orm.attributes', - 'orm.bind', - 'orm.extendedattr', - 'orm.instrumentation', - 'orm.query', - 'orm.lazy_relations', - 'orm.eager_relations', - 'orm.mapper', - 'orm.expire', - 'orm.selectable', - 'orm.collection', - 'orm.generative', - 'orm.lazytest1', - 'orm.assorted_eager', - - 'orm.naturalpks', - 'orm.defaults', - 'orm.unitofwork', - 'orm.session', - 'orm.transaction', - 'orm.scoping', - 'orm.cascade', - 'orm.relationships', - 'orm.association', - 'orm.merge', - 'orm.pickled', - 'orm.utils', - - 'orm.cycles', - - 'orm.compile', - 'orm.manytomany', - 'orm.onetoone', - 'orm.dynamic', - - 'orm.evaluator', - - 'orm.deprecations', - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - alltests.addTest(inheritance.suite()) - alltests.addTest(sharding.suite()) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/orm/association.py b/test/orm/association.py deleted file mode 100644 index d9265ffb1..000000000 --- a/test/orm/association.py +++ /dev/null @@ -1,145 +0,0 @@ -import testenv; testenv.configure_for_tests() - -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base -from testlib.testing import eq_ - - -class AssociationTest(_base.MappedTest): - run_setup_classes = 'once' - run_setup_mappers = 'once' - - def define_tables(self, metadata): - Table('items', metadata, - Column('item_id', Integer, primary_key=True), - Column('name', String(40))) - Table('item_keywords', metadata, - Column('item_id', Integer, ForeignKey('items.item_id')), - Column('keyword_id', Integer, ForeignKey('keywords.keyword_id')), - Column('data', String(40))) - Table('keywords', metadata, - Column('keyword_id', Integer, primary_key=True), - Column('name', String(40))) - - def setup_classes(self): - class Item(_base.BasicEntity): - def __init__(self, name): - self.name = name - def __repr__(self): - return "Item id=%d name=%s keywordassoc=%r" % ( - self.item_id, self.name, self.keywords) - - class Keyword(_base.BasicEntity): - def __init__(self, name): - self.name = name - def __repr__(self): - return "Keyword id=%d name=%s" % (self.keyword_id, self.name) - - class KeywordAssociation(_base.BasicEntity): - def __init__(self, keyword, data): - self.keyword = keyword - self.data = data - def __repr__(self): - return "KeywordAssociation itemid=%d keyword=%r data=%s" % ( - self.item_id, self.keyword, self.data) - - @testing.resolve_artifact_names - def setup_mappers(self): - items, item_keywords, keywords = self.tables.get_all( - 'items', 'item_keywords', 'keywords') - - mapper(Keyword, keywords) - mapper(KeywordAssociation, item_keywords, properties={ - 'keyword':relation(Keyword, lazy=False)}, - primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id], - order_by=[item_keywords.c.data]) - - mapper(Item, items, properties={ - 'keywords' : relation(KeywordAssociation, - cascade="all, delete-orphan") - }) - - @testing.resolve_artifact_names - def test_insert(self): - sess = create_session() - item1 = Item('item1') - item2 = Item('item2') - item1.keywords.append(KeywordAssociation(Keyword('blue'), 'blue_assoc')) - item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) - item2.keywords.append(KeywordAssociation(Keyword('green'), 'green_assoc')) - sess.add_all((item1, item2)) - sess.flush() - saved = repr([item1, item2]) - sess.expunge_all() - l = sess.query(Item).all() - loaded = repr(l) - eq_(saved, loaded) - - @testing.resolve_artifact_names - def test_replace(self): - sess = create_session() - item1 = Item('item1') - item1.keywords.append(KeywordAssociation(Keyword('blue'), 'blue_assoc')) - item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) - sess.add(item1) - sess.flush() - - red_keyword = item1.keywords[1].keyword - del item1.keywords[1] - item1.keywords.append(KeywordAssociation(red_keyword, 'new_red_assoc')) - sess.flush() - saved = repr([item1]) - sess.expunge_all() - l = sess.query(Item).all() - loaded = repr(l) - eq_(saved, loaded) - - @testing.resolve_artifact_names - def test_modify(self): - sess = create_session() - item1 = Item('item1') - item2 = Item('item2') - item1.keywords.append(KeywordAssociation(Keyword('blue'), 'blue_assoc')) - item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) - item2.keywords.append(KeywordAssociation(Keyword('green'), 'green_assoc')) - sess.add_all((item1, item2)) - sess.flush() - - red_keyword = item1.keywords[1].keyword - del item1.keywords[0] - del item1.keywords[0] - purple_keyword = Keyword('purple') - item1.keywords.append(KeywordAssociation(red_keyword, 'new_red_assoc')) - item2.keywords.append(KeywordAssociation(purple_keyword, 'purple_item2_assoc')) - item1.keywords.append(KeywordAssociation(purple_keyword, 'purple_item1_assoc')) - item1.keywords.append(KeywordAssociation(Keyword('yellow'), 'yellow_assoc')) - - sess.flush() - saved = repr([item1, item2]) - sess.expunge_all() - l = sess.query(Item).all() - loaded = repr(l) - eq_(saved, loaded) - - @testing.resolve_artifact_names - def test_delete(self): - sess = create_session() - item1 = Item('item1') - item2 = Item('item2') - item1.keywords.append(KeywordAssociation(Keyword('blue'), 'blue_assoc')) - item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) - item2.keywords.append(KeywordAssociation(Keyword('green'), 'green_assoc')) - sess.add_all((item1, item2)) - sess.flush() - eq_(self.tables.item_keywords.count().scalar(), 3) - - sess.delete(item1) - sess.delete(item2) - sess.flush() - eq_(self.tables.item_keywords.count().scalar(), 0) - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/assorted_eager.py b/test/orm/assorted_eager.py deleted file mode 100644 index 8dc95fa5b..000000000 --- a/test/orm/assorted_eager.py +++ /dev/null @@ -1,878 +0,0 @@ -"""Exercises for eager loading. - -Derived from mailing list-reported problems and trac tickets. - -""" -import testenv; testenv.configure_for_tests() -import datetime - -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, backref, create_session -from testlib.testing import eq_ -from orm import _base - - -class EagerTest(_base.MappedTest): - run_deletes = None - run_inserts = "once" - - def define_tables(self, metadata): - # determine a literal value for "false" based on the dialect - # FIXME: this DefaultClause setup is bogus. - - dialect = testing.db.dialect - bp = sa.Boolean().dialect_impl(dialect).bind_processor(dialect) - - if bp: - false = str(bp(False)) - elif testing.against('maxdb'): - false = text('FALSE') - else: - false = str(False) - self.other_artifacts['false'] = false - - Table('owners', metadata , - Column('id', Integer, primary_key=True, nullable=False), - Column('data', String(30))) - - Table('categories', metadata, - Column('id', Integer, primary_key=True, nullable=False), - Column('name', String(20))) - - Table('tests', metadata , - Column('id', Integer, primary_key=True, nullable=False ), - Column('owner_id', Integer, ForeignKey('owners.id'), - nullable=False), - Column('category_id', Integer, ForeignKey('categories.id'), - nullable=False)) - - Table('options', metadata , - Column('test_id', Integer, ForeignKey('tests.id'), - primary_key=True, nullable=False), - Column('owner_id', Integer, ForeignKey('owners.id'), - primary_key=True, nullable=False), - Column('someoption', sa.Boolean, server_default=false, - nullable=False)) - - def setup_classes(self): - class Owner(_base.BasicEntity): - pass - - class Category(_base.BasicEntity): - pass - - class Test(_base.BasicEntity): - pass - - class Option(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Owner, owners) - - mapper(Category, categories) - - mapper(Option, options, properties=dict( - owner=relation(Owner), - test=relation(Test))) - - mapper(Test, tests, properties=dict( - owner=relation(Owner, backref='tests'), - category=relation(Category), - owner_option=relation(Option, - primaryjoin=sa.and_(tests.c.id == options.c.test_id, - tests.c.owner_id == options.c.owner_id), - foreign_keys=[options.c.test_id, options.c.owner_id], - uselist=False))) - - @testing.resolve_artifact_names - def insert_data(self): - session = create_session() - - o = Owner() - c = Category(name='Some Category') - session.add_all(( - Test(owner=o, category=c), - Test(owner=o, category=c, owner_option=Option(someoption=True)), - Test(owner=o, category=c, owner_option=Option()))) - - session.flush() - - @testing.resolve_artifact_names - def test_noorm(self): - """test the control case""" - # I want to display a list of tests owned by owner 1 - # if someoption is false or he hasn't specified it yet (null) - # but not if he set it to true (example someoption is for hiding) - - # desired output for owner 1 - # test_id, cat_name - # 1 'Some Category' - # 3 " - - # not orm style correct query - print "Obtaining correct results without orm" - result = sa.select( - [tests.c.id,categories.c.name], - sa.and_(tests.c.owner_id == 1, - sa.or_(options.c.someoption==None, - options.c.someoption==False)), - order_by=[tests.c.id], - from_obj=[tests.join(categories).outerjoin(options, sa.and_( - tests.c.id == options.c.test_id, - tests.c.owner_id == options.c.owner_id))] - ).execute().fetchall() - eq_(result, [(1, u'Some Category'), (3, u'Some Category')]) - - @testing.resolve_artifact_names - def test_withouteagerload(self): - s = create_session() - l = (s.query(Test). - select_from(tests.outerjoin(options, - sa.and_(tests.c.id == options.c.test_id, - tests.c.owner_id == - options.c.owner_id))). - filter(sa.and_(tests.c.owner_id==1, - sa.or_(options.c.someoption==None, - options.c.someoption==False)))) - - result = ["%d %s" % ( t.id,t.category.name ) for t in l] - eq_(result, [u'1 Some Category', u'3 Some Category']) - - @testing.resolve_artifact_names - def test_witheagerload(self): - """ - Test that an eagerload locates the correct "from" clause with which to - attach to, when presented with a query that already has a complicated - from clause. - - """ - s = create_session() - q=s.query(Test).options(sa.orm.eagerload('category')) - - l=(q.select_from(tests.outerjoin(options, - sa.and_(tests.c.id == - options.c.test_id, - tests.c.owner_id == - options.c.owner_id))). - filter(sa.and_(tests.c.owner_id == 1, - sa.or_(options.c.someoption==None, - options.c.someoption==False)))) - - result = ["%d %s" % ( t.id,t.category.name ) for t in l] - eq_(result, [u'1 Some Category', u'3 Some Category']) - - @testing.resolve_artifact_names - def test_dslish(self): - """test the same as witheagerload except using generative""" - s = create_session() - q = s.query(Test).options(sa.orm.eagerload('category')) - l = q.filter ( - sa.and_(tests.c.owner_id == 1, - sa.or_(options.c.someoption == None, - options.c.someoption == False)) - ).outerjoin('owner_option') - - result = ["%d %s" % ( t.id,t.category.name ) for t in l] - eq_(result, [u'1 Some Category', u'3 Some Category']) - - @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on') - @testing.resolve_artifact_names - def test_without_outerjoin_literal(self): - s = create_session() - q = s.query(Test).options(sa.orm.eagerload('category')) - l = (q.filter( - (tests.c.owner_id==1) & - ('options.someoption is null or options.someoption=%s' % false)). - join('owner_option')) - - result = ["%d %s" % ( t.id,t.category.name ) for t in l] - eq_(result, [u'3 Some Category']) - - @testing.resolve_artifact_names - def test_withoutouterjoin(self): - s = create_session() - q = s.query(Test).options(sa.orm.eagerload('category')) - l = q.filter( - (tests.c.owner_id==1) & - ((options.c.someoption==None) | (options.c.someoption==False)) - ).join('owner_option') - - result = ["%d %s" % ( t.id,t.category.name ) for t in l] - eq_(result, [u'3 Some Category']) - - -class EagerTest2(_base.MappedTest): - def define_tables(self, metadata): - Table('left', metadata, - Column('id', Integer, ForeignKey('middle.id'), primary_key=True), - Column('data', String(50), primary_key=True)) - - Table('middle', metadata, - Column('id', Integer, primary_key = True), - Column('data', String(50))) - - Table('right', metadata, - Column('id', Integer, ForeignKey('middle.id'), primary_key=True), - Column('data', String(50), primary_key=True)) - - def setup_classes(self): - class Left(_base.BasicEntity): - def __init__(self, data): - self.data = data - - class Middle(_base.BasicEntity): - def __init__(self, data): - self.data = data - - class Right(_base.BasicEntity): - def __init__(self, data): - self.data = data - - @testing.resolve_artifact_names - def setup_mappers(self): - # set up bi-directional eager loads - mapper(Left, left) - mapper(Right, right) - mapper(Middle, middle, properties=dict( - left=relation(Left, - lazy=False, - backref=backref('middle',lazy=False)), - right=relation(Right, - lazy=False, - backref=backref('middle', lazy=False)))), - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_eager_terminate(self): - """Eager query generation does not include the same mapper's table twice. - - Or, that bi-directional eager loads dont include each other in eager - query generation. - - """ - p = Middle('m1') - p.left.append(Left('l1')) - p.right.append(Right('r1')) - - session = create_session() - session.add(p) - session.flush() - session.expunge_all() - obj = session.query(Left).filter_by(data='l1').one() - - -class EagerTest3(_base.MappedTest): - """Eager loading combined with nested SELECT statements, functions, and aggregates.""" - - def define_tables(self, metadata): - Table('datas', metadata, - Column('id', Integer, primary_key=True, nullable=False), - Column('a', Integer, nullable=False)) - - Table('foo', metadata, - Column('data_id', Integer, - ForeignKey('datas.id'), - nullable=False, primary_key=True), - Column('bar', Integer)) - - Table('stats', metadata, - Column('id', Integer, primary_key=True, nullable=False ), - Column('data_id', Integer, ForeignKey('datas.id')), - Column('somedata', Integer, nullable=False )) - - def setup_classes(self): - class Data(_base.BasicEntity): - pass - - class Foo(_base.BasicEntity): - pass - - class Stat(_base.BasicEntity): - pass - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_nesting_with_functions(self): - mapper(Data, datas) - mapper(Foo, foo, properties={ - 'data': relation(Data,backref=backref('foo',uselist=False))}) - - mapper(Stat, stats, properties={ - 'data':relation(Data)}) - - session = create_session() - - data = [Data(a=x) for x in range(5)] - session.add_all(data) - - session.add_all(( - Stat(data=data[0], somedata=1), - Stat(data=data[1], somedata=2), - Stat(data=data[2], somedata=3), - Stat(data=data[3], somedata=4), - Stat(data=data[4], somedata=5), - Stat(data=data[0], somedata=6), - Stat(data=data[1], somedata=7), - Stat(data=data[2], somedata=8), - Stat(data=data[3], somedata=9), - Stat(data=data[4], somedata=10))) - session.flush() - - arb_data = sa.select( - [stats.c.data_id, sa.func.max(stats.c.somedata).label('max')], - stats.c.data_id <= 5, - group_by=[stats.c.data_id]).alias('arb') - - arb_result = arb_data.execute().fetchall() - - # order the result list descending based on 'max' - arb_result.sort(key = lambda a: a['max'], reverse=True) - - # extract just the "data_id" from it - arb_result = [row['data_id'] for row in arb_result] - - # now query for Data objects using that above select, adding the - # "order by max desc" separately - q = (session.query(Data). - options(sa.orm.eagerload('foo')). - select_from(datas.join(arb_data, arb_data.c.data_id == datas.c.id)). - order_by(sa.desc(arb_data.c.max)). - limit(10)) - - # extract "data_id" from the list of result objects - verify_result = [d.id for d in q] - - eq_(verify_result, arb_result) - -class EagerTest4(_base.MappedTest): - - def define_tables(self, metadata): - Table('departments', metadata, - Column('department_id', Integer, primary_key=True), - Column('name', String(50))) - - Table('employees', metadata, - Column('person_id', Integer, primary_key=True), - Column('name', String(50)), - Column('department_id', Integer, - ForeignKey('departments.department_id'))) - - def setup_classes(self): - class Department(_base.BasicEntity): - pass - - class Employee(_base.BasicEntity): - pass - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_basic(self): - mapper(Employee, employees) - mapper(Department, departments, properties=dict( - employees=relation(Employee, - lazy=False, - backref='department'))) - - d1 = Department(name='One') - for e in 'Jim', 'Jack', 'John', 'Susan': - d1.employees.append(Employee(name=e)) - - d2 = Department(name='Two') - for e in 'Joe', 'Bob', 'Mary', 'Wally': - d2.employees.append(Employee(name=e)) - - sess = create_session() - sess.add_all((d1, d2)) - sess.flush() - - q = (sess.query(Department). - join('employees'). - filter(Employee.name.startswith('J')). - distinct(). - order_by([sa.desc(Department.name)])) - - eq_(q.count(), 2) - assert q[0] is d2 - - -class EagerTest5(_base.MappedTest): - """Construction of AliasedClauses for the same eager load property but different parent mappers, due to inheritance.""" - - def define_tables(self, metadata): - Table('base', metadata, - Column('uid', String(30), primary_key=True), - Column('x', String(30))) - - Table('derived', metadata, - Column('uid', String(30), ForeignKey('base.uid'), - primary_key=True), - Column('y', String(30))) - - Table('derivedII', metadata, - Column('uid', String(30), ForeignKey('base.uid'), - primary_key=True), - Column('z', String(30))) - - Table('comments', metadata, - Column('id', Integer, primary_key=True), - Column('uid', String(30), ForeignKey('base.uid')), - Column('comment', String(30))) - - def setup_classes(self): - class Base(_base.BasicEntity): - def __init__(self, uid, x): - self.uid = uid - self.x = x - - class Derived(Base): - def __init__(self, uid, x, y): - self.uid = uid - self.x = x - self.y = y - - class DerivedII(Base): - def __init__(self, uid, x, z): - self.uid = uid - self.x = x - self.z = z - - class Comment(_base.BasicEntity): - def __init__(self, uid, comment): - self.uid = uid - self.comment = comment - - @testing.resolve_artifact_names - def test_basic(self): - commentMapper = mapper(Comment, comments) - - baseMapper = mapper(Base, base, properties=dict( - comments=relation(Comment, lazy=False, - cascade='all, delete-orphan'))) - - mapper(Derived, derived, inherits=baseMapper) - - mapper(DerivedII, derivedII, inherits=baseMapper) - - sess = create_session() - d = Derived('uid1', 'x', 'y') - d.comments = [Comment('uid1', 'comment')] - d2 = DerivedII('uid2', 'xx', 'z') - d2.comments = [Comment('uid2', 'comment')] - sess.add_all((d, d2)) - sess.flush() - sess.expunge_all() - - # this eager load sets up an AliasedClauses for the "comment" - # relationship, then stores it in clauses_by_lead_mapper[mapper for - # Derived] - d = sess.query(Derived).get('uid1') - sess.expunge_all() - assert len([c for c in d.comments]) == 1 - - # this eager load sets up an AliasedClauses for the "comment" - # relationship, and should store it in clauses_by_lead_mapper[mapper - # for DerivedII]. the bug was that the previous AliasedClause create - # prevented this population from occurring. - d2 = sess.query(DerivedII).get('uid2') - sess.expunge_all() - - # object is not in the session; therefore the lazy load cant trigger - # here, eager load had to succeed - assert len([c for c in d2.comments]) == 1 - - -class EagerTest6(_base.MappedTest): - - def define_tables(self, metadata): - Table('design_types', metadata, - Column('design_type_id', Integer, primary_key=True)) - - Table('design', metadata, - Column('design_id', Integer, primary_key=True), - Column('design_type_id', Integer, - ForeignKey('design_types.design_type_id'))) - - Table('parts', metadata, - Column('part_id', Integer, primary_key=True), - Column('design_id', Integer, ForeignKey('design.design_id')), - Column('design_type_id', Integer, - ForeignKey('design_types.design_type_id'))) - - Table('inherited_part', metadata, - Column('ip_id', Integer, primary_key=True), - Column('part_id', Integer, ForeignKey('parts.part_id')), - Column('design_id', Integer, ForeignKey('design.design_id'))) - - def setup_classes(self): - class Part(_base.BasicEntity): - pass - - class Design(_base.BasicEntity): - pass - - class DesignType(_base.BasicEntity): - pass - - class InheritedPart(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def test_one(self): - p_m = mapper(Part, parts) - - mapper(InheritedPart, inherited_part, properties=dict( - part=relation(Part, lazy=False))) - - d_m = mapper(Design, design, properties=dict( - inheritedParts=relation(InheritedPart, - cascade="all, delete-orphan", - backref="design"))) - - mapper(DesignType, design_types) - - d_m.add_property( - "type", relation(DesignType, lazy=False, backref="designs")) - - p_m.add_property( - "design", relation( - Design, lazy=False, - backref=backref("parts", cascade="all, delete-orphan"))) - - - d = Design() - sess = create_session() - sess.add(d) - sess.flush() - sess.expunge_all() - x = sess.query(Design).get(1) - x.inheritedParts - - -class EagerTest7(_base.MappedTest): - def define_tables(self, metadata): - Table('companies', metadata, - Column('company_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('company_name', String(40))) - - Table('addresses', metadata, - Column('address_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('company_id', Integer, ForeignKey("companies.company_id")), - Column('address', String(40))) - - Table('phone_numbers', metadata, - Column('phone_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('address_id', Integer, ForeignKey('addresses.address_id')), - Column('type', String(20)), - Column('number', String(10))) - - Table('invoices', metadata, - Column('invoice_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('company_id', Integer, ForeignKey("companies.company_id")), - Column('date', sa.DateTime)) - - Table('items', metadata, - Column('item_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')), - Column('code', String(20)), - Column('qty', Integer)) - - def setup_classes(self): - class Company(_base.ComparableEntity): - pass - - class Address(_base.ComparableEntity): - pass - - class Phone(_base.ComparableEntity): - pass - - class Item(_base.ComparableEntity): - pass - - class Invoice(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def testone(self): - """ - Tests eager load of a many-to-one attached to a one-to-many. this - testcase illustrated the bug, which is that when the single Company is - loaded, no further processing of the rows occurred in order to load - the Company's second Address object. - - """ - mapper(Address, addresses) - - mapper(Company, companies, properties={ - 'addresses' : relation(Address, lazy=False)}) - - mapper(Invoice, invoices, properties={ - 'company': relation(Company, lazy=False)}) - - a1 = Address(address='a1 address') - a2 = Address(address='a2 address') - c1 = Company(company_name='company 1', addresses=[a1, a2]) - i1 = Invoice(date=datetime.datetime.now(), company=c1) - - - session = create_session() - session.add(i1) - session.flush() - - company_id = c1.company_id - invoice_id = i1.invoice_id - - session.expunge_all() - c = session.query(Company).get(company_id) - - session.expunge_all() - i = session.query(Invoice).get(invoice_id) - - eq_(c, i.company) - - @testing.resolve_artifact_names - def testtwo(self): - """The original testcase that includes various complicating factors""" - - mapper(Phone, phone_numbers) - - mapper(Address, addresses, properties={ - 'phones': relation(Phone, lazy=False, backref='address', - order_by=phone_numbers.c.phone_id)}) - - mapper(Company, companies, properties={ - 'addresses': relation(Address, lazy=False, backref='company', - order_by=addresses.c.address_id)}) - - mapper(Item, items) - - mapper(Invoice, invoices, properties={ - 'items': relation(Item, lazy=False, backref='invoice', - order_by=items.c.item_id), - 'company': relation(Company, lazy=False, backref='invoices')}) - - c1 = Company(company_name='company 1', addresses=[ - Address(address='a1 address', - phones=[Phone(type='home', number='1111'), - Phone(type='work', number='22222')]), - Address(address='a2 address', - phones=[Phone(type='home', number='3333'), - Phone(type='work', number='44444')]) - ]) - - session = create_session() - session.add(c1) - session.flush() - - company_id = c1.company_id - - session.expunge_all() - - a = session.query(Company).get(company_id) - - # set up an invoice - i1 = Invoice(date=datetime.datetime.now(), company=a) - - item1 = Item(code='aaaa', qty=1, invoice=i1) - item2 = Item(code='bbbb', qty=2, invoice=i1) - item3 = Item(code='cccc', qty=3, invoice=i1) - - session.flush() - invoice_id = i1.invoice_id - - session.expunge_all() - c = session.query(Company).get(company_id) - - session.expunge_all() - i = session.query(Invoice).get(invoice_id) - - eq_(c, i.company) - - -class EagerTest8(_base.MappedTest): - - def define_tables(self, metadata): - Table('prj', metadata, - Column('id', Integer, primary_key=True), - Column('created', sa.DateTime ), - Column('title', sa.Unicode(100))) - - Table('task', metadata, - Column('id', Integer, primary_key=True), - Column('status_id', Integer, - ForeignKey('task_status.id'), nullable=False), - Column('title', sa.Unicode(100)), - Column('task_type_id', Integer , - ForeignKey('task_type.id'), nullable=False), - Column('prj_id', Integer , ForeignKey('prj.id'), nullable=False)) - - Table('task_status', metadata, - Column('id', Integer, primary_key=True)) - - Table('task_type', metadata, - Column('id', Integer, primary_key=True)) - - Table('msg', metadata, - Column('id', Integer, primary_key=True), - Column('posted', sa.DateTime, index=True,), - Column('type_id', Integer, ForeignKey('msg_type.id')), - Column('task_id', Integer, ForeignKey('task.id'))) - - Table('msg_type', metadata, - Column('id', Integer, primary_key=True), - Column('name', sa.Unicode(20)), - Column('display_name', sa.Unicode(20))) - - @testing.resolve_artifact_names - def fixtures(self): - return dict( - prj=(('id',), - (1,)), - - task_status=(('id',), - (1,)), - - task_type=(('id',), - (1,),), - - task=(('title', 'task_type_id', 'status_id', 'prj_id'), - (u'task 1', 1, 1, 1))) - - def setup_classes(self): - class Task_Type(_base.BasicEntity): - pass - - class Joined(_base.ComparableEntity): - pass - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_nested_joins(self): - # this is testing some subtle column resolution stuff, - # concerning corresponding_column() being extremely accurate - # as well as how mapper sets up its column properties - - mapper(Task_Type, task_type) - - tsk_cnt_join = sa.outerjoin(prj, task, task.c.prj_id==prj.c.id) - - j = sa.outerjoin(task, msg, task.c.id==msg.c.task_id) - jj = sa.select([ task.c.id.label('task_id'), - sa.func.count(msg.c.id).label('props_cnt')], - from_obj=[j], - group_by=[task.c.id]).alias('prop_c_s') - jjj = sa.join(task, jj, task.c.id == jj.c.task_id) - - mapper(Joined, jjj, properties=dict( - type=relation(Task_Type, lazy=False))) - - session = create_session() - - eq_(session.query(Joined).limit(10).offset(0).one(), - Joined(id=1, title=u'task 1', props_cnt=0)) - - -class EagerTest9(_base.MappedTest): - """Test the usage of query options to eagerly load specific paths. - - This relies upon the 'path' construct used by PropertyOption to relate - LoaderStrategies to specific paths, as well as the path state maintained - throughout the query setup/mapper instances process. - - """ - def define_tables(self, metadata): - Table('accounts', metadata, - Column('account_id', Integer, primary_key=True), - Column('name', String(40))) - - Table('transactions', metadata, - Column('transaction_id', Integer, primary_key=True), - Column('name', String(40))) - - Table('entries', metadata, - Column('entry_id', Integer, primary_key=True), - Column('name', String(40)), - Column('account_id', Integer, - ForeignKey('accounts.account_id')), - Column('transaction_id', Integer, - ForeignKey('transactions.transaction_id'))) - - def setup_classes(self): - class Account(_base.BasicEntity): - pass - - class Transaction(_base.BasicEntity): - pass - - class Entry(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Account, accounts) - - mapper(Transaction, transactions) - - mapper(Entry, entries, properties=dict( - account=relation(Account, - uselist=False, - backref=backref('entries', lazy=True, - order_by=entries.c.entry_id)), - transaction=relation(Transaction, - uselist=False, - backref=backref('entries', lazy=False, - order_by=entries.c.entry_id)))) - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_eagerload_on_path(self): - session = create_session() - - tx1 = Transaction(name='tx1') - tx2 = Transaction(name='tx2') - - acc1 = Account(name='acc1') - ent11 = Entry(name='ent11', account=acc1, transaction=tx1) - ent12 = Entry(name='ent12', account=acc1, transaction=tx2) - - acc2 = Account(name='acc2') - ent21 = Entry(name='ent21', account=acc2, transaction=tx1) - ent22 = Entry(name='ent22', account=acc2, transaction=tx2) - - session.add(acc1) - session.flush() - session.expunge_all() - - def go(): - # load just the first Account. eager loading will actually load - # all objects saved thus far, but will not eagerly load the - # "accounts" off the immediate "entries"; only the "accounts" off - # the entries->transaction->entries - acc = (session.query(Account). - options(sa.orm.eagerload_all('entries.transaction.entries.account')). - order_by(Account.account_id)).first() - - # no sql occurs - eq_(acc.name, 'acc1') - eq_(acc.entries[0].transaction.entries[0].account.name, 'acc1') - eq_(acc.entries[0].transaction.entries[1].account.name, 'acc2') - - # lazyload triggers but no sql occurs because many-to-one uses - # cached query.get() - for e in acc.entries: - assert e.account is acc - - self.assert_sql_count(testing.db, go, 1) - - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/attributes.py b/test/orm/attributes.py deleted file mode 100644 index 7c116fcf7..000000000 --- a/test/orm/attributes.py +++ /dev/null @@ -1,1331 +0,0 @@ -import testenv; testenv.configure_for_tests() -import pickle -import sqlalchemy.orm.attributes as attributes -from sqlalchemy.orm.collections import collection -from sqlalchemy.orm.interfaces import AttributeExtension -from sqlalchemy import exc as sa_exc -from testlib import * -from testlib.testing import eq_ -from orm import _base -import gc - -# global for pickling tests -MyTest = None -MyTest2 = None - - -class AttributesTest(_base.ORMTest): - def setUp(self): - global MyTest, MyTest2 - class MyTest(object): pass - class MyTest2(object): pass - - def tearDown(self): - global MyTest, MyTest2 - MyTest, MyTest2 = None, None - - def test_basic(self): - class User(object):pass - - attributes.register_class(User) - attributes.register_attribute(User, 'user_id', uselist=False, useobject=False) - attributes.register_attribute(User, 'user_name', uselist=False, useobject=False) - attributes.register_attribute(User, 'email_address', uselist=False, useobject=False) - - u = User() - u.user_id = 7 - u.user_name = 'john' - u.email_address = 'lala@123.com' - - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - attributes.instance_state(u).commit_all(attributes.instance_dict(u)) - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - - u.user_name = 'heythere' - u.email_address = 'foo@bar.com' - self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com') - - def test_pickleness(self): - attributes.register_class(MyTest) - attributes.register_class(MyTest2) - attributes.register_attribute(MyTest, 'user_id', uselist=False, useobject=False) - attributes.register_attribute(MyTest, 'user_name', uselist=False, useobject=False) - attributes.register_attribute(MyTest, 'email_address', uselist=False, useobject=False) - attributes.register_attribute(MyTest, 'some_mutable_data', mutable_scalars=True, copy_function=list, compare_function=cmp, uselist=False, useobject=False) - attributes.register_attribute(MyTest2, 'a', uselist=False, useobject=False) - attributes.register_attribute(MyTest2, 'b', uselist=False, useobject=False) - # shouldnt be pickling callables at the class level - def somecallable(*args): - return None - attributes.register_attribute(MyTest, "mt2", uselist = True, trackparent=True, callable_=somecallable, useobject=True) - - o = MyTest() - o.mt2.append(MyTest2()) - o.user_id=7 - o.some_mutable_data = [1,2,3] - o.mt2[0].a = 'abcde' - pk_o = pickle.dumps(o) - - o2 = pickle.loads(pk_o) - pk_o2 = pickle.dumps(o2) - - # so... pickle is creating a new 'mt2' string after a roundtrip here, - # so we'll brute-force set it to be id-equal to the original string - if False: - o_mt2_str = [ k for k in o.__dict__ if k == 'mt2'][0] - o2_mt2_str = [ k for k in o2.__dict__ if k == 'mt2'][0] - self.assert_(o_mt2_str == o2_mt2_str) - self.assert_(o_mt2_str is not o2_mt2_str) - # change the id of o2.__dict__['mt2'] - former = o2.__dict__['mt2'] - del o2.__dict__['mt2'] - o2.__dict__[o_mt2_str] = former - - self.assert_(pk_o == pk_o2) - - # the above is kind of distrurbing, so let's do it again a little - # differently. the string-id in serialization thing is just an - # artifact of pickling that comes up in the first round-trip. - # a -> b differs in pickle memoization of 'mt2', but b -> c will - # serialize identically. - - o3 = pickle.loads(pk_o2) - pk_o3 = pickle.dumps(o3) - o4 = pickle.loads(pk_o3) - pk_o4 = pickle.dumps(o4) - - self.assert_(pk_o3 == pk_o4) - - # and lastly make sure we still have our data after all that. - # identical serialzation is great, *if* it's complete :) - self.assert_(o4.user_id == 7) - self.assert_(o4.user_name is None) - self.assert_(o4.email_address is None) - self.assert_(o4.some_mutable_data == [1,2,3]) - self.assert_(len(o4.mt2) == 1) - self.assert_(o4.mt2[0].a == 'abcde') - self.assert_(o4.mt2[0].b is None) - - def test_state_gc(self): - """test that InstanceState always has a dict, even after host object gc'ed.""" - - class Foo(object): - pass - - attributes.register_class(Foo) - f = Foo() - state = attributes.instance_state(f) - f.bar = "foo" - assert state.dict == {'bar':'foo', state.manager.STATE_ATTR:state} - del f - gc.collect() - assert state.obj() is None - assert state.dict == {} - - def test_deferred(self): - class Foo(object):pass - - data = {'a':'this is a', 'b':12} - def loader(state, keys): - for k in keys: - state.dict[k] = data[k] - return attributes.ATTR_WAS_SET - - attributes.register_class(Foo) - manager = attributes.manager_of_class(Foo) - manager.deferred_scalar_loader = loader - attributes.register_attribute(Foo, 'a', uselist=False, useobject=False) - attributes.register_attribute(Foo, 'b', uselist=False, useobject=False) - - f = Foo() - attributes.instance_state(f).expire_attributes(None) - eq_(f.a, "this is a") - eq_(f.b, 12) - - f.a = "this is some new a" - attributes.instance_state(f).expire_attributes(None) - eq_(f.a, "this is a") - eq_(f.b, 12) - - attributes.instance_state(f).expire_attributes(None) - f.a = "this is another new a" - eq_(f.a, "this is another new a") - eq_(f.b, 12) - - attributes.instance_state(f).expire_attributes(None) - eq_(f.a, "this is a") - eq_(f.b, 12) - - del f.a - eq_(f.a, None) - eq_(f.b, 12) - - attributes.instance_state(f).commit_all(attributes.instance_dict(f)) - eq_(f.a, None) - eq_(f.b, 12) - - def test_deferred_pickleable(self): - data = {'a':'this is a', 'b':12} - def loader(state, keys): - for k in keys: - state.dict[k] = data[k] - return attributes.ATTR_WAS_SET - - attributes.register_class(MyTest) - manager = attributes.manager_of_class(MyTest) - manager.deferred_scalar_loader=loader - attributes.register_attribute(MyTest, 'a', uselist=False, useobject=False) - attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False) - - m = MyTest() - attributes.instance_state(m).expire_attributes(None) - assert 'a' not in m.__dict__ - m2 = pickle.loads(pickle.dumps(m)) - assert 'a' not in m2.__dict__ - eq_(m2.a, "this is a") - eq_(m2.b, 12) - - def test_list(self): - class User(object):pass - class Address(object):pass - - attributes.register_class(User) - attributes.register_class(Address) - attributes.register_attribute(User, 'user_id', uselist=False, useobject=False) - attributes.register_attribute(User, 'user_name', uselist=False, useobject=False) - attributes.register_attribute(User, 'addresses', uselist = True, useobject=True) - attributes.register_attribute(Address, 'address_id', uselist=False, useobject=False) - attributes.register_attribute(Address, 'email_address', uselist=False, useobject=False) - - u = User() - u.user_id = 7 - u.user_name = 'john' - u.addresses = [] - a = Address() - a.address_id = 10 - a.email_address = 'lala@123.com' - u.addresses.append(a) - - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') - u, attributes.instance_state(a).commit_all(attributes.instance_dict(a)) - self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') - - u.user_name = 'heythere' - a = Address() - a.address_id = 11 - a.email_address = 'foo@bar.com' - u.addresses.append(a) - self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com') - - def test_scalar_listener(self): - # listeners on ScalarAttributeImpl and MutableScalarAttributeImpl aren't used normally. - # test that they work for the benefit of user extensions - class Foo(object): - pass - - results = [] - class ReceiveEvents(AttributeExtension): - def append(self, state, child, initiator): - assert False - - def remove(self, state, child, initiator): - results.append(("remove", state.obj(), child)) - - def set(self, state, child, oldchild, initiator): - results.append(("set", state.obj(), child, oldchild)) - return child - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'x', uselist=False, mutable_scalars=False, useobject=False, extension=ReceiveEvents()) - attributes.register_attribute(Foo, 'y', uselist=False, mutable_scalars=True, useobject=False, copy_function=lambda x:x, extension=ReceiveEvents()) - - f = Foo() - f.x = 5 - f.x = 17 - del f.x - f.y = [1,2,3] - f.y = [4,5,6] - del f.y - - eq_(results, [ - ('set', f, 5, None), - ('set', f, 17, 5), - ('remove', f, 17), - ('set', f, [1,2,3], None), - ('set', f, [4,5,6], [1,2,3]), - ('remove', f, [4,5,6]) - ]) - - - def test_lazytrackparent(self): - """test that the "hasparent" flag works properly when lazy loaders and backrefs are used""" - - class Post(object):pass - class Blog(object):pass - attributes.register_class(Post) - attributes.register_class(Blog) - - # set up instrumented attributes with backrefs - attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) - - # create objects as if they'd been freshly loaded from the database (without history) - b = Blog() - p1 = Post() - attributes.instance_state(b).set_callable('posts', lambda:[p1]) - attributes.instance_state(p1).set_callable('blog', lambda:b) - p1, attributes.instance_state(b).commit_all(attributes.instance_dict(b)) - - # no orphans (called before the lazy loaders fire off) - assert attributes.has_parent(Blog, p1, 'posts', optimistic=True) - assert attributes.has_parent(Post, b, 'blog', optimistic=True) - - # assert connections - assert p1.blog is b - assert p1 in b.posts - - # manual connections - b2 = Blog() - p2 = Post() - b2.posts.append(p2) - assert attributes.has_parent(Blog, p2, 'posts') - assert attributes.has_parent(Post, b2, 'blog') - - def test_inheritance(self): - """tests that attributes are polymorphic""" - class Foo(object):pass - class Bar(Foo):pass - - - attributes.register_class(Foo) - attributes.register_class(Bar) - - def func1(): - print "func1" - return "this is the foo attr" - def func2(): - print "func2" - return "this is the bar attr" - def func3(): - print "func3" - return "this is the shared attr" - attributes.register_attribute(Foo, 'element', uselist=False, callable_=lambda o:func1, useobject=True) - attributes.register_attribute(Foo, 'element2', uselist=False, callable_=lambda o:func3, useobject=True) - attributes.register_attribute(Bar, 'element', uselist=False, callable_=lambda o:func2, useobject=True) - - x = Foo() - y = Bar() - assert x.element == 'this is the foo attr' - assert y.element == 'this is the bar attr' - assert x.element2 == 'this is the shared attr' - assert y.element2 == 'this is the shared attr' - - def test_no_double_state(self): - states = set() - class Foo(object): - def __init__(self): - states.add(attributes.instance_state(self)) - class Bar(Foo): - def __init__(self): - states.add(attributes.instance_state(self)) - Foo.__init__(self) - - - attributes.register_class(Foo) - attributes.register_class(Bar) - - b = Bar() - eq_(len(states), 1) - eq_(list(states)[0].obj(), b) - - - def test_inheritance2(self): - """test that the attribute manager can properly traverse the managed attributes of an object, - if the object is of a descendant class with managed attributes in the parent class""" - class Foo(object):pass - class Bar(Foo):pass - - class Element(object): - _state = True - - attributes.register_class(Foo) - attributes.register_class(Bar) - attributes.register_attribute(Foo, 'element', uselist=False, useobject=True) - el = Element() - x = Bar() - x.element = el - eq_(attributes.get_history(attributes.instance_state(x), 'element'), ([el], (), ())) - attributes.instance_state(x).commit_all(attributes.instance_dict(x)) - - (added, unchanged, deleted) = attributes.get_history(attributes.instance_state(x), 'element') - assert added == () - assert unchanged == [el] - - def test_lazyhistory(self): - """tests that history functions work with lazy-loading attributes""" - - class Foo(_base.BasicEntity): - pass - class Bar(_base.BasicEntity): - pass - - attributes.register_class(Foo) - attributes.register_class(Bar) - - bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] - def func1(): - return "this is func 1" - def func2(): - return [bar1, bar2, bar3] - - attributes.register_attribute(Foo, 'col1', uselist=False, callable_=lambda o:func1, useobject=True) - attributes.register_attribute(Foo, 'col2', uselist=True, callable_=lambda o:func2, useobject=True) - attributes.register_attribute(Bar, 'id', uselist=False, useobject=True) - - x = Foo() - attributes.instance_state(x).commit_all(attributes.instance_dict(x)) - x.col2.append(bar4) - eq_(attributes.get_history(attributes.instance_state(x), 'col2'), ([bar4], [bar1, bar2, bar3], [])) - - def test_parenttrack(self): - class Foo(object):pass - class Bar(object):pass - - attributes.register_class(Foo) - attributes.register_class(Bar) - - attributes.register_attribute(Foo, 'element', uselist=False, trackparent=True, useobject=True) - attributes.register_attribute(Bar, 'element', uselist=False, trackparent=True, useobject=True) - - f1 = Foo() - f2 = Foo() - b1 = Bar() - b2 = Bar() - - f1.element = b1 - b2.element = f2 - - assert attributes.has_parent(Foo, b1, 'element') - assert not attributes.has_parent(Foo, b2, 'element') - assert not attributes.has_parent(Foo, f2, 'element') - assert attributes.has_parent(Bar, f2, 'element') - - b2.element = None - assert not attributes.has_parent(Bar, f2, 'element') - - # test that double assignment doesn't accidentally reset the 'parent' flag. - b3 = Bar() - f4 = Foo() - b3.element = f4 - assert attributes.has_parent(Bar, f4, 'element') - b3.element = f4 - assert attributes.has_parent(Bar, f4, 'element') - - def test_mutablescalars(self): - """test detection of changes on mutable scalar items""" - class Foo(object):pass - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False) - x = Foo() - x.element = ['one', 'two', 'three'] - attributes.instance_state(x).commit_all(attributes.instance_dict(x)) - x.element[1] = 'five' - assert attributes.instance_state(x).check_modified() - - attributes.unregister_class(Foo) - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'element', uselist=False, useobject=False) - x = Foo() - x.element = ['one', 'two', 'three'] - attributes.instance_state(x).commit_all(attributes.instance_dict(x)) - x.element[1] = 'five' - assert not attributes.instance_state(x).check_modified() - - def test_descriptorattributes(self): - """changeset: 1633 broke ability to use ORM to map classes with unusual - descriptor attributes (for example, classes that inherit from ones - implementing zope.interface.Interface). - This is a simple regression test to prevent that defect. - """ - class des(object): - def __get__(self, instance, owner): - raise AttributeError('fake attribute') - - class Foo(object): - A = des() - - attributes.register_class(Foo) - attributes.unregister_class(Foo) - - def test_collectionclasses(self): - - class Foo(object):pass - attributes.register_class(Foo) - - attributes.register_attribute(Foo, "collection", uselist=True, typecallable=set, useobject=True) - assert attributes.manager_of_class(Foo).is_instrumented("collection") - assert isinstance(Foo().collection, set) - - attributes.unregister_attribute(Foo, "collection") - assert not attributes.manager_of_class(Foo).is_instrumented("collection") - - try: - attributes.register_attribute(Foo, "collection", uselist=True, typecallable=dict, useobject=True) - assert False - except sa_exc.ArgumentError, e: - assert str(e) == "Type InstrumentedDict must elect an appender method to be a collection class" - - class MyDict(dict): - @collection.appender - def append(self, item): - self[item.foo] = item - @collection.remover - def remove(self, item): - del self[item.foo] - attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyDict, useobject=True) - assert isinstance(Foo().collection, MyDict) - - attributes.unregister_attribute(Foo, "collection") - - class MyColl(object):pass - try: - attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl, useobject=True) - assert False - except sa_exc.ArgumentError, e: - assert str(e) == "Type MyColl must elect an appender method to be a collection class" - - class MyColl(object): - @collection.iterator - def __iter__(self): - return iter([]) - @collection.appender - def append(self, item): - pass - @collection.remover - def remove(self, item): - pass - attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl, useobject=True) - try: - Foo().collection - assert True - except sa_exc.ArgumentError, e: - assert False - - -class BackrefTest(_base.ORMTest): - - def test_manytomany(self): - class Student(object):pass - class Course(object):pass - - attributes.register_class(Student) - attributes.register_class(Course) - attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True) - attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True) - - s = Student() - c = Course() - s.courses.append(c) - self.assert_(c.students == [s]) - s.courses.remove(c) - self.assert_(c.students == []) - - (s1, s2, s3) = (Student(), Student(), Student()) - - c.students = [s1, s2, s3] - self.assert_(s2.courses == [c]) - self.assert_(s1.courses == [c]) - s1.courses.remove(c) - self.assert_(c.students == [s2,s3]) - - def test_onetomany(self): - class Post(object):pass - class Blog(object):pass - - attributes.register_class(Post) - attributes.register_class(Blog) - attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) - b = Blog() - (p1, p2, p3) = (Post(), Post(), Post()) - b.posts.append(p1) - b.posts.append(p2) - b.posts.append(p3) - self.assert_(b.posts == [p1, p2, p3]) - self.assert_(p2.blog is b) - - p3.blog = None - self.assert_(b.posts == [p1, p2]) - p4 = Post() - p4.blog = b - self.assert_(b.posts == [p1, p2, p4]) - - p4.blog = b - p4.blog = b - self.assert_(b.posts == [p1, p2, p4]) - - # assert no failure removing None - p5 = Post() - p5.blog = None - del p5.blog - - def test_onetoone(self): - class Port(object):pass - class Jack(object):pass - attributes.register_class(Port) - attributes.register_class(Jack) - attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True) - attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True) - p = Port() - j = Jack() - p.jack = j - self.assert_(j.port is p) - self.assert_(p.jack is not None) - - j.port = None - self.assert_(p.jack is None) - -class PendingBackrefTest(_base.ORMTest): - def setUp(self): - global Post, Blog, called, lazy_load - - class Post(object): - def __init__(self, name): - self.name = name - __hash__ = None - def __eq__(self, other): - return other.name == self.name - - class Blog(object): - def __init__(self, name): - self.name = name - __hash__ = None - def __eq__(self, other): - return other.name == self.name - - called = [0] - - lazy_load = [] - def lazy_posts(instance): - def load(): - called[0] += 1 - return lazy_load - return load - - attributes.register_class(Post) - attributes.register_class(Blog) - attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), callable_=lazy_posts, trackparent=True, useobject=True) - - def test_lazy_add(self): - global lazy_load - - p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3") - lazy_load = [p1, p2, p3] - - b = Blog("blog 1") - p = Post("post 4") - - p.blog = b - p = Post("post 5") - p.blog = b - # setting blog doesnt call 'posts' callable - assert called[0] == 0 - - # calling backref calls the callable, populates extra posts - assert b.posts == [p1, p2, p3, Post("post 4"), Post("post 5")] - assert called[0] == 1 - - def test_lazy_history(self): - global lazy_load - - p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3") - lazy_load = [p1, p2, p3] - - b = Blog("blog 1") - p = Post("post 4") - p.blog = b - - p4 = Post("post 5") - p4.blog = b - assert called[0] == 0 - eq_(attributes.instance_state(b).get_history('posts'), ([p, p4], [p1, p2, p3], [])) - assert called[0] == 1 - - def test_lazy_remove(self): - global lazy_load - called[0] = 0 - lazy_load = [] - - b = Blog("blog 1") - p = Post("post 1") - p.blog = b - assert called[0] == 0 - - lazy_load = [p] - - p.blog = None - p2 = Post("post 2") - p2.blog = b - assert called[0] == 0 - assert b.posts == [p2] - assert called[0] == 1 - - def test_normal_load(self): - global lazy_load - lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] - called[0] = 0 - - b = Blog("blog 1") - - # assign without using backref system - p2.__dict__['blog'] = b - - assert b.posts == [Post("post 1"), Post("post 2"), Post("post 3")] - assert called[0] == 1 - p2.blog = None - p4 = Post("post 4") - p4.blog = b - assert b.posts == [Post("post 1"), Post("post 3"), Post("post 4")] - assert called[0] == 1 - - called[0] = 0 - lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] - - def test_commit_removes_pending(self): - global lazy_load - lazy_load = (p1, ) = [Post("post 1"), ] - called[0] = 0 - - b = Blog("blog 1") - p1.blog = b - attributes.instance_state(b).commit_all(attributes.instance_dict(b)) - attributes.instance_state(p1).commit_all(attributes.instance_dict(p1)) - assert b.posts == [Post("post 1")] - -class HistoryTest(_base.ORMTest): - - def test_get_committed_value(self): - class Foo(_base.BasicEntity): - pass - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False) - - f = Foo() - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) - - f.someattr = 3 - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) - - f = Foo() - f.someattr = 3 - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), 3) - - def test_scalar(self): - class Foo(_base.BasicEntity): - pass - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False) - - # case 1. new object - f = Foo() - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), (), ())) - - f.someattr = "hi" - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['hi'], (), ())) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['hi'], ())) - - f.someattr = 'there' - - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['there'], (), ['hi'])) - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['there'], ())) - - del f.someattr - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), (), ['there'])) - - # case 2. object with direct dictionary settings (similar to a load operation) - f = Foo() - f.__dict__['someattr'] = 'new' - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['new'], ())) - - f.someattr = 'old' - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['old'], (), ['new'])) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['old'], ())) - - # setting None on uninitialized is currently a change for a scalar attribute - # no lazyload occurs so this allows overwrite operation to proceed - f = Foo() - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), (), ())) - f.someattr = None - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], (), ())) - - f = Foo() - f.__dict__['someattr'] = 'new' - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['new'], ())) - f.someattr = None - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], (), ['new'])) - - # set same value twice - f = Foo() - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - f.someattr = 'one' - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ())) - f.someattr = 'two' - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['two'], (), ())) - - - def test_mutable_scalar(self): - class Foo(_base.BasicEntity): - pass - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False, mutable_scalars=True, copy_function=dict) - - # case 1. new object - f = Foo() - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), (), ())) - - f.someattr = {'foo':'hi'} - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'hi'}], (), ())) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'hi'}], ())) - eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'}) - - f.someattr['foo'] = 'there' - eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'}) - - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'there'}], (), [{'foo':'hi'}])) - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'there'}], ())) - - # case 2. object with direct dictionary settings (similar to a load operation) - f = Foo() - f.__dict__['someattr'] = {'foo':'new'} - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'new'}], ())) - - f.someattr = {'foo':'old'} - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'old'}], (), [{'foo':'new'}])) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'old'}], ())) - - - def test_use_object(self): - class Foo(_base.BasicEntity): - pass - - class Bar(_base.BasicEntity): - _state = None - def __nonzero__(self): - assert False - - hi = Bar(name='hi') - there = Bar(name='there') - new = Bar(name='new') - old = Bar(name='old') - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=True) - - # case 1. new object - f = Foo() - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [None], ())) - - f.someattr = hi - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], (), ())) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) - - f.someattr = there - - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], (), [hi])) - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ())) - - del f.someattr - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], (), [there])) - - # case 2. object with direct dictionary settings (similar to a load operation) - f = Foo() - f.__dict__['someattr'] = 'new' - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['new'], ())) - - f.someattr = old - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], (), ['new'])) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ())) - - # setting None on uninitialized is currently not a change for an object attribute - # (this is different than scalar attribute). a lazyload has occured so if its - # None, its really None - f = Foo() - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [None], ())) - f.someattr = None - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [None], ())) - - f = Foo() - f.__dict__['someattr'] = 'new' - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['new'], ())) - f.someattr = None - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], (), ['new'])) - - # set same value twice - f = Foo() - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - f.someattr = 'one' - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ())) - f.someattr = 'two' - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['two'], (), ())) - - def test_object_collections_set(self): - class Foo(_base.BasicEntity): - pass - class Bar(_base.BasicEntity): - def __nonzero__(self): - assert False - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True) - - hi = Bar(name='hi') - there = Bar(name='there') - old = Bar(name='old') - new = Bar(name='new') - - # case 1. new object - f = Foo() - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [], ())) - - f.someattr = [hi] - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) - - f.someattr = [there] - - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi])) - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ())) - - f.someattr = [hi] - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [there])) - - f.someattr = [old, new] - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [], [there])) - - # case 2. object with direct settings (similar to a load operation) - f = Foo() - collection = attributes.init_collection(attributes.instance_state(f), 'someattr') - collection.append_without_event(new) - attributes.instance_state(f).commit_all(attributes.instance_dict(f)) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) - - f.someattr = [old] - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], [new])) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ())) - - def test_dict_collections(self): - class Foo(_base.BasicEntity): - pass - class Bar(_base.BasicEntity): - pass - - from sqlalchemy.orm.collections import attribute_mapped_collection - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True, typecallable=attribute_mapped_collection('name')) - - hi = Bar(name='hi') - there = Bar(name='there') - old = Bar(name='old') - new = Bar(name='new') - - f = Foo() - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [], ())) - - f.someattr['hi'] = hi - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) - - f.someattr['there'] = there - eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([hi, there]), set(), set())) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set(), set([hi, there]), set())) - - def test_object_collections_mutate(self): - class Foo(_base.BasicEntity): - pass - class Bar(_base.BasicEntity): - pass - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True) - attributes.register_attribute(Foo, 'id', uselist=False, useobject=False) - - hi = Bar(name='hi') - there = Bar(name='there') - old = Bar(name='old') - new = Bar(name='new') - - # case 1. new object - f = Foo(id=1) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [], ())) - - f.someattr.append(hi) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) - - f.someattr.append(there) - - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [hi], [])) - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there], ())) - - f.someattr.remove(there) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], [there])) - - f.someattr.append(old) - f.someattr.append(new) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [hi], [there])) - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, old, new], ())) - - f.someattr.pop(0) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old, new], [hi])) - - # case 2. object with direct settings (similar to a load operation) - f = Foo() - f.__dict__['id'] = 1 - collection = attributes.init_collection(attributes.instance_state(f), 'someattr') - collection.append_without_event(new) - attributes.instance_state(f).commit_all(attributes.instance_dict(f)) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) - - f.someattr.append(old) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [new], [])) - - attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new, old], ())) - - f = Foo() - collection = attributes.init_collection(attributes.instance_state(f), 'someattr') - collection.append_without_event(new) - attributes.instance_state(f).commit_all(attributes.instance_dict(f)) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) - - f.id = 1 - f.someattr.remove(new) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], [new])) - - # case 3. mixing appends with sets - f = Foo() - f.someattr.append(hi) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) - f.someattr.append(there) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there], [], [])) - f.someattr = [there] - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [])) - - # case 4. ensure duplicates show up, order is maintained - f = Foo() - f.someattr.append(hi) - f.someattr.append(there) - f.someattr.append(hi) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there, hi], [], [])) - - attributes.instance_state(f).commit_all(attributes.instance_dict(f)) - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there, hi], ())) - - f.someattr = [] - eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], [hi, there, hi])) - - def test_collections_via_backref(self): - class Foo(_base.BasicEntity): - pass - class Bar(_base.BasicEntity): - pass - - attributes.register_class(Foo) - attributes.register_class(Bar) - attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, useobject=True) - attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True) - - f1 = Foo() - b1 = Bar() - eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [], ())) - eq_(attributes.get_history(attributes.instance_state(b1), 'foo'), ((), [None], ())) - - #b1.foo = f1 - f1.bars.append(b1) - eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], [])) - eq_(attributes.get_history(attributes.instance_state(b1), 'foo'), ([f1], (), ())) - - b2 = Bar() - f1.bars.append(b2) - eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1, b2], [], [])) - eq_(attributes.get_history(attributes.instance_state(b1), 'foo'), ([f1], (), ())) - eq_(attributes.get_history(attributes.instance_state(b2), 'foo'), ([f1], (), ())) - - def test_lazy_backref_collections(self): - class Foo(_base.BasicEntity): - pass - class Bar(_base.BasicEntity): - pass - - lazy_load = [] - def lazyload(instance): - def load(): - return lazy_load - return load - - attributes.register_class(Foo) - attributes.register_class(Bar) - attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, callable_=lazyload, useobject=True) - attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True) - - bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] - lazy_load = [bar1, bar2, bar3] - - f = Foo() - bar4 = Bar() - bar4.foo = f - eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar2, bar3], [])) - - lazy_load = None - f = Foo() - bar4 = Bar() - bar4.foo = f - eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [], [])) - - lazy_load = [bar1, bar2, bar3] - attributes.instance_state(f).expire_attributes(['bars']) - eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ((), [bar1, bar2, bar3], ())) - - def test_collections_via_lazyload(self): - class Foo(_base.BasicEntity): - pass - class Bar(_base.BasicEntity): - pass - - lazy_load = [] - def lazyload(instance): - def load(): - return lazy_load - return load - - attributes.register_class(Foo) - attributes.register_class(Bar) - attributes.register_attribute(Foo, 'bars', uselist=True, callable_=lazyload, trackparent=True, useobject=True) - - bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] - lazy_load = [bar1, bar2, bar3] - - f = Foo() - f.bars = [] - eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [], [bar1, bar2, bar3])) - - f = Foo() - f.bars.append(bar4) - eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar2, bar3], []) ) - - f = Foo() - f.bars.remove(bar2) - eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar3], [bar2])) - f.bars.append(bar4) - eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar3], [bar2])) - - f = Foo() - del f.bars[1] - eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar3], [bar2])) - - lazy_load = None - f = Foo() - f.bars.append(bar2) - eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar2], [], [])) - - def test_scalar_via_lazyload(self): - class Foo(_base.BasicEntity): - pass - - lazy_load = None - def lazyload(instance): - def load(): - return lazy_load - return load - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'bar', uselist=False, callable_=lazyload, useobject=False) - lazy_load = "hi" - - # with scalar non-object and active_history=False, the lazy callable is only executed on gets, not history - # operations - - f = Foo() - eq_(f.bar, "hi") - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), ["hi"], ())) - - f = Foo() - f.bar = None - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), ())) - - f = Foo() - f.bar = "there" - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), (["there"], (), ())) - f.bar = "hi" - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), (["hi"], (), ())) - - f = Foo() - eq_(f.bar, "hi") - del f.bar - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), (), ["hi"])) - assert f.bar is None - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), ["hi"])) - - def test_scalar_via_lazyload_with_active(self): - class Foo(_base.BasicEntity): - pass - - lazy_load = None - def lazyload(instance): - def load(): - return lazy_load - return load - - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'bar', uselist=False, callable_=lazyload, useobject=False, active_history=True) - lazy_load = "hi" - - # active_history=True means the lazy callable is executed on set as well as get, - # causing the old value to appear in the history - - f = Foo() - eq_(f.bar, "hi") - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), ["hi"], ())) - - f = Foo() - f.bar = None - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), ['hi'])) - - f = Foo() - f.bar = "there" - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), (["there"], (), ['hi'])) - f.bar = "hi" - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), ["hi"], ())) - - f = Foo() - eq_(f.bar, "hi") - del f.bar - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), (), ["hi"])) - assert f.bar is None - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), ["hi"])) - - def test_scalar_object_via_lazyload(self): - class Foo(_base.BasicEntity): - pass - class Bar(_base.BasicEntity): - pass - - lazy_load = None - def lazyload(instance): - def load(): - return lazy_load - return load - - attributes.register_class(Foo) - attributes.register_class(Bar) - attributes.register_attribute(Foo, 'bar', uselist=False, callable_=lazyload, trackparent=True, useobject=True) - bar1, bar2 = [Bar(id=1), Bar(id=2)] - lazy_load = bar1 - - # with scalar object, the lazy callable is only executed on gets and history - # operations - - f = Foo() - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), [bar1], ())) - - f = Foo() - f.bar = None - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), [bar1])) - - f = Foo() - f.bar = bar2 - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([bar2], (), [bar1])) - f.bar = bar1 - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), [bar1], ())) - - f = Foo() - eq_(f.bar, bar1) - del f.bar - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), [bar1])) - assert f.bar is None - eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), [bar1])) - -class ListenerTest(_base.ORMTest): - def test_receive_changes(self): - """test that Listeners can mutate the given value. - - This is a rudimentary test which would be better suited by a full-blown inclusion - into collection.py. - - """ - class Foo(object): - pass - class Bar(object): - pass - - class AlteringListener(AttributeExtension): - def append(self, state, child, initiator): - b2 = Bar() - b2.data = b1.data + " appended" - return b2 - - def set(self, state, value, oldvalue, initiator): - return value + " modified" - - attributes.register_class(Foo) - attributes.register_class(Bar) - attributes.register_attribute(Foo, 'data', uselist=False, useobject=False, extension=AlteringListener()) - attributes.register_attribute(Foo, 'barlist', uselist=True, useobject=True, extension=AlteringListener()) - attributes.register_attribute(Foo, 'barset', typecallable=set, uselist=True, useobject=True, extension=AlteringListener()) - attributes.register_attribute(Bar, 'data', uselist=False, useobject=False) - - f1 = Foo() - f1.data = "some data" - eq_(f1.data, "some data modified") - b1 = Bar() - b1.data = "some bar" - f1.barlist.append(b1) - assert b1.data == "some bar" - assert f1.barlist[0].data == "some bar appended" - - f1.barset.add(b1) - assert f1.barset.pop().data == "some bar appended" - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/bind.py b/test/orm/bind.py deleted file mode 100644 index 33d028d22..000000000 --- a/test/orm/bind.py +++ /dev/null @@ -1,55 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib.sa import MetaData, Table, Column, Integer -from testlib.sa.orm import mapper, create_session -from testlib import sa, testing -from orm import _base - - -class BindTest(_base.MappedTest): - def define_tables(self, metadata): - Table('test_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', Integer)) - - def setup_classes(self): - class Foo(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - meta = MetaData() - test_table.tometadata(meta) - - assert meta.tables['test_table'].bind is None - mapper(Foo, meta.tables['test_table']) - - @testing.resolve_artifact_names - def test_session_bind(self): - engine = self.metadata.bind - - for bind in (engine, engine.connect()): - try: - sess = create_session(bind=bind) - assert sess.bind is bind - f = Foo() - sess.add(f) - sess.flush() - assert sess.query(Foo).get(f.id) is f - finally: - if hasattr(bind, 'close'): - bind.close() - - @testing.resolve_artifact_names - def test_session_unbound(self): - sess = create_session() - sess.add(Foo()) - self.assertRaisesMessage( - sa.exc.UnboundExecutionError, - ('Could not locate a bind configured on Mapper|Foo|test_table ' - 'or this Session'), - sess.flush) - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/cascade.py b/test/orm/cascade.py deleted file mode 100644 index c827a85ce..000000000 --- a/test/orm/cascade.py +++ /dev/null @@ -1,1292 +0,0 @@ -import testenv; testenv.configure_for_tests() - -from testlib.sa import Table, Column, Integer, String, ForeignKey, Sequence, exc as sa_exc -from testlib.sa.orm import mapper, relation, create_session, class_mapper, backref -from testlib.sa.orm import attributes, exc as orm_exc -from testlib import testing -from testlib.testing import eq_ -from orm import _base, _fixtures - - -class O2MCascadeTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Address, addresses) - mapper(User, users, properties = dict( - addresses = relation(Address, cascade="all, delete-orphan", backref="user"), - orders = relation( - mapper(Order, orders), cascade="all, delete-orphan") - )) - mapper(Dingaling,dingalings, properties={ - 'address':relation(Address) - }) - - @testing.resolve_artifact_names - def test_list_assignment(self): - sess = create_session() - u = User(name='jack', orders=[ - Order(description='someorder'), - Order(description='someotherorder')]) - sess.add(u) - sess.flush() - sess.expunge_all() - - u = sess.query(User).get(u.id) - eq_(u, User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')])) - - u.orders=[Order(description="order 3"), Order(description="order 4")] - sess.flush() - sess.expunge_all() - - u = sess.query(User).get(u.id) - eq_(u, User(name='jack', - orders=[Order(description="order 3"), - Order(description="order 4")])) - - eq_(sess.query(Order).all(), - [Order(description="order 3"), Order(description="order 4")]) - - o5 = Order(description="order 5") - sess.add(o5) - try: - sess.flush() - assert False - except orm_exc.FlushError, e: - assert "is an orphan" in str(e) - - - @testing.resolve_artifact_names - def test_delete(self): - sess = create_session() - u = User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')]) - sess.add(u) - sess.flush() - - sess.delete(u) - sess.flush() - assert users.count().scalar() == 0 - assert orders.count().scalar() == 0 - - @testing.resolve_artifact_names - def test_delete_unloaded_collections(self): - """Unloaded collections are still included in a delete-cascade by default.""" - sess = create_session() - u = User(name='jack', - addresses=[Address(email_address="address1"), - Address(email_address="address2")]) - sess.add(u) - sess.flush() - sess.expunge_all() - assert addresses.count().scalar() == 2 - assert users.count().scalar() == 1 - - u = sess.query(User).get(u.id) - - assert 'addresses' not in u.__dict__ - sess.delete(u) - sess.flush() - assert addresses.count().scalar() == 0 - assert users.count().scalar() == 0 - - @testing.resolve_artifact_names - def test_cascades_onlycollection(self): - """Cascade only reaches instances that are still part of the collection, - not those that have been removed""" - - sess = create_session() - u = User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')]) - sess.add(u) - sess.flush() - - o = u.orders[0] - del u.orders[0] - sess.delete(u) - assert u in sess.deleted - assert o not in sess.deleted - assert o in sess - - u2 = User(name='newuser', orders=[o]) - sess.add(u2) - sess.flush() - sess.expunge_all() - assert users.count().scalar() == 1 - assert orders.count().scalar() == 1 - eq_(sess.query(User).all(), - [User(name='newuser', - orders=[Order(description='someorder')])]) - - @testing.resolve_artifact_names - def test_cascade_nosideeffects(self): - """test that cascade leaves the state of unloaded scalars/collections unchanged.""" - - sess = create_session() - u = User(name='jack') - sess.add(u) - assert 'orders' not in u.__dict__ - - sess.flush() - - assert 'orders' not in u.__dict__ - - a = Address(email_address='foo@bar.com') - sess.add(a) - assert 'user' not in a.__dict__ - a.user = u - sess.flush() - - d = Dingaling(data='d1') - d.address_id = a.id - sess.add(d) - assert 'address' not in d.__dict__ - sess.flush() - assert d.address is a - - @testing.resolve_artifact_names - def test_cascade_delete_plusorphans(self): - sess = create_session() - u = User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')]) - sess.add(u) - sess.flush() - assert users.count().scalar() == 1 - assert orders.count().scalar() == 2 - - del u.orders[0] - sess.delete(u) - sess.flush() - assert users.count().scalar() == 0 - assert orders.count().scalar() == 0 - - @testing.resolve_artifact_names - def test_collection_orphans(self): - sess = create_session() - u = User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')]) - sess.add(u) - sess.flush() - - assert users.count().scalar() == 1 - assert orders.count().scalar() == 2 - - u.orders[:] = [] - - sess.flush() - - assert users.count().scalar() == 1 - assert orders.count().scalar() == 0 - -class O2OCascadeTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Address, addresses) - mapper(User, users, properties = { - 'address':relation(Address, backref=backref("user", single_parent=True), uselist=False) - }) - - @testing.resolve_artifact_names - def test_single_parent_raise(self): - a1 = Address(email_address='some address') - u1 = User(name='u1', address=a1) - - self.assertRaises(sa_exc.InvalidRequestError, Address, email_address='asd', user=u1) - - a2 = Address(email_address='asd') - u1.address = a2 - assert u1.address is not a1 - assert a1.user is None - - - -class O2MBackrefTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(User, users, properties = dict( - orders = relation( - mapper(Order, orders), cascade="all, delete-orphan", backref="user") - )) - - @testing.resolve_artifact_names - def test_lazyload_bug(self): - sess = create_session() - - u = User(name="jack") - sess.add(u) - sess.expunge(u) - - o1 = Order(description='someorder') - o1.user = u - sess.add(u) - assert u in sess - assert o1 in sess - - -class NoSaveCascadeTest(_fixtures.FixtureTest): - """test that backrefs don't force save-update cascades to occur - when the cascade initiated from the forwards side.""" - - @testing.resolve_artifact_names - def test_unidirectional_cascade_o2m(self): - mapper(Order, orders) - mapper(User, users, properties = dict( - orders = relation( - Order, backref=backref("user", cascade=None)) - )) - - sess = create_session() - - o1 = Order() - sess.add(o1) - u1 = User(orders=[o1]) - assert u1 not in sess - assert o1 in sess - - sess.expunge_all() - - o1 = Order() - u1 = User(orders=[o1]) - sess.add(o1) - assert u1 not in sess - assert o1 in sess - - @testing.resolve_artifact_names - def test_unidirectional_cascade_m2o(self): - mapper(Order, orders, properties={ - 'user':relation(User, backref=backref("orders", cascade=None)) - }) - mapper(User, users) - - sess = create_session() - - u1 = User() - sess.add(u1) - o1 = Order() - o1.user = u1 - assert o1 not in sess - assert u1 in sess - - sess.expunge_all() - - u1 = User() - o1 = Order() - o1.user = u1 - sess.add(u1) - assert o1 not in sess - assert u1 in sess - - @testing.resolve_artifact_names - def test_unidirectional_cascade_m2m(self): - mapper(Item, items, properties={ - 'keywords':relation(Keyword, secondary=item_keywords, cascade="none", backref="items") - }) - mapper(Keyword, keywords) - - sess = create_session() - - i1 = Item() - k1 = Keyword() - sess.add(i1) - i1.keywords.append(k1) - assert i1 in sess - assert k1 not in sess - - sess.expunge_all() - - i1 = Item() - k1 = Keyword() - sess.add(i1) - k1.items.append(i1) - assert i1 in sess - assert k1 not in sess - - -class O2MCascadeNoOrphanTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(User, users, properties = dict( - orders = relation( - mapper(Order, orders), cascade="all") - )) - - @testing.resolve_artifact_names - def test_cascade_delete_noorphans(self): - sess = create_session() - u = User(name='jack', - orders=[Order(description='someorder'), - Order(description='someotherorder')]) - sess.add(u) - sess.flush() - assert users.count().scalar() == 1 - assert orders.count().scalar() == 2 - - del u.orders[0] - sess.delete(u) - sess.flush() - assert users.count().scalar() == 0 - assert orders.count().scalar() == 1 - - -class M2OCascadeTest(_base.MappedTest): - def define_tables(self, metadata): - Table("extra", metadata, - Column("id", Integer, Sequence("extra_id_seq", optional=True), - primary_key=True), - Column("prefs_id", Integer, ForeignKey("prefs.id"))) - - Table('prefs', metadata, - Column('id', Integer, Sequence('prefs_id_seq', optional=True), - primary_key=True), - Column('data', String(40))) - - Table('users', metadata, - Column('id', Integer, Sequence('user_id_seq', optional=True), - primary_key=True), - Column('name', String(40)), - Column('pref_id', Integer, ForeignKey('prefs.id'))) - - def setup_classes(self): - class User(_fixtures.Base): - pass - class Pref(_fixtures.Base): - pass - class Extra(_fixtures.Base): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Extra, extra) - mapper(Pref, prefs, properties=dict( - extra = relation(Extra, cascade="all, delete") - )) - mapper(User, users, properties = dict( - pref = relation(Pref, lazy=False, cascade="all, delete-orphan", single_parent=True ) - )) - - @testing.resolve_artifact_names - def insert_data(self): - u1 = User(name='ed', pref=Pref(data="pref 1", extra=[Extra()])) - u2 = User(name='jack', pref=Pref(data="pref 2", extra=[Extra()])) - u3 = User(name="foo", pref=Pref(data="pref 3", extra=[Extra()])) - sess = create_session() - sess.add_all((u1, u2, u3)) - sess.flush() - sess.close() - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_orphan(self): - sess = create_session() - assert prefs.count().scalar() == 3 - assert extra.count().scalar() == 3 - jack = sess.query(User).filter_by(name="jack").one() - jack.pref = None - sess.flush() - assert prefs.count().scalar() == 2 - assert extra.count().scalar() == 2 - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_orphan_on_update(self): - sess = create_session() - jack = sess.query(User).filter_by(name="jack").one() - p = jack.pref - e = jack.pref.extra[0] - sess.expunge_all() - - jack.pref = None - sess.add(jack) - sess.add(p) - sess.add(e) - assert p in sess - assert e in sess - sess.flush() - assert prefs.count().scalar() == 2 - assert extra.count().scalar() == 2 - - @testing.resolve_artifact_names - def test_pending_expunge(self): - sess = create_session() - someuser = User(name='someuser') - sess.add(someuser) - sess.flush() - someuser.pref = p1 = Pref(data='somepref') - assert p1 in sess - someuser.pref = Pref(data='someotherpref') - assert p1 not in sess - sess.flush() - eq_(sess.query(Pref).with_parent(someuser).all(), - [Pref(data="someotherpref")]) - - @testing.resolve_artifact_names - def test_double_assignment(self): - """Double assignment will not accidentally reset the 'parent' flag.""" - - sess = create_session() - jack = sess.query(User).filter_by(name="jack").one() - - newpref = Pref(data="newpref") - jack.pref = newpref - jack.pref = newpref - sess.flush() - eq_(sess.query(Pref).all(), - [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")]) - -class M2OCascadeDeleteTest(_base.MappedTest): - def define_tables(self, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('t2id', Integer, ForeignKey('t2.id'))) - Table('t2', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('t3id', Integer, ForeignKey('t3.id'))) - Table('t3', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) - - def setup_classes(self): - class T1(_fixtures.Base): - pass - class T2(_fixtures.Base): - pass - class T3(_fixtures.Base): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(T1, t1, properties={'t2': relation(T2, cascade="all")}) - mapper(T2, t2, properties={'t3': relation(T3, cascade="all")}) - mapper(T3, t3) - - @testing.resolve_artifact_names - def test_cascade_delete(self): - sess = create_session() - x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) - sess.add(x) - sess.flush() - - sess.delete(x) - sess.flush() - eq_(sess.query(T1).all(), []) - eq_(sess.query(T2).all(), []) - eq_(sess.query(T3).all(), []) - - @testing.resolve_artifact_names - def test_cascade_delete_postappend_onelevel(self): - sess = create_session() - x1 = T1(data='t1', ) - x2 = T2(data='t2') - x3 = T3(data='t3') - sess.add_all((x1, x2, x3)) - sess.flush() - - sess.delete(x1) - x1.t2 = x2 - x2.t3 = x3 - sess.flush() - eq_(sess.query(T1).all(), []) - eq_(sess.query(T2).all(), []) - eq_(sess.query(T3).all(), []) - - @testing.resolve_artifact_names - def test_cascade_delete_postappend_twolevel(self): - sess = create_session() - x1 = T1(data='t1', t2=T2(data='t2')) - x3 = T3(data='t3') - sess.add_all((x1, x3)) - sess.flush() - - sess.delete(x1) - x1.t2.t3 = x3 - sess.flush() - eq_(sess.query(T1).all(), []) - eq_(sess.query(T2).all(), []) - eq_(sess.query(T3).all(), []) - - @testing.resolve_artifact_names - def test_preserves_orphans_onelevel(self): - sess = create_session() - x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) - sess.add(x2) - sess.flush() - x2.t2 = None - - sess.delete(x2) - sess.flush() - eq_(sess.query(T1).all(), []) - eq_(sess.query(T2).all(), [T2()]) - eq_(sess.query(T3).all(), [T3()]) - - @testing.future - @testing.resolve_artifact_names - def test_preserves_orphans_onelevel_postremove(self): - sess = create_session() - x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) - sess.add(x2) - sess.flush() - - sess.delete(x2) - x2.t2 = None - sess.flush() - eq_(sess.query(T1).all(), []) - eq_(sess.query(T2).all(), [T2()]) - eq_(sess.query(T3).all(), [T3()]) - - @testing.resolve_artifact_names - def test_preserves_orphans_twolevel(self): - sess = create_session() - x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) - sess.add(x) - sess.flush() - - x.t2.t3 = None - sess.delete(x) - sess.flush() - eq_(sess.query(T1).all(), []) - eq_(sess.query(T2).all(), []) - eq_(sess.query(T3).all(), [T3()]) - - -class M2OCascadeDeleteOrphanTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('t2id', Integer, ForeignKey('t2.id'))) - Table('t2', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('t3id', Integer, ForeignKey('t3.id'))) - Table('t3', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) - - def setup_classes(self): - class T1(_fixtures.Base): - pass - class T2(_fixtures.Base): - pass - class T3(_fixtures.Base): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(T1, t1, properties=dict( - t2=relation(T2, cascade="all, delete-orphan", single_parent=True))) - mapper(T2, t2, properties=dict( - t3=relation(T3, cascade="all, delete-orphan", single_parent=True, backref=backref('t2', uselist=False)))) - mapper(T3, t3) - - @testing.resolve_artifact_names - def test_cascade_delete(self): - sess = create_session() - x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) - sess.add(x) - sess.flush() - - sess.delete(x) - sess.flush() - eq_(sess.query(T1).all(), []) - eq_(sess.query(T2).all(), []) - eq_(sess.query(T3).all(), []) - - @testing.resolve_artifact_names - def test_deletes_orphans_onelevel(self): - sess = create_session() - x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) - sess.add(x2) - sess.flush() - x2.t2 = None - - sess.delete(x2) - sess.flush() - eq_(sess.query(T1).all(), []) - eq_(sess.query(T2).all(), []) - eq_(sess.query(T3).all(), []) - - @testing.resolve_artifact_names - def test_deletes_orphans_twolevel(self): - sess = create_session() - x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) - sess.add(x) - sess.flush() - - x.t2.t3 = None - sess.delete(x) - sess.flush() - eq_(sess.query(T1).all(), []) - eq_(sess.query(T2).all(), []) - eq_(sess.query(T3).all(), []) - - @testing.resolve_artifact_names - def test_finds_orphans_twolevel(self): - sess = create_session() - x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) - sess.add(x) - sess.flush() - - x.t2.t3 = None - sess.flush() - eq_(sess.query(T1).all(), [T1()]) - eq_(sess.query(T2).all(), [T2()]) - eq_(sess.query(T3).all(), []) - - @testing.resolve_artifact_names - def test_single_parent_raise(self): - - sess = create_session() - - y = T2(data='T2a') - x = T1(data='T1a', t2=y) - self.assertRaises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y) - - @testing.resolve_artifact_names - def test_single_parent_backref(self): - - sess = create_session() - - y = T3(data='T3a') - x = T2(data='T2a', t3=y) - - # cant attach the T3 to another T2 - self.assertRaises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y) - - # set via backref tho is OK, unsets from previous parent - # first - z = T2(data='T2b') - y.t2 = z - - assert z.t3 is y - assert x.t3 is None - -class M2MCascadeTest(_base.MappedTest): - def define_tables(self, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)), - test_needs_fk=True - ) - Table('b', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)), - test_needs_fk=True - - ) - Table('atob', metadata, - Column('aid', Integer, ForeignKey('a.id')), - Column('bid', Integer, ForeignKey('b.id')), - test_needs_fk=True - - ) - Table('c', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)), - Column('bid', Integer, ForeignKey('b.id')), - test_needs_fk=True - - ) - - def setup_classes(self): - class A(_fixtures.Base): - pass - class B(_fixtures.Base): - pass - class C(_fixtures.Base): - pass - - @testing.resolve_artifact_names - def test_delete_orphan(self): - mapper(A, a, properties={ - # if no backref here, delete-orphan failed until [ticket:427] was - # fixed - 'bs': relation(B, secondary=atob, cascade="all, delete-orphan", single_parent=True) - }) - mapper(B, b) - - sess = create_session() - b1 = B(data='b1') - a1 = A(data='a1', bs=[b1]) - sess.add(a1) - sess.flush() - - a1.bs.remove(b1) - sess.flush() - assert atob.count().scalar() ==0 - assert b.count().scalar() == 0 - assert a.count().scalar() == 1 - - @testing.resolve_artifact_names - def test_delete_orphan_cascades(self): - mapper(A, a, properties={ - # if no backref here, delete-orphan failed until [ticket:427] was - # fixed - 'bs':relation(B, secondary=atob, cascade="all, delete-orphan", single_parent=True) - }) - mapper(B, b, properties={'cs':relation(C, cascade="all, delete-orphan")}) - mapper(C, c) - - sess = create_session() - b1 = B(data='b1', cs=[C(data='c1')]) - a1 = A(data='a1', bs=[b1]) - sess.add(a1) - sess.flush() - - a1.bs.remove(b1) - sess.flush() - assert atob.count().scalar() ==0 - assert b.count().scalar() == 0 - assert a.count().scalar() == 1 - assert c.count().scalar() == 0 - - @testing.resolve_artifact_names - def test_cascade_delete(self): - mapper(A, a, properties={ - 'bs':relation(B, secondary=atob, cascade="all, delete-orphan", single_parent=True) - }) - mapper(B, b) - - sess = create_session() - a1 = A(data='a1', bs=[B(data='b1')]) - sess.add(a1) - sess.flush() - - sess.delete(a1) - sess.flush() - assert atob.count().scalar() ==0 - assert b.count().scalar() == 0 - assert a.count().scalar() == 0 - - @testing.resolve_artifact_names - def test_single_parent_raise(self): - mapper(A, a, properties={ - 'bs':relation(B, secondary=atob, cascade="all, delete-orphan", single_parent=True) - }) - mapper(B, b) - - sess = create_session() - b1 =B(data='b1') - a1 = A(data='a1', bs=[b1]) - - self.assertRaises(sa_exc.InvalidRequestError, - A, data='a2', bs=[b1] - ) - - @testing.resolve_artifact_names - def test_single_parent_backref(self): - """test that setting m2m via a uselist=False backref bypasses the single_parent raise""" - - mapper(A, a, properties={ - 'bs':relation(B, - secondary=atob, - cascade="all, delete-orphan", single_parent=True, - backref=backref('a', uselist=False)) - }) - mapper(B, b) - - sess = create_session() - b1 =B(data='b1') - a1 = A(data='a1', bs=[b1]) - - self.assertRaises( - sa_exc.InvalidRequestError, - A, data='a2', bs=[b1] - ) - - a2 = A(data='a2') - b1.a = a2 - assert b1 not in a1.bs - assert b1 in a2.bs - -class UnsavedOrphansTest(_base.MappedTest): - """Pending entities that are orphans""" - - def define_tables(self, metadata): - Table('users', metadata, - Column('user_id', Integer, - Sequence('user_id_seq', optional=True), - primary_key=True), - Column('name', String(40))) - - Table('addresses', metadata, - Column('address_id', Integer, - Sequence('address_id_seq', optional=True), - primary_key=True), - Column('user_id', Integer, ForeignKey('users.user_id')), - Column('email_address', String(40))) - - def setup_classes(self): - class User(_fixtures.Base): - pass - class Address(_fixtures.Base): - pass - - @testing.resolve_artifact_names - def test_pending_standalone_orphan(self): - """An entity that never had a parent on a delete-orphan cascade can't be saved.""" - - mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relation(Address, cascade="all,delete-orphan", backref="user") - )) - s = create_session() - a = Address() - s.add(a) - try: - s.flush() - except orm_exc.FlushError, e: - pass - assert a.address_id is None, "Error: address should not be persistent" - - @testing.resolve_artifact_names - def test_pending_collection_expunge(self): - """Removing a pending item from a collection expunges it from the session.""" - - mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relation(Address, cascade="all,delete-orphan", backref="user") - )) - s = create_session() - - u = User() - s.add(u) - s.flush() - a = Address() - - u.addresses.append(a) - assert a in s - - u.addresses.remove(a) - assert a not in s - - s.delete(u) - s.flush() - - assert a.address_id is None, "Error: address should not be persistent" - - @testing.resolve_artifact_names - def test_nonorphans_ok(self): - mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relation(Address, cascade="all,delete", backref="user") - )) - s = create_session() - u = User(name='u1', addresses=[Address(email_address='ad1')]) - s.add(u) - a1 = u.addresses[0] - u.addresses.remove(a1) - assert a1 in s - s.flush() - s.expunge_all() - eq_(s.query(Address).all(), [Address(email_address='ad1')]) - - -class UnsavedOrphansTest2(_base.MappedTest): - """same test as UnsavedOrphans only three levels deep""" - - def define_tables(self, meta): - Table('orders', meta, - Column('id', Integer, Sequence('order_id_seq'), - primary_key=True), - Column('name', String(50))) - - Table('items', meta, - Column('id', Integer, Sequence('item_id_seq'), - primary_key=True), - Column('order_id', Integer, ForeignKey('orders.id'), - nullable=False), - Column('name', String(50))) - - Table('attributes', meta, - Column('id', Integer, Sequence('attribute_id_seq'), - primary_key=True), - Column('item_id', Integer, ForeignKey('items.id'), - nullable=False), - Column('name', String(50))) - - @testing.resolve_artifact_names - def test_pending_expunge(self): - class Order(_fixtures.Base): - pass - class Item(_fixtures.Base): - pass - class Attribute(_fixtures.Base): - pass - - mapper(Attribute, attributes) - mapper(Item, items, properties=dict( - attributes=relation(Attribute, cascade="all,delete-orphan", backref="item") - )) - mapper(Order, orders, properties=dict( - items=relation(Item, cascade="all,delete-orphan", backref="order") - )) - - s = create_session() - order = Order(name="order1") - s.add(order) - - attr = Attribute(name="attr1") - item = Item(name="item1", attributes=[attr]) - - order.items.append(item) - order.items.remove(item) - - assert item not in s - assert attr not in s - - s.flush() - assert orders.count().scalar() == 1 - assert items.count().scalar() == 0 - assert attributes.count().scalar() == 0 - -class UnsavedOrphansTest3(_base.MappedTest): - """test not expunging double parents""" - - def define_tables(self, meta): - Table('sales_reps', meta, - Column('sales_rep_id', Integer, - Sequence('sales_rep_id_seq'), - primary_key=True), - Column('name', String(50))) - Table('accounts', meta, - Column('account_id', Integer, - Sequence('account_id_seq'), - primary_key=True), - Column('balance', Integer)) - Table('customers', meta, - Column('customer_id', Integer, - Sequence('customer_id_seq'), - primary_key=True), - Column('name', String(50)), - Column('sales_rep_id', Integer, - ForeignKey('sales_reps.sales_rep_id')), - Column('account_id', Integer, - ForeignKey('accounts.account_id'))) - - @testing.resolve_artifact_names - def test_double_parent_expunge_o2m(self): - """test the delete-orphan uow event for multiple delete-orphan parent relations.""" - - class Customer(_fixtures.Base): - pass - class Account(_fixtures.Base): - pass - class SalesRep(_fixtures.Base): - pass - - mapper(Customer, customers) - mapper(Account, accounts, properties=dict( - customers=relation(Customer, - cascade="all,delete-orphan", - backref="account"))) - mapper(SalesRep, sales_reps, properties=dict( - customers=relation(Customer, - cascade="all,delete-orphan", - backref="sales_rep"))) - s = create_session() - - a = Account(balance=0) - sr = SalesRep(name="John") - s.add_all((a, sr)) - s.flush() - - c = Customer(name="Jane") - - a.customers.append(c) - sr.customers.append(c) - assert c in s - - a.customers.remove(c) - assert c in s, "Should not expunge customer yet, still has one parent" - - sr.customers.remove(c) - assert c not in s, "Should expunge customer when both parents are gone" - - @testing.resolve_artifact_names - def test_double_parent_expunge_o2o(self): - """test the delete-orphan uow event for multiple delete-orphan parent relations.""" - - class Customer(_fixtures.Base): - pass - class Account(_fixtures.Base): - pass - class SalesRep(_fixtures.Base): - pass - - mapper(Customer, customers) - mapper(Account, accounts, properties=dict( - customer=relation(Customer, - cascade="all,delete-orphan", - backref="account", uselist=False))) - mapper(SalesRep, sales_reps, properties=dict( - customer=relation(Customer, - cascade="all,delete-orphan", - backref="sales_rep", uselist=False))) - s = create_session() - - a = Account(balance=0) - sr = SalesRep(name="John") - s.add_all((a, sr)) - s.flush() - - c = Customer(name="Jane") - - a.customer = c - sr.customer = c - assert c in s - - a.customer = None - assert c in s, "Should not expunge customer yet, still has one parent" - - sr.customer = None - assert c not in s, "Should expunge customer when both parents are gone" - - - -class DoubleParentOrphanTest(_base.MappedTest): - """test orphan detection for an entity with two parent relations""" - - def define_tables(self, metadata): - Table('addresses', metadata, - Column('address_id', Integer, primary_key=True), - Column('street', String(30)), - ) - - Table('homes', metadata, - Column('home_id', Integer, primary_key=True, key="id"), - Column('description', String(30)), - Column('address_id', Integer, ForeignKey('addresses.address_id'), - nullable=False), - ) - - Table('businesses', metadata, - Column('business_id', Integer, primary_key=True, key="id"), - Column('description', String(30), key="description"), - Column('address_id', Integer, ForeignKey('addresses.address_id'), - nullable=False), - ) - - @testing.resolve_artifact_names - def test_non_orphan(self): - """test that an entity can have two parent delete-orphan cascades, and persists normally.""" - - class Address(_fixtures.Base): - pass - class Home(_fixtures.Base): - pass - class Business(_fixtures.Base): - pass - - mapper(Address, addresses) - mapper(Home, homes, properties={'address':relation(Address, cascade="all,delete-orphan", single_parent=True)}) - mapper(Business, businesses, properties={'address':relation(Address, cascade="all,delete-orphan", single_parent=True)}) - - session = create_session() - h1 = Home(description='home1', address=Address(street='address1')) - b1 = Business(description='business1', address=Address(street='address2')) - session.add_all((h1,b1)) - session.flush() - session.expunge_all() - - eq_(session.query(Home).get(h1.id), Home(description='home1', address=Address(street='address1'))) - eq_(session.query(Business).get(b1.id), Business(description='business1', address=Address(street='address2'))) - - @testing.resolve_artifact_names - def test_orphan(self): - """test that an entity can have two parent delete-orphan cascades, and is detected as an orphan - when saved without a parent.""" - - class Address(_fixtures.Base): - pass - class Home(_fixtures.Base): - pass - class Business(_fixtures.Base): - pass - - mapper(Address, addresses) - mapper(Home, homes, properties={'address':relation(Address, cascade="all,delete-orphan", single_parent=True)}) - mapper(Business, businesses, properties={'address':relation(Address, cascade="all,delete-orphan", single_parent=True)}) - - session = create_session() - a1 = Address() - session.add(a1) - try: - session.flush() - assert False - except orm_exc.FlushError, e: - assert True - -class CollectionAssignmentOrphanTest(_base.MappedTest): - def define_tables(self, metadata): - Table('table_a', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30))) - Table('table_b', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30)), - Column('a_id', Integer, ForeignKey('table_a.id'))) - - @testing.resolve_artifact_names - def test_basic(self): - class A(_fixtures.Base): - pass - class B(_fixtures.Base): - pass - - mapper(A, table_a, properties={ - 'bs':relation(B, cascade="all, delete-orphan") - }) - mapper(B, table_b) - - a1 = A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]) - - sess = create_session() - sess.add(a1) - sess.flush() - - sess.expunge_all() - - eq_(sess.query(A).get(a1.id), - A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) - - a1 = sess.query(A).get(a1.id) - assert not class_mapper(B)._is_orphan( - attributes.instance_state(a1.bs[0])) - a1.bs[0].foo='b2modified' - a1.bs[1].foo='b3modified' - sess.flush() - - sess.expunge_all() - eq_(sess.query(A).get(a1.id), - A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) - - -class PartialFlushTest(_base.MappedTest): - """test cascade behavior as it relates to object lists passed to flush(). - - """ - def define_tables(self, metadata): - Table("base", metadata, - Column("id", Integer, primary_key=True), - Column("descr", String(50)) - ) - - Table("noninh_child", metadata, - Column('id', Integer, primary_key=True), - Column('base_id', Integer, ForeignKey('base.id')) - ) - - Table("parent", metadata, - Column("id", Integer, ForeignKey("base.id"), primary_key=True) - ) - Table("inh_child", metadata, - Column("id", Integer, ForeignKey("base.id"), primary_key=True), - Column("parent_id", Integer, ForeignKey("parent.id")) - ) - - @testing.uses_deprecated() - @testing.resolve_artifact_names - def test_o2m_m2o(self): - class Base(_base.ComparableEntity): - pass - class Child(_base.ComparableEntity): - pass - - mapper(Base, base, properties={ - 'children':relation(Child, backref='parent') - }) - mapper(Child, noninh_child) - - sess = create_session() - - c1, c2 = Child(), Child() - b1 = Base(descr='b1', children=[c1, c2]) - sess.add(b1) - - assert c1 in sess.new - assert c2 in sess.new - sess.flush([b1]) - - # c1, c2 get cascaded into the session on o2m. - # not sure if this is how I like this - # to work but that's how it works for now. - assert c1 in sess and c1 not in sess.new - assert c2 in sess and c2 not in sess.new - assert b1 in sess and b1 not in sess.new - - sess = create_session() - c1, c2 = Child(), Child() - b1 = Base(descr='b1', children=[c1, c2]) - sess.add(b1) - sess.flush([c1]) - # m2o, otoh, doesn't cascade up the other way. - assert c1 in sess and c1 not in sess.new - assert c2 in sess and c2 in sess.new - assert b1 in sess and b1 in sess.new - - sess = create_session() - c1, c2 = Child(), Child() - b1 = Base(descr='b1', children=[c1, c2]) - sess.add(b1) - sess.flush([c1, c2]) - # m2o, otoh, doesn't cascade up the other way. - assert c1 in sess and c1 not in sess.new - assert c2 in sess and c2 not in sess.new - assert b1 in sess and b1 in sess.new - - @testing.uses_deprecated() - @testing.resolve_artifact_names - def test_circular_sort(self): - """test ticket 1306""" - - class Base(_base.ComparableEntity): - pass - class Parent(Base): - pass - class Child(Base): - pass - - mapper(Base,base) - - mapper(Child, inh_child, - inherits=Base, - properties={'parent': relation( - Parent, - backref='children', - primaryjoin=inh_child.c.parent_id == parent.c.id - )} - ) - - - mapper(Parent,parent, inherits=Base) - - sess = create_session() - p1 = Parent() - - c1, c2, c3 = Child(), Child(), Child() - p1.children = [c1, c2, c3] - sess.add(p1) - - sess.flush([c1]) - assert p1 in sess.new - assert c1 not in sess.new - assert c2 in sess.new - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/collection.py b/test/orm/collection.py deleted file mode 100644 index 23f643597..000000000 --- a/test/orm/collection.py +++ /dev/null @@ -1,1834 +0,0 @@ -import testenv; testenv.configure_for_tests() -import sys -from operator import and_ - -import sqlalchemy.orm.collections as collections -from sqlalchemy.orm.collections import collection - -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa import util, exc as sa_exc -from testlib.sa.orm import create_session, mapper, relation, \ - attributes -from orm import _base -from testlib.testing import eq_ - -class Canary(sa.orm.interfaces.AttributeExtension): - def __init__(self): - self.data = set() - self.added = set() - self.removed = set() - def append(self, obj, value, initiator): - assert value not in self.added - self.data.add(value) - self.added.add(value) - return value - def remove(self, obj, value, initiator): - assert value not in self.removed - self.data.remove(value) - self.removed.add(value) - def set(self, obj, value, oldvalue, initiator): - if isinstance(value, str): - value = CollectionsTest.entity_maker() - - if oldvalue is not None: - self.remove(obj, oldvalue, None) - self.append(obj, value, None) - return value - -class CollectionsTest(_base.ORMTest): - class Entity(object): - def __init__(self, a=None, b=None, c=None): - self.a = a - self.b = b - self.c = c - def __repr__(self): - return str((id(self), self.a, self.b, self.c)) - - def setUpAll(self): - attributes.register_class(self.Entity) - - def tearDownAll(self): - attributes.unregister_class(self.Entity) - _base.ORMTest.tearDownAll(self) - - _entity_id = 1 - - @classmethod - def entity_maker(cls): - cls._entity_id += 1 - return cls.Entity(cls._entity_id) - - @classmethod - def dictable_entity(cls, a=None, b=None, c=None): - id = cls._entity_id = (cls._entity_id + 1) - return cls.Entity(a or str(id), b or 'value %s' % id, c) - - def _test_adapter(self, typecallable, creator=None, to_set=None): - if creator is None: - creator = self.entity_maker - - class Foo(object): - pass - - canary = Canary() - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, - typecallable=typecallable, useobject=True) - - obj = Foo() - adapter = collections.collection_adapter(obj.attr) - direct = obj.attr - if to_set is None: - to_set = lambda col: set(col) - - def assert_eq(): - self.assert_(to_set(direct) == canary.data) - self.assert_(set(adapter) == canary.data) - assert_ne = lambda: self.assert_(to_set(direct) != canary.data) - - e1, e2 = creator(), creator() - - adapter.append_with_event(e1) - assert_eq() - - adapter.append_without_event(e2) - assert_ne() - canary.data.add(e2) - assert_eq() - - adapter.remove_without_event(e2) - assert_ne() - canary.data.remove(e2) - assert_eq() - - adapter.remove_with_event(e1) - assert_eq() - - def _test_list(self, typecallable, creator=None): - if creator is None: - creator = self.entity_maker - - class Foo(object): - pass - - canary = Canary() - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, - typecallable=typecallable, useobject=True) - - obj = Foo() - adapter = collections.collection_adapter(obj.attr) - direct = obj.attr - control = list() - - def assert_eq(): - self.assert_(set(direct) == canary.data) - self.assert_(set(adapter) == canary.data) - self.assert_(direct == control) - - # assume append() is available for list tests - e = creator() - direct.append(e) - control.append(e) - assert_eq() - - if hasattr(direct, 'pop'): - direct.pop() - control.pop() - assert_eq() - - if hasattr(direct, '__setitem__'): - e = creator() - direct.append(e) - control.append(e) - - e = creator() - direct[0] = e - control[0] = e - assert_eq() - - if util.reduce(and_, [hasattr(direct, a) for a in - ('__delitem__', 'insert', '__len__')], True): - values = [creator(), creator(), creator(), creator()] - direct[slice(0,1)] = values - control[slice(0,1)] = values - assert_eq() - - values = [creator(), creator()] - direct[slice(0,-1,2)] = values - control[slice(0,-1,2)] = values - assert_eq() - - values = [creator()] - direct[slice(0,-1)] = values - control[slice(0,-1)] = values - assert_eq() - - if hasattr(direct, '__delitem__'): - e = creator() - direct.append(e) - control.append(e) - del direct[-1] - del control[-1] - assert_eq() - - if hasattr(direct, '__getslice__'): - for e in [creator(), creator(), creator(), creator()]: - direct.append(e) - control.append(e) - - del direct[:-3] - del control[:-3] - assert_eq() - - del direct[0:1] - del control[0:1] - assert_eq() - - del direct[::2] - del control[::2] - assert_eq() - - if hasattr(direct, 'remove'): - e = creator() - direct.append(e) - control.append(e) - - direct.remove(e) - control.remove(e) - assert_eq() - - if hasattr(direct, '__setslice__'): - values = [creator(), creator()] - direct[0:1] = values - control[0:1] = values - assert_eq() - - values = [creator()] - direct[0:] = values - control[0:] = values - assert_eq() - - values = [creator()] - direct[:1] = values - control[:1] = values - assert_eq() - - values = [creator()] - direct[-1::2] = values - control[-1::2] = values - assert_eq() - - values = [creator()] * len(direct[1::2]) - direct[1::2] = values - control[1::2] = values - assert_eq() - - if hasattr(direct, '__delslice__'): - for i in range(1, 4): - e = creator() - direct.append(e) - control.append(e) - - del direct[-1:] - del control[-1:] - assert_eq() - - del direct[1:2] - del control[1:2] - assert_eq() - - del direct[:] - del control[:] - assert_eq() - - if hasattr(direct, 'extend'): - values = [creator(), creator(), creator()] - - direct.extend(values) - control.extend(values) - assert_eq() - - if hasattr(direct, '__iadd__'): - values = [creator(), creator(), creator()] - - direct += values - control += values - assert_eq() - - direct += [] - control += [] - assert_eq() - - values = [creator(), creator()] - obj.attr += values - control += values - assert_eq() - - if hasattr(direct, '__imul__'): - direct *= 2 - control *= 2 - assert_eq() - - obj.attr *= 2 - control *= 2 - assert_eq() - - def _test_list_bulk(self, typecallable, creator=None): - if creator is None: - creator = self.entity_maker - - class Foo(object): - pass - - canary = Canary() - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, - typecallable=typecallable, useobject=True) - - obj = Foo() - direct = obj.attr - - e1 = creator() - obj.attr.append(e1) - - like_me = typecallable() - e2 = creator() - like_me.append(e2) - - self.assert_(obj.attr is direct) - obj.attr = like_me - self.assert_(obj.attr is not direct) - self.assert_(obj.attr is not like_me) - self.assert_(set(obj.attr) == set([e2])) - self.assert_(e1 in canary.removed) - self.assert_(e2 in canary.added) - - e3 = creator() - real_list = [e3] - obj.attr = real_list - self.assert_(obj.attr is not real_list) - self.assert_(set(obj.attr) == set([e3])) - self.assert_(e2 in canary.removed) - self.assert_(e3 in canary.added) - - e4 = creator() - try: - obj.attr = set([e4]) - self.assert_(False) - except TypeError: - self.assert_(e4 not in canary.data) - self.assert_(e3 in canary.data) - - e5 = creator() - e6 = creator() - e7 = creator() - obj.attr = [e5, e6, e7] - self.assert_(e5 in canary.added) - self.assert_(e6 in canary.added) - self.assert_(e7 in canary.added) - - obj.attr = [e6, e7] - self.assert_(e5 in canary.removed) - self.assert_(e6 in canary.added) - self.assert_(e7 in canary.added) - self.assert_(e6 not in canary.removed) - self.assert_(e7 not in canary.removed) - - def test_list(self): - self._test_adapter(list) - self._test_list(list) - self._test_list_bulk(list) - - def test_list_subclass(self): - class MyList(list): - pass - self._test_adapter(MyList) - self._test_list(MyList) - self._test_list_bulk(MyList) - self.assert_(getattr(MyList, '_sa_instrumented') == id(MyList)) - - def test_list_duck(self): - class ListLike(object): - def __init__(self): - self.data = list() - def append(self, item): - self.data.append(item) - def remove(self, item): - self.data.remove(item) - def insert(self, index, item): - self.data.insert(index, item) - def pop(self, index=-1): - return self.data.pop(index) - def extend(self): - assert False - def __iter__(self): - return iter(self.data) - __hash__ = object.__hash__ - def __eq__(self, other): - return self.data == other - def __repr__(self): - return 'ListLike(%s)' % repr(self.data) - - self._test_adapter(ListLike) - self._test_list(ListLike) - self._test_list_bulk(ListLike) - self.assert_(getattr(ListLike, '_sa_instrumented') == id(ListLike)) - - def test_list_emulates(self): - class ListIsh(object): - __emulates__ = list - def __init__(self): - self.data = list() - def append(self, item): - self.data.append(item) - def remove(self, item): - self.data.remove(item) - def insert(self, index, item): - self.data.insert(index, item) - def pop(self, index=-1): - return self.data.pop(index) - def extend(self): - assert False - def __iter__(self): - return iter(self.data) - __hash__ = object.__hash__ - def __eq__(self, other): - return self.data == other - def __repr__(self): - return 'ListIsh(%s)' % repr(self.data) - - self._test_adapter(ListIsh) - self._test_list(ListIsh) - self._test_list_bulk(ListIsh) - self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh)) - - def _test_set(self, typecallable, creator=None): - if creator is None: - creator = self.entity_maker - - class Foo(object): - pass - - canary = Canary() - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, - typecallable=typecallable, useobject=True) - - obj = Foo() - adapter = collections.collection_adapter(obj.attr) - direct = obj.attr - control = set() - - def assert_eq(): - self.assert_(set(direct) == canary.data) - self.assert_(set(adapter) == canary.data) - self.assert_(direct == control) - - def addall(*values): - for item in values: - direct.add(item) - control.add(item) - assert_eq() - def zap(): - for item in list(direct): - direct.remove(item) - control.clear() - - addall(creator()) - - e = creator() - addall(e) - addall(e) - - if hasattr(direct, 'pop'): - direct.pop() - control.pop() - assert_eq() - - if hasattr(direct, 'remove'): - e = creator() - addall(e) - - direct.remove(e) - control.remove(e) - assert_eq() - - e = creator() - try: - direct.remove(e) - except KeyError: - assert_eq() - self.assert_(e not in canary.removed) - else: - self.assert_(False) - - if hasattr(direct, 'discard'): - e = creator() - addall(e) - - direct.discard(e) - control.discard(e) - assert_eq() - - e = creator() - direct.discard(e) - self.assert_(e not in canary.removed) - assert_eq() - - if hasattr(direct, 'update'): - zap() - e = creator() - addall(e) - - values = set([e, creator(), creator()]) - - direct.update(values) - control.update(values) - assert_eq() - - if hasattr(direct, '__ior__'): - zap() - e = creator() - addall(e) - - values = set([e, creator(), creator()]) - - direct |= values - control |= values - assert_eq() - - # cover self-assignment short-circuit - values = set([e, creator(), creator()]) - obj.attr |= values - control |= values - assert_eq() - - values = frozenset([e, creator()]) - obj.attr |= values - control |= values - assert_eq() - - try: - direct |= [e, creator()] - assert False - except TypeError: - assert True - - if hasattr(direct, 'clear'): - addall(creator(), creator()) - direct.clear() - control.clear() - assert_eq() - - if hasattr(direct, 'difference_update'): - zap() - e = creator() - addall(creator(), creator()) - values = set([creator()]) - - direct.difference_update(values) - control.difference_update(values) - assert_eq() - values.update(set([e, creator()])) - direct.difference_update(values) - control.difference_update(values) - assert_eq() - - if hasattr(direct, '__isub__'): - zap() - e = creator() - addall(creator(), creator()) - values = set([creator()]) - - direct -= values - control -= values - assert_eq() - values.update(set([e, creator()])) - direct -= values - control -= values - assert_eq() - - values = set([creator()]) - obj.attr -= values - control -= values - assert_eq() - - values = frozenset([creator()]) - obj.attr -= values - control -= values - assert_eq() - - try: - direct -= [e, creator()] - assert False - except TypeError: - assert True - - if hasattr(direct, 'intersection_update'): - zap() - e = creator() - addall(e, creator(), creator()) - values = set(control) - - direct.intersection_update(values) - control.intersection_update(values) - assert_eq() - - values.update(set([e, creator()])) - direct.intersection_update(values) - control.intersection_update(values) - assert_eq() - - if hasattr(direct, '__iand__'): - zap() - e = creator() - addall(e, creator(), creator()) - values = set(control) - - direct &= values - control &= values - assert_eq() - - values.update(set([e, creator()])) - direct &= values - control &= values - assert_eq() - - values.update(set([creator()])) - obj.attr &= values - control &= values - assert_eq() - - try: - direct &= [e, creator()] - assert False - except TypeError: - assert True - - if hasattr(direct, 'symmetric_difference_update'): - zap() - e = creator() - addall(e, creator(), creator()) - - values = set([e, creator()]) - direct.symmetric_difference_update(values) - control.symmetric_difference_update(values) - assert_eq() - - e = creator() - addall(e) - values = set([e]) - direct.symmetric_difference_update(values) - control.symmetric_difference_update(values) - assert_eq() - - values = set() - direct.symmetric_difference_update(values) - control.symmetric_difference_update(values) - assert_eq() - - if hasattr(direct, '__ixor__'): - zap() - e = creator() - addall(e, creator(), creator()) - - values = set([e, creator()]) - direct ^= values - control ^= values - assert_eq() - - e = creator() - addall(e) - values = set([e]) - direct ^= values - control ^= values - assert_eq() - - values = set() - direct ^= values - control ^= values - assert_eq() - - values = set([creator()]) - obj.attr ^= values - control ^= values - assert_eq() - - try: - direct ^= [e, creator()] - assert False - except TypeError: - assert True - - def _test_set_bulk(self, typecallable, creator=None): - if creator is None: - creator = self.entity_maker - - class Foo(object): - pass - - canary = Canary() - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, - typecallable=typecallable, useobject=True) - - obj = Foo() - direct = obj.attr - - e1 = creator() - obj.attr.add(e1) - - like_me = typecallable() - e2 = creator() - like_me.add(e2) - - self.assert_(obj.attr is direct) - obj.attr = like_me - self.assert_(obj.attr is not direct) - self.assert_(obj.attr is not like_me) - self.assert_(obj.attr == set([e2])) - self.assert_(e1 in canary.removed) - self.assert_(e2 in canary.added) - - e3 = creator() - real_set = set([e3]) - obj.attr = real_set - self.assert_(obj.attr is not real_set) - self.assert_(obj.attr == set([e3])) - self.assert_(e2 in canary.removed) - self.assert_(e3 in canary.added) - - e4 = creator() - try: - obj.attr = [e4] - self.assert_(False) - except TypeError: - self.assert_(e4 not in canary.data) - self.assert_(e3 in canary.data) - - def test_set(self): - self._test_adapter(set) - self._test_set(set) - self._test_set_bulk(set) - - def test_set_subclass(self): - class MySet(set): - pass - self._test_adapter(MySet) - self._test_set(MySet) - self._test_set_bulk(MySet) - self.assert_(getattr(MySet, '_sa_instrumented') == id(MySet)) - - def test_set_duck(self): - class SetLike(object): - def __init__(self): - self.data = set() - def add(self, item): - self.data.add(item) - def remove(self, item): - self.data.remove(item) - def discard(self, item): - self.data.discard(item) - def pop(self): - return self.data.pop() - def update(self, other): - self.data.update(other) - def __iter__(self): - return iter(self.data) - __hash__ = object.__hash__ - def __eq__(self, other): - return self.data == other - - self._test_adapter(SetLike) - self._test_set(SetLike) - self._test_set_bulk(SetLike) - self.assert_(getattr(SetLike, '_sa_instrumented') == id(SetLike)) - - def test_set_emulates(self): - class SetIsh(object): - __emulates__ = set - def __init__(self): - self.data = set() - def add(self, item): - self.data.add(item) - def remove(self, item): - self.data.remove(item) - def discard(self, item): - self.data.discard(item) - def pop(self): - return self.data.pop() - def update(self, other): - self.data.update(other) - def __iter__(self): - return iter(self.data) - __hash__ = object.__hash__ - def __eq__(self, other): - return self.data == other - - self._test_adapter(SetIsh) - self._test_set(SetIsh) - self._test_set_bulk(SetIsh) - self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh)) - - def _test_dict(self, typecallable, creator=None): - if creator is None: - creator = self.dictable_entity - - class Foo(object): - pass - - canary = Canary() - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, - typecallable=typecallable, useobject=True) - - obj = Foo() - adapter = collections.collection_adapter(obj.attr) - direct = obj.attr - control = dict() - - def assert_eq(): - self.assert_(set(direct.values()) == canary.data) - self.assert_(set(adapter) == canary.data) - self.assert_(direct == control) - - def addall(*values): - for item in values: - direct.set(item) - control[item.a] = item - assert_eq() - def zap(): - for item in list(adapter): - direct.remove(item) - control.clear() - - # assume an 'set' method is available for tests - addall(creator()) - - if hasattr(direct, '__setitem__'): - e = creator() - direct[e.a] = e - control[e.a] = e - assert_eq() - - e = creator(e.a, e.b) - direct[e.a] = e - control[e.a] = e - assert_eq() - - if hasattr(direct, '__delitem__'): - e = creator() - addall(e) - - del direct[e.a] - del control[e.a] - assert_eq() - - e = creator() - try: - del direct[e.a] - except KeyError: - self.assert_(e not in canary.removed) - - if hasattr(direct, 'clear'): - addall(creator(), creator(), creator()) - - direct.clear() - control.clear() - assert_eq() - - direct.clear() - control.clear() - assert_eq() - - if hasattr(direct, 'pop'): - e = creator() - addall(e) - - direct.pop(e.a) - control.pop(e.a) - assert_eq() - - e = creator() - try: - direct.pop(e.a) - except KeyError: - self.assert_(e not in canary.removed) - - if hasattr(direct, 'popitem'): - zap() - e = creator() - addall(e) - - direct.popitem() - control.popitem() - assert_eq() - - if hasattr(direct, 'setdefault'): - e = creator() - - val_a = direct.setdefault(e.a, e) - val_b = control.setdefault(e.a, e) - assert_eq() - self.assert_(val_a is val_b) - - val_a = direct.setdefault(e.a, e) - val_b = control.setdefault(e.a, e) - assert_eq() - self.assert_(val_a is val_b) - - if hasattr(direct, 'update'): - e = creator() - d = dict([(ee.a, ee) for ee in [e, creator(), creator()]]) - addall(e, creator()) - - direct.update(d) - control.update(d) - assert_eq() - - if sys.version_info >= (2, 4): - kw = dict([(ee.a, ee) for ee in [e, creator()]]) - direct.update(**kw) - control.update(**kw) - assert_eq() - - def _test_dict_bulk(self, typecallable, creator=None): - if creator is None: - creator = self.dictable_entity - - class Foo(object): - pass - - canary = Canary() - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, - typecallable=typecallable, useobject=True) - - obj = Foo() - direct = obj.attr - - e1 = creator() - collections.collection_adapter(direct).append_with_event(e1) - - like_me = typecallable() - e2 = creator() - like_me.set(e2) - - self.assert_(obj.attr is direct) - obj.attr = like_me - self.assert_(obj.attr is not direct) - self.assert_(obj.attr is not like_me) - self.assert_(set(collections.collection_adapter(obj.attr)) == set([e2])) - self.assert_(e1 in canary.removed) - self.assert_(e2 in canary.added) - - - # key validity on bulk assignment is a basic feature of MappedCollection - # but is not present in basic, @converter-less dict collections. - e3 = creator() - if isinstance(obj.attr, collections.MappedCollection): - real_dict = dict(badkey=e3) - try: - obj.attr = real_dict - self.assert_(False) - except TypeError: - pass - self.assert_(obj.attr is not real_dict) - self.assert_('badkey' not in obj.attr) - self.assertEquals(set(collections.collection_adapter(obj.attr)), - set([e2])) - self.assert_(e3 not in canary.added) - else: - real_dict = dict(keyignored1=e3) - obj.attr = real_dict - self.assert_(obj.attr is not real_dict) - self.assert_('keyignored1' not in obj.attr) - self.assertEquals(set(collections.collection_adapter(obj.attr)), - set([e3])) - self.assert_(e2 in canary.removed) - self.assert_(e3 in canary.added) - - obj.attr = typecallable() - self.assertEquals(list(collections.collection_adapter(obj.attr)), []) - - e4 = creator() - try: - obj.attr = [e4] - self.assert_(False) - except TypeError: - self.assert_(e4 not in canary.data) - - def test_dict(self): - try: - self._test_adapter(dict, self.dictable_entity, - to_set=lambda c: set(c.values())) - self.assert_(False) - except sa_exc.ArgumentError, e: - self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class') - - try: - self._test_dict(dict) - self.assert_(False) - except sa_exc.ArgumentError, e: - self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class') - - def test_dict_subclass(self): - class MyDict(dict): - @collection.appender - @collection.internally_instrumented - def set(self, item, _sa_initiator=None): - self.__setitem__(item.a, item, _sa_initiator=_sa_initiator) - @collection.remover - @collection.internally_instrumented - def _remove(self, item, _sa_initiator=None): - self.__delitem__(item.a, _sa_initiator=_sa_initiator) - - self._test_adapter(MyDict, self.dictable_entity, - to_set=lambda c: set(c.values())) - self._test_dict(MyDict) - self._test_dict_bulk(MyDict) - self.assert_(getattr(MyDict, '_sa_instrumented') == id(MyDict)) - - def test_dict_subclass2(self): - class MyEasyDict(collections.MappedCollection): - def __init__(self): - super(MyEasyDict, self).__init__(lambda e: e.a) - - self._test_adapter(MyEasyDict, self.dictable_entity, - to_set=lambda c: set(c.values())) - self._test_dict(MyEasyDict) - self._test_dict_bulk(MyEasyDict) - self.assert_(getattr(MyEasyDict, '_sa_instrumented') == id(MyEasyDict)) - - def test_dict_subclass3(self): - class MyOrdered(util.OrderedDict, collections.MappedCollection): - def __init__(self): - collections.MappedCollection.__init__(self, lambda e: e.a) - util.OrderedDict.__init__(self) - - self._test_adapter(MyOrdered, self.dictable_entity, - to_set=lambda c: set(c.values())) - self._test_dict(MyOrdered) - self._test_dict_bulk(MyOrdered) - self.assert_(getattr(MyOrdered, '_sa_instrumented') == id(MyOrdered)) - - def test_dict_duck(self): - class DictLike(object): - def __init__(self): - self.data = dict() - - @collection.appender - @collection.replaces(1) - def set(self, item): - current = self.data.get(item.a, None) - self.data[item.a] = item - return current - @collection.remover - def _remove(self, item): - del self.data[item.a] - def __setitem__(self, key, value): - self.data[key] = value - def __getitem__(self, key): - return self.data[key] - def __delitem__(self, key): - del self.data[key] - def values(self): - return self.data.values() - def __contains__(self, key): - return key in self.data - @collection.iterator - def itervalues(self): - return self.data.itervalues() - __hash__ = object.__hash__ - def __eq__(self, other): - return self.data == other - def __repr__(self): - return 'DictLike(%s)' % repr(self.data) - - self._test_adapter(DictLike, self.dictable_entity, - to_set=lambda c: set(c.itervalues())) - self._test_dict(DictLike) - self._test_dict_bulk(DictLike) - self.assert_(getattr(DictLike, '_sa_instrumented') == id(DictLike)) - - def test_dict_emulates(self): - class DictIsh(object): - __emulates__ = dict - def __init__(self): - self.data = dict() - - @collection.appender - @collection.replaces(1) - def set(self, item): - current = self.data.get(item.a, None) - self.data[item.a] = item - return current - @collection.remover - def _remove(self, item): - del self.data[item.a] - def __setitem__(self, key, value): - self.data[key] = value - def __getitem__(self, key): - return self.data[key] - def __delitem__(self, key): - del self.data[key] - def values(self): - return self.data.values() - def __contains__(self, key): - return key in self.data - @collection.iterator - def itervalues(self): - return self.data.itervalues() - __hash__ = object.__hash__ - def __eq__(self, other): - return self.data == other - def __repr__(self): - return 'DictIsh(%s)' % repr(self.data) - - self._test_adapter(DictIsh, self.dictable_entity, - to_set=lambda c: set(c.itervalues())) - self._test_dict(DictIsh) - self._test_dict_bulk(DictIsh) - self.assert_(getattr(DictIsh, '_sa_instrumented') == id(DictIsh)) - - def _test_object(self, typecallable, creator=None): - if creator is None: - creator = self.entity_maker - - class Foo(object): - pass - - canary = Canary() - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, - typecallable=typecallable, useobject=True) - - obj = Foo() - adapter = collections.collection_adapter(obj.attr) - direct = obj.attr - control = set() - - def assert_eq(): - self.assert_(set(direct) == canary.data) - self.assert_(set(adapter) == canary.data) - self.assert_(direct == control) - - # There is no API for object collections. We'll make one up - # for the purposes of the test. - e = creator() - direct.push(e) - control.add(e) - assert_eq() - - direct.zark(e) - control.remove(e) - assert_eq() - - e = creator() - direct.maybe_zark(e) - control.discard(e) - assert_eq() - - e = creator() - direct.push(e) - control.add(e) - assert_eq() - - e = creator() - direct.maybe_zark(e) - control.discard(e) - assert_eq() - - def test_object_duck(self): - class MyCollection(object): - def __init__(self): - self.data = set() - @collection.appender - def push(self, item): - self.data.add(item) - @collection.remover - def zark(self, item): - self.data.remove(item) - @collection.removes_return() - def maybe_zark(self, item): - if item in self.data: - self.data.remove(item) - return item - @collection.iterator - def __iter__(self): - return iter(self.data) - __hash__ = object.__hash__ - def __eq__(self, other): - return self.data == other - - self._test_adapter(MyCollection) - self._test_object(MyCollection) - self.assert_(getattr(MyCollection, '_sa_instrumented') == - id(MyCollection)) - - def test_object_emulates(self): - class MyCollection2(object): - __emulates__ = None - def __init__(self): - self.data = set() - # looks like a list - def append(self, item): - assert False - @collection.appender - def push(self, item): - self.data.add(item) - @collection.remover - def zark(self, item): - self.data.remove(item) - @collection.removes_return() - def maybe_zark(self, item): - if item in self.data: - self.data.remove(item) - return item - @collection.iterator - def __iter__(self): - return iter(self.data) - __hash__ = object.__hash__ - def __eq__(self, other): - return self.data == other - - self._test_adapter(MyCollection2) - self._test_object(MyCollection2) - self.assert_(getattr(MyCollection2, '_sa_instrumented') == - id(MyCollection2)) - - def test_recipes(self): - class Custom(object): - def __init__(self): - self.data = [] - @collection.appender - @collection.adds('entity') - def put(self, entity): - self.data.append(entity) - - @collection.remover - @collection.removes(1) - def remove(self, entity): - self.data.remove(entity) - - @collection.adds(1) - def push(self, *args): - self.data.append(args[0]) - - @collection.removes('entity') - def yank(self, entity, arg): - self.data.remove(entity) - - @collection.replaces(2) - def replace(self, arg, entity, **kw): - self.data.insert(0, entity) - return self.data.pop() - - @collection.removes_return() - def pop(self, key): - return self.data.pop() - - @collection.iterator - def __iter__(self): - return iter(self.data) - - class Foo(object): - pass - canary = Canary() - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, - typecallable=Custom, useobject=True) - - obj = Foo() - adapter = collections.collection_adapter(obj.attr) - direct = obj.attr - control = list() - def assert_eq(): - self.assert_(set(direct) == canary.data) - self.assert_(set(adapter) == canary.data) - self.assert_(list(direct) == control) - creator = self.entity_maker - - e1 = creator() - direct.put(e1) - control.append(e1) - assert_eq() - - e2 = creator() - direct.put(entity=e2) - control.append(e2) - assert_eq() - - direct.remove(e2) - control.remove(e2) - assert_eq() - - direct.remove(entity=e1) - control.remove(e1) - assert_eq() - - e3 = creator() - direct.push(e3) - control.append(e3) - assert_eq() - - direct.yank(e3, 'blah') - control.remove(e3) - assert_eq() - - e4, e5, e6, e7 = creator(), creator(), creator(), creator() - direct.put(e4) - direct.put(e5) - control.append(e4) - control.append(e5) - - dr1 = direct.replace('foo', e6, bar='baz') - control.insert(0, e6) - cr1 = control.pop() - assert_eq() - self.assert_(dr1 is cr1) - - dr2 = direct.replace(arg=1, entity=e7) - control.insert(0, e7) - cr2 = control.pop() - assert_eq() - self.assert_(dr2 is cr2) - - dr3 = direct.pop('blah') - cr3 = control.pop() - assert_eq() - self.assert_(dr3 is cr3) - - def test_lifecycle(self): - class Foo(object): - pass - - canary = Canary() - creator = self.entity_maker - attributes.register_class(Foo) - attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, useobject=True) - - obj = Foo() - col1 = obj.attr - - e1 = creator() - obj.attr.append(e1) - - e2 = creator() - bulk1 = [e2] - # empty & sever col1 from obj - obj.attr = bulk1 - self.assert_(len(col1) == 0) - self.assert_(len(canary.data) == 1) - self.assert_(obj.attr is not col1) - self.assert_(obj.attr is not bulk1) - self.assert_(obj.attr == bulk1) - - e3 = creator() - col1.append(e3) - self.assert_(e3 not in canary.data) - self.assert_(collections.collection_adapter(col1) is None) - - obj.attr[0] = e3 - self.assert_(e3 in canary.data) - -class DictHelpersTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('parents', metadata, - Column('id', Integer, primary_key=True), - Column('label', String(128))) - Table('children', metadata, - Column('id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('parents.id'), - nullable=False), - Column('a', String(128)), - Column('b', String(128)), - Column('c', String(128))) - - def setup_classes(self): - class Parent(_base.BasicEntity): - def __init__(self, label=None): - self.label = label - - class Child(_base.BasicEntity): - def __init__(self, a=None, b=None, c=None): - self.a = a - self.b = b - self.c = c - - @testing.resolve_artifact_names - def _test_scalar_mapped(self, collection_class): - mapper(Child, children) - mapper(Parent, parents, properties={ - 'children': relation(Child, collection_class=collection_class, - cascade="all, delete-orphan")}) - - p = Parent() - p.children['foo'] = Child('foo', 'value') - p.children['bar'] = Child('bar', 'value') - session = create_session() - session.add(p) - session.flush() - pid = p.id - session.expunge_all() - - p = session.query(Parent).get(pid) - - - self.assertEquals(set(p.children.keys()), set(['foo', 'bar'])) - cid = p.children['foo'].id - - collections.collection_adapter(p.children).append_with_event( - Child('foo', 'newvalue')) - - session.flush() - session.expunge_all() - - p = session.query(Parent).get(pid) - - self.assert_(set(p.children.keys()) == set(['foo', 'bar'])) - self.assert_(p.children['foo'].id != cid) - - self.assert_(len(list(collections.collection_adapter(p.children))) == 2) - session.flush() - session.expunge_all() - - p = session.query(Parent).get(pid) - self.assert_(len(list(collections.collection_adapter(p.children))) == 2) - - collections.collection_adapter(p.children).remove_with_event( - p.children['foo']) - - self.assert_(len(list(collections.collection_adapter(p.children))) == 1) - session.flush() - session.expunge_all() - - p = session.query(Parent).get(pid) - self.assert_(len(list(collections.collection_adapter(p.children))) == 1) - - del p.children['bar'] - self.assert_(len(list(collections.collection_adapter(p.children))) == 0) - session.flush() - session.expunge_all() - - p = session.query(Parent).get(pid) - self.assert_(len(list(collections.collection_adapter(p.children))) == 0) - - - @testing.resolve_artifact_names - def _test_composite_mapped(self, collection_class): - mapper(Child, children) - mapper(Parent, parents, properties={ - 'children': relation(Child, collection_class=collection_class, - cascade="all, delete-orphan") - }) - - p = Parent() - p.children[('foo', '1')] = Child('foo', '1', 'value 1') - p.children[('foo', '2')] = Child('foo', '2', 'value 2') - - session = create_session() - session.add(p) - session.flush() - pid = p.id - session.expunge_all() - - p = session.query(Parent).get(pid) - - self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])) - cid = p.children[('foo', '1')].id - - collections.collection_adapter(p.children).append_with_event( - Child('foo', '1', 'newvalue')) - - session.flush() - session.expunge_all() - - p = session.query(Parent).get(pid) - - self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])) - self.assert_(p.children[('foo', '1')].id != cid) - - self.assert_(len(list(collections.collection_adapter(p.children))) == 2) - - def test_mapped_collection(self): - collection_class = collections.mapped_collection(lambda c: c.a) - self._test_scalar_mapped(collection_class) - - def test_mapped_collection2(self): - collection_class = collections.mapped_collection(lambda c: (c.a, c.b)) - self._test_composite_mapped(collection_class) - - def test_attr_mapped_collection(self): - collection_class = collections.attribute_mapped_collection('a') - self._test_scalar_mapped(collection_class) - - def test_declarative_column_mapped(self): - """test that uncompiled attribute usage works with column_mapped_collection""" - - from sqlalchemy.ext.declarative import declarative_base - - BaseObject = declarative_base() - - class Foo(BaseObject): - __tablename__ = "foo" - id = Column(Integer(), primary_key=True) - bar_id = Column(Integer, ForeignKey('bar.id')) - - class Bar(BaseObject): - __tablename__ = "bar" - id = Column(Integer(), primary_key=True) - foos = relation(Foo, collection_class=collections.column_mapped_collection(Foo.id)) - foos2 = relation(Foo, collection_class=collections.column_mapped_collection((Foo.id, Foo.bar_id))) - - eq_(Bar.foos.property.collection_class().keyfunc(Foo(id=3)), 3) - eq_(Bar.foos2.property.collection_class().keyfunc(Foo(id=3, bar_id=12)), (3, 12)) - - @testing.resolve_artifact_names - def test_column_mapped_collection(self): - collection_class = collections.column_mapped_collection( - children.c.a) - self._test_scalar_mapped(collection_class) - - @testing.resolve_artifact_names - def test_column_mapped_collection2(self): - collection_class = collections.column_mapped_collection( - (children.c.a, children.c.b)) - self._test_composite_mapped(collection_class) - - def test_mixin(self): - class Ordered(util.OrderedDict, collections.MappedCollection): - def __init__(self): - collections.MappedCollection.__init__(self, lambda v: v.a) - util.OrderedDict.__init__(self) - collection_class = Ordered - self._test_scalar_mapped(collection_class) - - def test_mixin2(self): - class Ordered2(util.OrderedDict, collections.MappedCollection): - def __init__(self, keyfunc): - collections.MappedCollection.__init__(self, keyfunc) - util.OrderedDict.__init__(self) - collection_class = lambda: Ordered2(lambda v: (v.a, v.b)) - self._test_composite_mapped(collection_class) - -# TODO: are these tests redundant vs. the above tests ? -# remove if so -class CustomCollectionsTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('sometable', metadata, - Column('col1',Integer, primary_key=True), - Column('data', String(30))) - Table('someothertable', metadata, - Column('col1', Integer, primary_key=True), - Column('scol1', Integer, - ForeignKey('sometable.col1')), - Column('data', String(20))) - - @testing.resolve_artifact_names - def test_basic(self): - class MyList(list): - pass - class Foo(object): - pass - class Bar(object): - pass - - mapper(Foo, sometable, properties={ - 'bars':relation(Bar, collection_class=MyList) - }) - mapper(Bar, someothertable) - f = Foo() - assert isinstance(f.bars, MyList) - - @testing.resolve_artifact_names - def test_lazyload(self): - """test that a 'set' can be used as a collection and can lazyload.""" - class Foo(object): - pass - class Bar(object): - pass - mapper(Foo, sometable, properties={ - 'bars':relation(Bar, collection_class=set) - }) - mapper(Bar, someothertable) - f = Foo() - f.bars.add(Bar()) - f.bars.add(Bar()) - sess = create_session() - sess.add(f) - sess.flush() - sess.expunge_all() - f = sess.query(Foo).get(f.col1) - assert len(list(f.bars)) == 2 - f.bars.clear() - - @testing.resolve_artifact_names - def test_dict(self): - """test that a 'dict' can be used as a collection and can lazyload.""" - - class Foo(object): - pass - class Bar(object): - pass - class AppenderDict(dict): - @collection.appender - def set(self, item): - self[id(item)] = item - @collection.remover - def remove(self, item): - if id(item) in self: - del self[id(item)] - - mapper(Foo, sometable, properties={ - 'bars':relation(Bar, collection_class=AppenderDict) - }) - mapper(Bar, someothertable) - f = Foo() - f.bars.set(Bar()) - f.bars.set(Bar()) - sess = create_session() - sess.add(f) - sess.flush() - sess.expunge_all() - f = sess.query(Foo).get(f.col1) - assert len(list(f.bars)) == 2 - f.bars.clear() - - @testing.resolve_artifact_names - def test_dict_wrapper(self): - """test that the supplied 'dict' wrapper can be used as a collection and can lazyload.""" - - class Foo(object): - pass - class Bar(object): - def __init__(self, data): self.data = data - - mapper(Foo, sometable, properties={ - 'bars':relation(Bar, - collection_class=collections.column_mapped_collection( - someothertable.c.data)) - }) - mapper(Bar, someothertable) - - f = Foo() - col = collections.collection_adapter(f.bars) - col.append_with_event(Bar('a')) - col.append_with_event(Bar('b')) - sess = create_session() - sess.add(f) - sess.flush() - sess.expunge_all() - f = sess.query(Foo).get(f.col1) - assert len(list(f.bars)) == 2 - - existing = set([id(b) for b in f.bars.values()]) - - col = collections.collection_adapter(f.bars) - col.append_with_event(Bar('b')) - f.bars['a'] = Bar('a') - sess.flush() - sess.expunge_all() - f = sess.query(Foo).get(f.col1) - assert len(list(f.bars)) == 2 - - replaced = set([id(b) for b in f.bars.values()]) - self.assert_(existing != replaced) - - @testing.resolve_artifact_names - def test_list(self): - class Parent(object): - pass - class Child(object): - pass - - mapper(Parent, sometable, properties={ - 'children':relation(Child, collection_class=list) - }) - mapper(Child, someothertable) - - control = list() - p = Parent() - - o = Child() - control.append(o) - p.children.append(o) - assert control == p.children - assert control == list(p.children) - - o = [Child(), Child(), Child(), Child()] - control.extend(o) - p.children.extend(o) - assert control == p.children - assert control == list(p.children) - - assert control[0] == p.children[0] - assert control[-1] == p.children[-1] - assert control[1:3] == p.children[1:3] - - del control[1] - del p.children[1] - assert control == p.children - assert control == list(p.children) - - o = [Child()] - control[1:3] = o - - p.children[1:3] = o - assert control == p.children - assert control == list(p.children) - - o = [Child(), Child(), Child(), Child()] - control[1:3] = o - p.children[1:3] = o - assert control == p.children - assert control == list(p.children) - - o = [Child(), Child(), Child(), Child()] - control[-1:-2] = o - p.children[-1:-2] = o - assert control == p.children - assert control == list(p.children) - - o = [Child(), Child(), Child(), Child()] - control[4:] = o - p.children[4:] = o - assert control == p.children - assert control == list(p.children) - - o = Child() - control.insert(0, o) - p.children.insert(0, o) - assert control == p.children - assert control == list(p.children) - - o = Child() - control.insert(3, o) - p.children.insert(3, o) - assert control == p.children - assert control == list(p.children) - - o = Child() - control.insert(999, o) - p.children.insert(999, o) - assert control == p.children - assert control == list(p.children) - - del control[0:1] - del p.children[0:1] - assert control == p.children - assert control == list(p.children) - - del control[1:1] - del p.children[1:1] - assert control == p.children - assert control == list(p.children) - - del control[1:3] - del p.children[1:3] - assert control == p.children - assert control == list(p.children) - - del control[7:] - del p.children[7:] - assert control == p.children - assert control == list(p.children) - - assert control.pop() == p.children.pop() - assert control == p.children - assert control == list(p.children) - - assert control.pop(0) == p.children.pop(0) - assert control == p.children - assert control == list(p.children) - - assert control.pop(2) == p.children.pop(2) - assert control == p.children - assert control == list(p.children) - - o = Child() - control.insert(2, o) - p.children.insert(2, o) - assert control == p.children - assert control == list(p.children) - - control.remove(o) - p.children.remove(o) - assert control == p.children - assert control == list(p.children) - - @testing.resolve_artifact_names - def test_custom(self): - class Parent(object): - pass - class Child(object): - pass - - class MyCollection(object): - def __init__(self): - self.data = [] - @collection.appender - def append(self, value): - self.data.append(value) - @collection.remover - def remove(self, value): - self.data.remove(value) - @collection.iterator - def __iter__(self): - return iter(self.data) - - mapper(Parent, sometable, properties={ - 'children':relation(Child, collection_class=MyCollection) - }) - mapper(Child, someothertable) - - control = list() - p1 = Parent() - - o = Child() - control.append(o) - p1.children.append(o) - assert control == list(p1.children) - - o = Child() - control.append(o) - p1.children.append(o) - assert control == list(p1.children) - - o = Child() - control.append(o) - p1.children.append(o) - assert control == list(p1.children) - - sess = create_session() - sess.add(p1) - sess.flush() - sess.expunge_all() - - p2 = sess.query(Parent).get(p1.col1) - o = list(p2.children) - assert len(o) == 3 - - -class InstrumentationTest(_base.ORMTest): - def test_uncooperative_descriptor_in_sweep(self): - class DoNotTouch(object): - def __get__(self, obj, owner): - raise AttributeError - - class Touchy(list): - no_touch = DoNotTouch() - - assert 'no_touch' in Touchy.__dict__ - assert not hasattr(Touchy, 'no_touch') - assert 'no_touch' in dir(Touchy) - - instrumented = collections._instrument_class(Touchy) - assert True - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/compile.py b/test/orm/compile.py deleted file mode 100644 index 7c9bed4ec..000000000 --- a/test/orm/compile.py +++ /dev/null @@ -1,186 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy import exc as sa_exc -from sqlalchemy.orm import * -from testlib import * -from orm import _base - - -class CompileTest(_base.ORMTest): - """test various mapper compilation scenarios""" - - def tearDown(self): - clear_mappers() - - def testone(self): - metadata = MetaData(testing.db) - - order = Table('orders', metadata, - Column('id', Integer, primary_key=True), - Column('employee_id', Integer, ForeignKey('employees.id'), nullable=False), - Column('type', Unicode(16))) - - employee = Table('employees', metadata, - Column('id', Integer, primary_key=True), - Column('name', Unicode(16), unique=True, nullable=False)) - - product = Table('products', metadata, - Column('id', Integer, primary_key=True), - ) - - orderproduct = Table('orderproducts', metadata, - Column('id', Integer, primary_key=True), - Column('order_id', Integer, ForeignKey("orders.id"), nullable=False), - Column('product_id', Integer, ForeignKey("products.id"), nullable=False), - ) - - class Order(object): - pass - - class Employee(object): - pass - - class Product(object): - pass - - class OrderProduct(object): - pass - - order_join = order.select().alias('pjoin') - - order_mapper = mapper(Order, order, - with_polymorphic=('*', order_join), - polymorphic_on=order_join.c.type, - polymorphic_identity='order', - properties={ - 'orderproducts': relation(OrderProduct, lazy=True, backref='order')} - ) - - mapper(Product, product, - properties={ - 'orderproducts': relation(OrderProduct, lazy=True, backref='product')} - ) - - mapper(Employee, employee, - properties={ - 'orders': relation(Order, lazy=True, backref='employee')}) - - mapper(OrderProduct, orderproduct) - - # this requires that the compilation of order_mapper's "surrogate - # mapper" occur after the initial setup of MapperProperty objects on - # the mapper. - class_mapper(Product).compile() - - def testtwo(self): - """test that conflicting backrefs raises an exception""" - metadata = MetaData(testing.db) - - order = Table('orders', metadata, - Column('id', Integer, primary_key=True), - Column('type', Unicode(16))) - - product = Table('products', metadata, - Column('id', Integer, primary_key=True), - ) - - orderproduct = Table('orderproducts', metadata, - Column('id', Integer, primary_key=True), - Column('order_id', Integer, ForeignKey("orders.id"), nullable=False), - Column('product_id', Integer, ForeignKey("products.id"), nullable=False), - ) - - class Order(object): - pass - - class Product(object): - pass - - class OrderProduct(object): - pass - - order_join = order.select().alias('pjoin') - - order_mapper = mapper(Order, order, - with_polymorphic=('*', order_join), - polymorphic_on=order_join.c.type, - polymorphic_identity='order', - properties={ - 'orderproducts': relation(OrderProduct, lazy=True, backref='product')} - ) - - mapper(Product, product, - properties={ - 'orderproducts': relation(OrderProduct, lazy=True, backref='product')} - ) - - mapper(OrderProduct, orderproduct) - - try: - class_mapper(Product).compile() - assert False - except sa_exc.ArgumentError, e: - assert str(e).index("Error creating backref ") > -1 - - def testthree(self): - metadata = MetaData(testing.db) - node_table = Table("node", metadata, - Column('node_id', Integer, primary_key=True), - Column('name_index', Integer, nullable=True), - ) - node_name_table = Table("node_name", metadata, - Column('node_name_id', Integer, primary_key=True), - Column('node_id', Integer, ForeignKey('node.node_id')), - Column('host_id', Integer, ForeignKey('host.host_id')), - Column('name', String(64), nullable=False), - ) - host_table = Table("host", metadata, - Column('host_id', Integer, primary_key=True), - Column('hostname', String(64), nullable=False, - unique=True), - ) - metadata.create_all() - try: - node_table.insert().execute(node_id=1, node_index=5) - class Node(object):pass - class NodeName(object):pass - class Host(object):pass - - node_mapper = mapper(Node, node_table) - host_mapper = mapper(Host, host_table) - node_name_mapper = mapper(NodeName, node_name_table, - properties = { - 'node' : relation(Node, backref=backref('names')), - 'host' : relation(Host), - } - ) - sess = create_session() - assert sess.query(Node).get(1).names == [] - finally: - metadata.drop_all() - - def testfour(self): - meta = MetaData() - - a = Table('a', meta, Column('id', Integer, primary_key=True)) - b = Table('b', meta, Column('id', Integer, primary_key=True), Column('a_id', Integer, ForeignKey('a.id'))) - - class A(object):pass - class B(object):pass - - mapper(A, a, properties={ - 'b':relation(B, backref='a') - }) - mapper(B, b, properties={ - 'a':relation(A, backref='b') - }) - - try: - compile_mappers() - assert False - except sa_exc.ArgumentError, e: - assert str(e).index("Error creating backref") > -1 - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/cycles.py b/test/orm/cycles.py deleted file mode 100644 index 3e3636085..000000000 --- a/test/orm/cycles.py +++ /dev/null @@ -1,862 +0,0 @@ -"""Tests cyclical mapper relationships. - -We might want to try an automated generate of much of this, all combos of -T1<->T2, with o2m or m2o between them, and a third T3 with o2m/m2o to one/both -T1/T2. - -""" -import testenv; testenv.configure_for_tests() -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, backref, create_session -from testlib.testing import eq_ -from testlib.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf -from orm import _base - - -class SelfReferentialTest(_base.MappedTest): - """A self-referential mapper with an additional list of child objects.""" - - def define_tables(self, metadata): - Table('t1', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_c1', Integer, ForeignKey('t1.c1')), - Column('data', String(20))) - Table('t2', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c1id', Integer, ForeignKey('t1.c1')), - Column('data', String(20))) - - def setup_classes(self): - class C1(_base.BasicEntity): - def __init__(self, data=None): - self.data = data - - class C2(_base.BasicEntity): - def __init__(self, data=None): - self.data = data - - @testing.resolve_artifact_names - def testsingle(self): - mapper(C1, t1, properties = { - 'c1s':relation(C1, cascade="all"), - 'parent':relation(C1, - primaryjoin=t1.c.parent_c1 == t1.c.c1, - remote_side=t1.c.c1, - lazy=True, - uselist=False)}) - a = C1('head c1') - a.c1s.append(C1('another c1')) - - sess = create_session( ) - sess.add(a) - sess.flush() - sess.delete(a) - sess.flush() - - @testing.resolve_artifact_names - def testmanytooneonly(self): - """ - - test that the circular dependency sort can assemble a many-to-one - dependency processor when only the object on the "many" side is - actually in the list of modified objects. this requires that the - circular sort add the other side of the relation into the - UOWTransaction so that the dependency operation can be tacked onto it. - - This also affects inheritance relationships since they rely upon - circular sort as well. - - """ - mapper(C1, t1, properties={ - 'parent':relation(C1, - primaryjoin=t1.c.parent_c1 == t1.c.c1, - remote_side=t1.c.c1)}) - - c1 = C1() - - sess = create_session() - sess.add(c1) - sess.flush() - sess.expunge_all() - c1 = sess.query(C1).get(c1.c1) - c2 = C1() - c2.parent = c1 - sess.add(c2) - sess.flush() - assert c2.parent_c1==c1.c1 - - @testing.resolve_artifact_names - def testcycle(self): - mapper(C1, t1, properties = { - 'c1s' : relation(C1, cascade="all"), - 'c2s' : relation(mapper(C2, t2), cascade="all, delete-orphan")}) - - a = C1('head c1') - a.c1s.append(C1('child1')) - a.c1s.append(C1('child2')) - a.c1s[0].c1s.append(C1('subchild1')) - a.c1s[0].c1s.append(C1('subchild2')) - a.c1s[1].c2s.append(C2('child2 data1')) - a.c1s[1].c2s.append(C2('child2 data2')) - sess = create_session( ) - sess.add(a) - sess.flush() - - sess.delete(a) - sess.flush() - - @testing.resolve_artifact_names - def test_setnull_ondelete(self): - mapper(C1, t1, properties={ - 'children':relation(C1) - }) - - sess = create_session() - c1 = C1() - c2 = C1() - c1.children.append(c2) - sess.add(c1) - sess.flush() - assert c2.parent_c1 == c1.c1 - - sess.delete(c1) - sess.flush() - assert c2.parent_c1 is None - - sess.expire_all() - assert c2.parent_c1 is None - -class SelfReferentialNoPKTest(_base.MappedTest): - """A self-referential relationship that joins on a column other than the primary key column""" - - def define_tables(self, metadata): - Table('item', metadata, - Column('id', Integer, primary_key=True), - Column('uuid', String(32), unique=True, nullable=False), - Column('parent_uuid', String(32), ForeignKey('item.uuid'), - nullable=True)) - - def setup_classes(self): - class TT(_base.BasicEntity): - def __init__(self): - self.uuid = hex(id(self)) - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(TT, item, properties={ - 'children': relation( - TT, - remote_side=[item.c.parent_uuid], - backref=backref('parent', remote_side=[item.c.uuid]))}) - - @testing.resolve_artifact_names - def testbasic(self): - t1 = TT() - t1.children.append(TT()) - t1.children.append(TT()) - - s = create_session() - s.add(t1) - s.flush() - s.expunge_all() - t = s.query(TT).filter_by(id=t1.id).one() - eq_(t.children[0].parent_uuid, t1.uuid) - - @testing.resolve_artifact_names - def testlazyclause(self): - s = create_session() - t1 = TT() - t2 = TT() - t1.children.append(t2) - s.add(t1) - s.flush() - s.expunge_all() - - t = s.query(TT).filter_by(id=t2.id).one() - eq_(t.uuid, t2.uuid) - eq_(t.parent.uuid, t1.uuid) - - -class InheritTestOne(_base.MappedTest): - def define_tables(self, metadata): - Table("parent", metadata, - Column("id", Integer, primary_key=True), - Column("parent_data", String(50)), - Column("type", String(10))) - - Table("child1", metadata, - Column("id", Integer, ForeignKey("parent.id"), - primary_key=True), - Column("child1_data", String(50))) - - Table("child2", metadata, - Column("id", Integer, ForeignKey("parent.id"), - primary_key=True), - Column("child1_id", Integer, ForeignKey("child1.id"), - nullable=False), - Column("child2_data", String(50))) - - def setup_classes(self): - class Parent(_base.BasicEntity): - pass - - class Child1(Parent): - pass - - class Child2(Parent): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Parent, parent) - mapper(Child1, child1, inherits=Parent) - mapper(Child2, child2, inherits=Parent, properties=dict( - child1=relation(Child1, - primaryjoin=child2.c.child1_id == child1.c.id))) - - @testing.resolve_artifact_names - def testmanytooneonly(self): - """test similar to SelfReferentialTest.testmanytooneonly""" - - session = create_session() - - c1 = Child1() - c1.child1_data = "qwerty" - session.add(c1) - session.flush() - session.expunge_all() - - c1 = session.query(Child1).filter_by(child1_data="qwerty").one() - c2 = Child2() - c2.child1 = c1 - c2.child2_data = "asdfgh" - session.add(c2) - - # the flush will fail if the UOW does not set up a many-to-one DP - # attached to a task corresponding to c1, since "child1_id" is not - # nullable - session.flush() - - -class InheritTestTwo(_base.MappedTest): - """ - - The fix in BiDirectionalManyToOneTest raised this issue, regarding the - 'circular sort' containing UOWTasks that were still polymorphic, which - could create duplicate entries in the final sort - - """ - - def define_tables(self, metadata): - Table('a', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)), - Column('cid', Integer, ForeignKey('c.id'))) - - Table('b', metadata, - Column('id', Integer, ForeignKey("a.id"), primary_key=True), - Column('data', String(30))) - - Table('c', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)), - Column('aid', Integer, - ForeignKey('a.id', use_alter=True, name="foo"))) - - def setup_classes(self): - class A(_base.BasicEntity): - pass - - class B(A): - pass - - class C(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def test_flush(self): - mapper(A, a, properties={ - 'cs':relation(C, primaryjoin=a.c.cid==c.c.id)}) - - mapper(B, b, inherits=A, inherit_condition=b.c.id == a.c.id) - - mapper(C, c, properties={ - 'arel':relation(A, primaryjoin=a.c.id == c.c.aid)}) - - sess = create_session() - bobj = B() - sess.add(bobj) - cobj = C() - sess.add(cobj) - sess.flush() - - -class BiDirectionalManyToOneTest(_base.MappedTest): - run_define_tables = 'each' - - def define_tables(self, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)), - Column('t2id', Integer, ForeignKey('t2.id'))) - Table('t2', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)), - Column('t1id', Integer, - ForeignKey('t1.id', use_alter=True, name="foo_fk"))) - Table('t3', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)), - Column('t1id', Integer, ForeignKey('t1.id'), nullable=False), - Column('t2id', Integer, ForeignKey('t2.id'), nullable=False)) - - def setup_classes(self): - class T1(_base.BasicEntity): - pass - class T2(_base.BasicEntity): - pass - class T3(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(T1, t1, properties={ - 't2':relation(T2, primaryjoin=t1.c.t2id == t2.c.id)}) - mapper(T2, t2, properties={ - 't1':relation(T1, primaryjoin=t2.c.t1id == t1.c.id)}) - mapper(T3, t3, properties={ - 't1':relation(T1), - 't2':relation(T2)}) - - @testing.resolve_artifact_names - def test_reflush(self): - o1 = T1() - o1.t2 = T2() - sess = create_session() - sess.add(o1) - sess.flush() - - # the bug here is that the dependency sort comes up with T1/T2 in a - # cycle, but there are no T1/T2 objects to be saved. therefore no - # "cyclical subtree" gets generated, and one or the other of T1/T2 - # gets lost, and processors on T3 dont fire off. the test will then - # fail because the FK's on T3 are not nullable. - o3 = T3() - o3.t1 = o1 - o3.t2 = o1.t2 - sess.add(o3) - sess.flush() - - - @testing.resolve_artifact_names - def test_reflush_2(self): - """A variant on test_reflush()""" - o1 = T1() - o1.t2 = T2() - sess = create_session() - sess.add(o1) - sess.flush() - - # in this case, T1, T2, and T3 tasks will all be in the cyclical - # tree normally. the dependency processors for T3 are part of the - # 'extradeps' collection so they all get assembled into the tree - # as well. - o1a = T1() - o2a = T2() - sess.add(o1a) - sess.add(o2a) - o3b = T3() - o3b.t1 = o1a - o3b.t2 = o2a - sess.add(o3b) - - o3 = T3() - o3.t1 = o1 - o3.t2 = o1.t2 - sess.add(o3) - sess.flush() - - -class BiDirectionalOneToManyTest(_base.MappedTest): - """tests two mappers with a one-to-many relation to each other.""" - - run_define_tables = 'each' - - def define_tables(self, metadata): - Table('t1', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c2', Integer, ForeignKey('t2.c1'))) - - Table('t2', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('c2', Integer, - ForeignKey('t1.c1', use_alter=True, name='t1c1_fk'))) - - def setup_classes(self): - class C1(_base.BasicEntity): - pass - - class C2(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def testcycle(self): - mapper(C2, t2, properties={ - 'c1s': relation(C1, - primaryjoin=t2.c.c1 == t1.c.c2, - uselist=True)}) - mapper(C1, t1, properties={ - 'c2s': relation(C2, - primaryjoin=t1.c.c1 == t2.c.c2, - uselist=True)}) - - a = C1() - b = C2() - c = C1() - d = C2() - e = C2() - f = C2() - a.c2s.append(b) - d.c1s.append(c) - b.c1s.append(c) - sess = create_session() - sess.add_all((a, b, c, d, e, f)) - sess.flush() - - -class BiDirectionalOneToManyTest2(_base.MappedTest): - """Two mappers with a one-to-many relation to each other, with a second one-to-many on one of the mappers""" - - run_define_tables = 'each' - - def define_tables(self, metadata): - Table('t1', metadata, - Column('c1', Integer, primary_key=True), - Column('c2', Integer, ForeignKey('t2.c1')), - test_needs_autoincrement=True) - - Table('t2', metadata, - Column('c1', Integer, primary_key=True), - Column('c2', Integer, - ForeignKey('t1.c1', use_alter=True, name='t1c1_fq')), - test_needs_autoincrement=True) - - Table('t1_data', metadata, - Column('c1', Integer, primary_key=True), - Column('t1id', Integer, ForeignKey('t1.c1')), - Column('data', String(20)), - test_needs_autoincrement=True) - - def setup_classes(self): - class C1(_base.BasicEntity): - pass - - class C2(_base.BasicEntity): - pass - - class C1Data(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(C2, t2, properties={ - 'c1s': relation(C1, - primaryjoin=t2.c.c1 == t1.c.c2, - uselist=True)}) - mapper(C1, t1, properties={ - 'c2s': relation(C2, - primaryjoin=t1.c.c1 == t2.c.c2, - uselist=True), - 'data': relation(mapper(C1Data, t1_data))}) - - @testing.resolve_artifact_names - def testcycle(self): - a = C1() - b = C2() - c = C1() - d = C2() - e = C2() - f = C2() - a.c2s.append(b) - d.c1s.append(c) - b.c1s.append(c) - a.data.append(C1Data(data='c1data1')) - a.data.append(C1Data(data='c1data2')) - c.data.append(C1Data(data='c1data3')) - sess = create_session() - sess.add_all((a, b, c, d, e, f)) - sess.flush() - - sess.delete(d) - sess.delete(c) - sess.flush() - -class OneToManyManyToOneTest(_base.MappedTest): - """ - - Tests two mappers, one has a one-to-many on the other mapper, the other - has a separate many-to-one relationship to the first. two tests will have - a row for each item that is dependent on the other. without the - "post_update" flag, such relationships raise an exception when - dependencies are sorted. - - """ - run_define_tables = 'each' - - def define_tables(self, metadata): - Table('ball', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('person_id', Integer, - ForeignKey('person.id', use_alter=True, name='fk_person_id')), - Column('data', String(30))) - - Table('person', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('favorite_ball_id', Integer, ForeignKey('ball.id')), - Column('data', String(30))) - - def setup_classes(self): - class Person(_base.BasicEntity): - pass - - class Ball(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def testcycle(self): - """ - This test has a peculiar aspect in that it doesnt create as many - dependent relationships as the other tests, and revealed a small - glitch in the circular dependency sorting. - - """ - mapper(Ball, ball) - mapper(Person, person, properties=dict( - balls=relation(Ball, - primaryjoin=ball.c.person_id == person.c.id, - remote_side=ball.c.person_id), - favorite=relation(Ball, - primaryjoin=person.c.favorite_ball_id == ball.c.id, - remote_side=ball.c.id))) - - b = Ball() - p = Person() - p.balls.append(b) - sess = create_session() - sess.add(p) - sess.flush() - - @testing.resolve_artifact_names - def testpostupdate_m2o(self): - """A cycle between two rows, with a post_update on the many-to-one""" - mapper(Ball, ball) - mapper(Person, person, properties=dict( - balls=relation(Ball, - primaryjoin=ball.c.person_id == person.c.id, - remote_side=ball.c.person_id, - post_update=False, - cascade="all, delete-orphan"), - favorite=relation(Ball, - primaryjoin=person.c.favorite_ball_id == ball.c.id, - remote_side=person.c.favorite_ball_id, - post_update=True))) - - b = Ball(data='some data') - p = Person(data='some data') - p.balls.append(b) - p.balls.append(Ball(data='some data')) - p.balls.append(Ball(data='some data')) - p.balls.append(Ball(data='some data')) - p.favorite = b - sess = create_session() - sess.add(b) - sess.add(p) - - self.assert_sql_execution( - testing.db, - sess.flush, - RegexSQL("^INSERT INTO person", {'data':'some data'}), - RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), - RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), - RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), - RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), - ExactSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id " - "WHERE person.id = :person_id", - lambda ctx:{'favorite_ball_id':p.favorite.id, 'person_id':p.id} - ), - ) - - sess.delete(p) - - self.assert_sql_execution( - testing.db, - sess.flush, - ExactSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id " - "WHERE person.id = :person_id", - lambda ctx: {'person_id': p.id, 'favorite_ball_id': None}), - ExactSQL("DELETE FROM ball WHERE ball.id = :id", None), # lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}]) - ExactSQL("DELETE FROM person WHERE person.id = :id", lambda ctx:[{'id': p.id}]) - ) - - @testing.resolve_artifact_names - def testpostupdate_o2m(self): - """A cycle between two rows, with a post_update on the one-to-many""" - - mapper(Ball, ball) - mapper(Person, person, properties=dict( - balls=relation(Ball, - primaryjoin=ball.c.person_id == person.c.id, - remote_side=ball.c.person_id, - cascade="all, delete-orphan", - post_update=True, - backref='person'), - favorite=relation(Ball, - primaryjoin=person.c.favorite_ball_id == ball.c.id, - remote_side=person.c.favorite_ball_id))) - - b = Ball(data='some data') - p = Person(data='some data') - p.balls.append(b) - b2 = Ball(data='some data') - p.balls.append(b2) - b3 = Ball(data='some data') - p.balls.append(b3) - b4 = Ball(data='some data') - p.balls.append(b4) - p.favorite = b - sess = create_session() - sess.add_all((b,p,b2,b3,b4)) - - self.assert_sql_execution( - testing.db, - sess.flush, - CompiledSQL("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - {'person_id':None, 'data':'some data'}), - - CompiledSQL("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - {'person_id':None, 'data':'some data'}), - - CompiledSQL("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - {'person_id':None, 'data':'some data'}), - - CompiledSQL("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - {'person_id':None, 'data':'some data'}), - - CompiledSQL("INSERT INTO person (favorite_ball_id, data) " - "VALUES (:favorite_ball_id, :data)", - lambda ctx:{'favorite_ball_id':b.id, 'data':'some data'}), - - AllOf( - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b2.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b3.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b4.id}) - ), - ) - - sess.delete(p) - - self.assert_sql_execution(testing.db, sess.flush, - AllOf(CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id': None, 'ball_id': b.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id': None, 'ball_id': b2.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id': None, 'ball_id': b3.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id': None, 'ball_id': b4.id})), - - CompiledSQL("DELETE FROM person WHERE person.id = :id", - lambda ctx:[{'id':p.id}]), - - CompiledSQL("DELETE FROM ball WHERE ball.id = :id", - lambda ctx:[{'id': b.id}, - {'id': b2.id}, - {'id': b3.id}, - {'id': b4.id}]) - ) - - -class SelfReferentialPostUpdateTest(_base.MappedTest): - """Post_update on a single self-referential mapper""" - - def define_tables(self, metadata): - Table('node', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('path', String(50), nullable=False), - Column('parent_id', Integer, - ForeignKey('node.id'), nullable=True), - Column('prev_sibling_id', Integer, - ForeignKey('node.id'), nullable=True), - Column('next_sibling_id', Integer, - ForeignKey('node.id'), nullable=True)) - - def setup_classes(self): - class Node(_base.BasicEntity): - def __init__(self, path=''): - self.path = path - - @testing.resolve_artifact_names - def testbasic(self): - """Post_update only fires off when needed. - - This test case used to produce many superfluous update statements, - particularly upon delete - - """ - - mapper(Node, node, properties={ - 'children': relation( - Node, - primaryjoin=node.c.id==node.c.parent_id, - lazy=True, - cascade="all", - backref=backref("parent", remote_side=node.c.id) - ), - 'prev_sibling': relation( - Node, - primaryjoin=node.c.prev_sibling_id==node.c.id, - remote_side=node.c.id, - lazy=True, - uselist=False), - 'next_sibling': relation( - Node, - primaryjoin=node.c.next_sibling_id==node.c.id, - remote_side=node.c.id, - lazy=True, - uselist=False, - post_update=True)}) - - session = create_session() - - def append_child(parent, child): - if parent.children: - parent.children[-1].next_sibling = child - child.prev_sibling = parent.children[-1] - parent.children.append(child) - - def remove_child(parent, child): - child.parent = None - node = child.next_sibling - node.prev_sibling = child.prev_sibling - child.prev_sibling.next_sibling = node - session.delete(child) - root = Node('root') - - about = Node('about') - cats = Node('cats') - stories = Node('stories') - bruce = Node('bruce') - - append_child(root, about) - assert(about.prev_sibling is None) - append_child(root, cats) - assert(cats.prev_sibling is about) - assert(cats.next_sibling is None) - assert(about.next_sibling is cats) - assert(about.prev_sibling is None) - append_child(root, stories) - append_child(root, bruce) - session.add(root) - session.flush() - - remove_child(root, cats) - # pre-trigger lazy loader on 'cats' to make the test easier - cats.children - self.assert_sql_execution( - testing.db, - session.flush, - CompiledSQL("UPDATE node SET prev_sibling_id=:prev_sibling_id " - "WHERE node.id = :node_id", - lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id}), - - CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " - "WHERE node.id = :node_id", - lambda ctx:{'next_sibling_id':stories.id, 'node_id':about.id}), - - CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " - "WHERE node.id = :node_id", - lambda ctx:{'next_sibling_id':None, 'node_id':cats.id}), - - CompiledSQL("DELETE FROM node WHERE node.id = :id", - lambda ctx:[{'id':cats.id}]) - ) - - -class SelfReferentialPostUpdateTest2(_base.MappedTest): - - def define_tables(self, metadata): - Table("a_table", metadata, - Column("id", Integer(), primary_key=True), - Column("fui", String(128)), - Column("b", Integer(), ForeignKey("a_table.id"))) - - def setup_classes(self): - class A(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def testbasic(self): - """ - Test that post_update remembers to be involved in update operations as - well, since it replaces the normal dependency processing completely - [ticket:413] - - """ - - mapper(A, a_table, properties={ - 'foo': relation(A, - remote_side=[a_table.c.id], - post_update=True)}) - - session = create_session() - - f1 = A(fui="f1") - session.add(f1) - session.flush() - - f2 = A(fui="f2", foo=f1) - - # at this point f1 is already inserted. but we need post_update - # to fire off anyway - session.add(f2) - session.flush() - session.expunge_all() - - f1 = session.query(A).get(f1.id) - f2 = session.query(A).get(f2.id) - assert f2.foo is f1 - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/defaults.py b/test/orm/defaults.py deleted file mode 100644 index 8dc192519..000000000 --- a/test/orm/defaults.py +++ /dev/null @@ -1,129 +0,0 @@ -import testenv; testenv.configure_for_tests() - -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base -from testlib.testing import eq_ - - -class TriggerDefaultsTest(_base.MappedTest): - __requires__ = ('row_triggers',) - - def define_tables(self, metadata): - dt = Table('dt', metadata, - Column('id', Integer, primary_key=True), - Column('col1', String(20)), - Column('col2', String(20), - server_default=sa.schema.FetchedValue()), - Column('col3', String(20), - sa.schema.FetchedValue(for_update=True)), - Column('col4', String(20), - sa.schema.FetchedValue(), - sa.schema.FetchedValue(for_update=True))) - for ins in ( - sa.DDL("CREATE TRIGGER dt_ins AFTER INSERT ON dt " - "FOR EACH ROW BEGIN " - "UPDATE dt SET col2='ins', col4='ins' " - "WHERE dt.id = NEW.id; END", - on='sqlite'), - sa.DDL("CREATE TRIGGER dt_ins ON dt AFTER INSERT AS " - "UPDATE dt SET col2='ins', col4='ins' " - "WHERE dt.id IN (SELECT id FROM inserted);", - on='mssql'), - ): - if testing.against(ins.on): - break - else: - ins = sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt " - "FOR EACH ROW BEGIN " - "SET NEW.col2='ins'; SET NEW.col4='ins'; END") - ins.execute_at('after-create', dt) - sa.DDL("DROP TRIGGER dt_ins").execute_at('before-drop', dt) - - - for up in ( - sa.DDL("CREATE TRIGGER dt_up AFTER UPDATE ON dt " - "FOR EACH ROW BEGIN " - "UPDATE dt SET col3='up', col4='up' " - "WHERE dt.id = OLD.id; END", - on='sqlite'), - sa.DDL("CREATE TRIGGER dt_up ON dt AFTER UPDATE AS " - "UPDATE dt SET col3='up', col4='up' " - "WHERE dt.id IN (SELECT id FROM deleted);", - on='mssql'), - ): - if testing.against(up.on): - break - else: - up = sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt " - "FOR EACH ROW BEGIN " - "SET NEW.col3='up'; SET NEW.col4='up'; END") - up.execute_at('after-create', dt) - sa.DDL("DROP TRIGGER dt_up").execute_at('before-drop', dt) - - - def setup_classes(self): - class Default(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Default, dt) - - @testing.resolve_artifact_names - def test_insert(self): - - d1 = Default(id=1) - - eq_(d1.col1, None) - eq_(d1.col2, None) - eq_(d1.col3, None) - eq_(d1.col4, None) - - session = create_session() - session.add(d1) - session.flush() - - eq_(d1.col1, None) - eq_(d1.col2, 'ins') - eq_(d1.col3, None) - # don't care which trigger fired - assert d1.col4 in ('ins', 'up') - - @testing.resolve_artifact_names - def test_update(self): - d1 = Default(id=1) - - session = create_session() - session.add(d1) - session.flush() - d1.col1 = 'set' - session.flush() - - eq_(d1.col1, 'set') - eq_(d1.col2, 'ins') - eq_(d1.col3, 'up') - eq_(d1.col4, 'up') - -class ExcludedDefaultsTest(_base.MappedTest): - def define_tables(self, metadata): - dt = Table('dt', metadata, - Column('id', Integer, primary_key=True), - Column('col1', String(20), default="hello"), - ) - - @testing.resolve_artifact_names - def test_exclude(self): - class Foo(_base.ComparableEntity): - pass - mapper(Foo, dt, exclude_properties=('col1',)) - - f1 = Foo() - sess = create_session() - sess.add(f1) - sess.flush() - eq_(dt.select().execute().fetchall(), [(1, "hello")]) - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/deprecations.py b/test/orm/deprecations.py deleted file mode 100644 index 483e8f556..000000000 --- a/test/orm/deprecations.py +++ /dev/null @@ -1,483 +0,0 @@ -"""The collection of modern alternatives to deprecated & removed functionality. - -Collects specimens of old ORM code and explicitly covers the recommended -modern (i.e. not deprecated) alternative to them. The tests snippets here can -be migrated directly to the wiki, docs, etc. - -""" -import testenv; testenv.configure_for_tests() -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, func -from testlib.sa.orm import mapper, relation, create_session, sessionmaker -from orm import _base - - -class QueryAlternativesTest(_base.MappedTest): - '''Collects modern idioms for Queries - - The docstring for each test case serves as miniature documentation about - the deprecated use case, and the test body illustrates (and covers) the - intended replacement code to accomplish the same task. - - Documenting the "old way" including the argument signature helps these - cases remain useful to readers even after the deprecated method has been - removed from the modern codebase. - - Format: - - def test_deprecated_thing(self): - """Query.methodname(old, arg, **signature) - - output = session.query(User).deprecatedmethod(inputs) - - """ - # 0.4+ - output = session.query(User).newway(inputs) - assert output is correct - - # 0.5+ - output = session.query(User).evennewerway(inputs) - assert output is correct - - ''' - - run_inserts = 'once' - run_deletes = None - - def define_tables(self, metadata): - Table('users_table', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(64))) - - Table('addresses_table', metadata, - Column('id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey('users_table.id')), - Column('email_address', String(128)), - Column('purpose', String(16)), - Column('bounces', Integer, default=0)) - - def setup_classes(self): - class User(_base.BasicEntity): - pass - - class Address(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(User, users_table, properties=dict( - addresses=relation(Address, backref='user'), - )) - mapper(Address, addresses_table) - - def fixtures(self): - return dict( - users_table=( - ('id', 'name'), - (1, 'jack'), - (2, 'ed'), - (3, 'fred'), - (4, 'chuck')), - - addresses_table=( - ('id', 'user_id', 'email_address', 'purpose', 'bounces'), - (1, 1, 'jack@jack.home', 'Personal', 0), - (2, 1, 'jack@jack.bizz', 'Work', 1), - (3, 2, 'ed@foo.bar', 'Personal', 0), - (4, 3, 'fred@the.fred', 'Personal', 10))) - - - ###################################################################### - - @testing.resolve_artifact_names - def test_override_get(self): - """MapperExtension.get() - - x = session.query.get(5) - - """ - from sqlalchemy.orm.query import Query - cache = {} - class MyQuery(Query): - def get(self, ident, **kwargs): - if ident in cache: - return cache[ident] - else: - x = super(MyQuery, self).get(ident) - cache[ident] = x - return x - - session = sessionmaker(query_cls=MyQuery)() - - ad1 = session.query(Address).get(1) - assert ad1 in cache.values() - - @testing.resolve_artifact_names - def test_load(self): - """x = session.query(Address).load(1) - - x = session.load(Address, 1) - - """ - - session = create_session() - ad1 = session.query(Address).populate_existing().get(1) - assert bool(ad1) - - - @testing.resolve_artifact_names - def test_apply_max(self): - """Query.apply_max(col) - - max = session.query(Address).apply_max(Address.bounces) - - """ - session = create_session() - - # 0.5.0 - maxes = list(session.query(Address).values(func.max(Address.bounces))) - max = maxes[0][0] - assert max == 10 - - max = session.query(func.max(Address.bounces)).one()[0] - assert max == 10 - - @testing.resolve_artifact_names - def test_apply_min(self): - """Query.apply_min(col) - - min = session.query(Address).apply_min(Address.bounces) - - """ - session = create_session() - - # 0.5.0 - mins = list(session.query(Address).values(func.min(Address.bounces))) - min = mins[0][0] - assert min == 0 - - min = session.query(func.min(Address.bounces)).one()[0] - assert min == 0 - - @testing.resolve_artifact_names - def test_apply_avg(self): - """Query.apply_avg(col) - - avg = session.query(Address).apply_avg(Address.bounces) - - """ - session = create_session() - - avgs = list(session.query(Address).values(func.avg(Address.bounces))) - avg = avgs[0][0] - assert avg > 0 and avg < 10 - - avg = session.query(func.avg(Address.bounces)).one()[0] - assert avg > 0 and avg < 10 - - @testing.resolve_artifact_names - def test_apply_sum(self): - """Query.apply_sum(col) - - avg = session.query(Address).apply_avg(Address.bounces) - - """ - session = create_session() - - avgs = list(session.query(Address).values(func.sum(Address.bounces))) - avg = avgs[0][0] - assert avg == 11 - - avg = session.query(func.sum(Address.bounces)).one()[0] - assert avg == 11 - - @testing.resolve_artifact_names - def test_count_by(self): - """Query.count_by(*args, **params) - - num = session.query(Address).count_by(purpose='Personal') - - # old-style implicit *_by join - num = session.query(User).count_by(purpose='Personal') - - """ - session = create_session() - - num = session.query(Address).filter_by(purpose='Personal').count() - assert num == 3, num - - num = (session.query(User).join('addresses'). - filter(Address.purpose=='Personal')).count() - assert num == 3, num - - @testing.resolve_artifact_names - def test_count_whereclause(self): - """Query.count(whereclause=None, params=None, **kwargs) - - num = session.query(Address).count(address_table.c.bounces > 1) - - """ - session = create_session() - - num = session.query(Address).filter(Address.bounces > 1).count() - assert num == 1, num - - @testing.resolve_artifact_names - def test_execute(self): - """Query.execute(clauseelement, params=None, *args, **kwargs) - - users = session.query(User).execute(users_table.select()) - - """ - session = create_session() - - users = session.query(User).from_statement(users_table.select()).all() - assert len(users) == 4 - - @testing.resolve_artifact_names - def test_get_by(self): - """Query.get_by(*args, **params) - - user = session.query(User).get_by(name='ed') - - # 0.3-style implicit *_by join - user = session.query(User).get_by(email_addresss='fred@the.fred') - - """ - session = create_session() - - user = session.query(User).filter_by(name='ed').first() - assert user.name == 'ed' - - user = (session.query(User).join('addresses'). - filter(Address.email_address=='fred@the.fred')).first() - assert user.name == 'fred' - - user = session.query(User).filter( - User.addresses.any(Address.email_address=='fred@the.fred')).first() - assert user.name == 'fred' - - @testing.resolve_artifact_names - def test_instances_entities(self): - """Query.instances(cursor, *mappers_or_columns, **kwargs) - - sel = users_table.join(addresses_table).select(use_labels=True) - res = session.query(User).instances(sel.execute(), Address) - - """ - session = create_session() - - sel = users_table.join(addresses_table).select(use_labels=True) - res = list(session.query(User, Address).instances(sel.execute())) - - assert len(res) == 4 - cola, colb = res[0] - assert isinstance(cola, User) and isinstance(colb, Address) - - @testing.resolve_artifact_names - def test_join_by(self): - """Query.join_by(*args, **params) - - TODO - """ - session = create_session() - - - @testing.resolve_artifact_names - def test_join_to(self): - """Query.join_to(key) - - TODO - """ - session = create_session() - - - @testing.resolve_artifact_names - def test_join_via(self): - """Query.join_via(keys) - - TODO - """ - session = create_session() - - - @testing.resolve_artifact_names - def test_list(self): - """Query.list() - - users = session.query(User).list() - - """ - session = create_session() - - users = session.query(User).all() - assert len(users) == 4 - - @testing.resolve_artifact_names - def test_scalar(self): - """Query.scalar() - - user = session.query(User).filter(User.id==1).scalar() - - """ - session = create_session() - - user = session.query(User).filter(User.id==1).first() - assert user.id==1 - - @testing.resolve_artifact_names - def test_select(self): - """Query.select(arg=None, **kwargs) - - users = session.query(User).select(users_table.c.name != None) - - """ - session = create_session() - - users = session.query(User).filter(User.name != None).all() - assert len(users) == 4 - - @testing.resolve_artifact_names - def test_select_by(self): - """Query.select_by(*args, **params) - - users = session.query(User).select_by(name='fred') - - # 0.3 magic join on *_by methods - users = session.query(User).select_by(email_address='fred@the.fred') - - """ - session = create_session() - - users = session.query(User).filter_by(name='fred').all() - assert len(users) == 1 - - users = session.query(User).filter(User.name=='fred').all() - assert len(users) == 1 - - users = (session.query(User).join('addresses'). - filter_by(email_address='fred@the.fred')).all() - assert len(users) == 1 - - users = session.query(User).filter(User.addresses.any( - Address.email_address == 'fred@the.fred')).all() - assert len(users) == 1 - - @testing.resolve_artifact_names - def test_selectfirst(self): - """Query.selectfirst(arg=None, **kwargs) - - bounced = session.query(Address).selectfirst( - addresses_table.c.bounces > 0) - - """ - session = create_session() - - bounced = session.query(Address).filter(Address.bounces > 0).first() - assert bounced.bounces > 0 - - @testing.resolve_artifact_names - def test_selectfirst_by(self): - """Query.selectfirst_by(*args, **params) - - onebounce = session.query(Address).selectfirst_by(bounces=1) - - # 0.3 magic join on *_by methods - onebounce_user = session.query(User).selectfirst_by(bounces=1) - - """ - session = create_session() - - onebounce = session.query(Address).filter_by(bounces=1).first() - assert onebounce.bounces == 1 - - onebounce_user = (session.query(User).join('addresses'). - filter_by(bounces=1)).first() - assert onebounce_user.name == 'jack' - - onebounce_user = (session.query(User).join('addresses'). - filter(Address.bounces == 1)).first() - assert onebounce_user.name == 'jack' - - onebounce_user = session.query(User).filter(User.addresses.any( - Address.bounces == 1)).first() - assert onebounce_user.name == 'jack' - - @testing.resolve_artifact_names - def test_selectone(self): - """Query.selectone(arg=None, **kwargs) - - ed = session.query(User).selectone(users_table.c.name == 'ed') - - """ - session = create_session() - - ed = session.query(User).filter(User.name == 'jack').one() - - @testing.resolve_artifact_names - def test_selectone_by(self): - """Query.selectone_by - - ed = session.query(User).selectone_by(name='ed') - - # 0.3 magic join on *_by methods - ed = session.query(User).selectone_by(email_address='ed@foo.bar') - - """ - session = create_session() - - ed = session.query(User).filter_by(name='jack').one() - - ed = session.query(User).filter(User.name == 'jack').one() - - ed = session.query(User).join('addresses').filter( - Address.email_address == 'ed@foo.bar').one() - - ed = session.query(User).filter(User.addresses.any( - Address.email_address == 'ed@foo.bar')).one() - - @testing.resolve_artifact_names - def test_select_statement(self): - """Query.select_statement(statement, **params) - - users = session.query(User).select_statement(users_table.select()) - - """ - session = create_session() - - users = session.query(User).from_statement(users_table.select()).all() - assert len(users) == 4 - - @testing.resolve_artifact_names - def test_select_text(self): - """Query.select_text(text, **params) - - users = session.query(User).select_text('SELECT * FROM users_table') - - """ - session = create_session() - - users = (session.query(User). - from_statement('SELECT * FROM users_table')).all() - assert len(users) == 4 - - @testing.resolve_artifact_names - def test_select_whereclause(self): - """Query.select_whereclause(whereclause=None, params=None, **kwargs) - - - users = session,query(User).select_whereclause(users.c.name=='ed') - users = session.query(User).select_whereclause("name='ed'") - - """ - session = create_session() - - users = session.query(User).filter(User.name=='ed').all() - assert len(users) == 1 and users[0].name == 'ed' - - users = session.query(User).filter("name='ed'").all() - assert len(users) == 1 and users[0].name == 'ed' - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/dynamic.py b/test/orm/dynamic.py deleted file mode 100644 index 3bd94b7c0..000000000 --- a/test/orm/dynamic.py +++ /dev/null @@ -1,559 +0,0 @@ -import testenv; testenv.configure_for_tests() -import operator -from sqlalchemy.orm import dynamic_loader, backref -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, desc, select, func -from testlib.sa.orm import mapper, relation, create_session, Query, attributes -from sqlalchemy.orm.dynamic import AppenderMixin -from testlib.testing import eq_ -from testlib.compat import _function_named -from orm import _base, _fixtures - - -class DynamicTest(_fixtures.FixtureTest): - @testing.resolve_artifact_names - def test_basic(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses)) - }) - sess = create_session() - q = sess.query(User) - - u = q.filter(User.id==7).first() - eq_([User(id=7, - addresses=[Address(id=1, email_address='jack@bean.com')])], - q.filter(User.id==7).all()) - eq_(self.static.user_address_result, q.all()) - - @testing.resolve_artifact_names - def test_order_by(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses)) - }) - sess = create_session() - u = sess.query(User).get(8) - eq_(list(u.addresses.order_by(desc(Address.email_address))), [Address(email_address=u'ed@wood.com'), Address(email_address=u'ed@lala.com'), Address(email_address=u'ed@bettyboop.com')]) - - @testing.resolve_artifact_names - def test_configured_order_by(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses), order_by=desc(Address.email_address)) - }) - sess = create_session() - u = sess.query(User).get(8) - eq_(list(u.addresses), [Address(email_address=u'ed@wood.com'), Address(email_address=u'ed@lala.com'), Address(email_address=u'ed@bettyboop.com')]) - - # test cancellation of None, replacement with something else - eq_( - list(u.addresses.order_by(None).order_by(Address.email_address)), - [Address(email_address=u'ed@bettyboop.com'), Address(email_address=u'ed@lala.com'), Address(email_address=u'ed@wood.com')] - ) - - # test cancellation of None, replacement with nothing - eq_( - set(u.addresses.order_by(None)), - set([Address(email_address=u'ed@bettyboop.com'), Address(email_address=u'ed@lala.com'), Address(email_address=u'ed@wood.com')]) - ) - - @testing.resolve_artifact_names - def test_count(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses)) - }) - sess = create_session() - u = sess.query(User).first() - eq_(u.addresses.count(), 1) - - @testing.resolve_artifact_names - def test_backref(self): - mapper(Address, addresses, properties={ - 'user':relation(User, backref=backref('addresses', lazy='dynamic')) - }) - mapper(User, users) - - sess = create_session() - ad = sess.query(Address).get(1) - def go(): - ad.user = None - self.assert_sql_count(testing.db, go, 1) - sess.flush() - u = sess.query(User).get(7) - assert ad not in u.addresses - - @testing.resolve_artifact_names - def test_no_count(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses)) - }) - sess = create_session() - q = sess.query(User) - - # dynamic collection cannot implement __len__() (at least one that - # returns a live database result), else additional count() queries are - # issued when evaluating in a list context - def go(): - eq_([User(id=7, - addresses=[Address(id=1, - email_address='jack@bean.com')])], - q.filter(User.id==7).all()) - self.assert_sql_count(testing.db, go, 2) - - @testing.resolve_artifact_names - def test_m2m(self): - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, lazy="dynamic", - backref=backref('orders', lazy="dynamic")) - }) - mapper(Item, items) - - sess = create_session() - o1 = Order(id=15, description="order 10") - i1 = Item(id=10, description="item 8") - o1.items.append(i1) - sess.add(o1) - sess.flush() - - assert o1 in i1.orders.all() - assert i1 in o1.items.all() - - @testing.resolve_artifact_names - def test_transient_detached(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses)) - }) - sess = create_session() - u1 = User() - u1.addresses.append(Address()) - assert u1.addresses.count() == 1 - assert u1.addresses[0] == Address() - - @testing.resolve_artifact_names - def test_custom_query(self): - class MyQuery(Query): - pass - - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses), - query_class=MyQuery) - }) - sess = create_session() - u = User() - sess.add(u) - - col = u.addresses - assert isinstance(col, Query) - assert isinstance(col, MyQuery) - assert hasattr(col, 'append') - assert type(col).__name__ == 'AppenderMyQuery' - - q = col.limit(1) - assert isinstance(q, Query) - assert isinstance(q, MyQuery) - assert not hasattr(q, 'append') - assert type(q).__name__ == 'MyQuery' - - @testing.resolve_artifact_names - def test_custom_query_with_custom_mixin(self): - class MyAppenderMixin(AppenderMixin): - def add(self, items): - if isinstance(items, list): - for item in items: - self.append(item) - else: - self.append(items) - - class MyQuery(Query): - pass - - class MyAppenderQuery(MyAppenderMixin, MyQuery): - query_class = MyQuery - - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses), - query_class=MyAppenderQuery) - }) - sess = create_session() - u = User() - sess.add(u) - - col = u.addresses - assert isinstance(col, Query) - assert isinstance(col, MyQuery) - assert hasattr(col, 'append') - assert hasattr(col, 'add') - assert type(col).__name__ == 'MyAppenderQuery' - - q = col.limit(1) - assert isinstance(q, Query) - assert isinstance(q, MyQuery) - assert not hasattr(q, 'append') - assert not hasattr(q, 'add') - assert type(q).__name__ == 'MyQuery' - - -class SessionTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def test_events(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses)) - }) - sess = create_session() - u1 = User(name='jack') - a1 = Address(email_address='foo') - sess.add_all([u1, a1]) - sess.flush() - - assert testing.db.scalar(select([func.count(1)]).where(addresses.c.user_id!=None)) == 0 - u1 = sess.query(User).get(u1.id) - u1.addresses.append(a1) - sess.flush() - - assert testing.db.execute(select([addresses]).where(addresses.c.user_id!=None)).fetchall() == [ - (a1.id, u1.id, 'foo') - ] - - u1.addresses.remove(a1) - sess.flush() - assert testing.db.scalar(select([func.count(1)]).where(addresses.c.user_id!=None)) == 0 - - u1.addresses.append(a1) - sess.flush() - assert testing.db.execute(select([addresses]).where(addresses.c.user_id!=None)).fetchall() == [ - (a1.id, u1.id, 'foo') - ] - - a2= Address(email_address='bar') - u1.addresses.remove(a1) - u1.addresses.append(a2) - sess.flush() - assert testing.db.execute(select([addresses]).where(addresses.c.user_id!=None)).fetchall() == [ - (a2.id, u1.id, 'bar') - ] - - - @testing.resolve_artifact_names - def test_merge(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses), order_by=addresses.c.email_address) - }) - sess = create_session() - u1 = User(name='jack') - a1 = Address(email_address='a1') - a2 = Address(email_address='a2') - a3 = Address(email_address='a3') - - u1.addresses.append(a2) - u1.addresses.append(a3) - - sess.add_all([u1, a1]) - sess.flush() - - u1 = User(id=u1.id, name='jack') - u1.addresses.append(a1) - u1.addresses.append(a3) - u1 = sess.merge(u1) - assert attributes.get_history(u1, 'addresses') == ( - [a1], - [a3], - [a2] - ) - - sess.flush() - - eq_( - list(u1.addresses), - [a1, a3] - ) - - @testing.resolve_artifact_names - def test_flush(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses)) - }) - sess = create_session() - u1 = User(name='jack') - u2 = User(name='ed') - u2.addresses.append(Address(email_address='foo@bar.com')) - u1.addresses.append(Address(email_address='lala@hoho.com')) - sess.add_all((u1, u2)) - sess.flush() - - from sqlalchemy.orm import attributes - self.assertEquals(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], [])) - - sess.expunge_all() - - # test the test fixture a little bit - assert User(name='jack', addresses=[Address(email_address='wrong')]) != sess.query(User).first() - assert User(name='jack', addresses=[Address(email_address='lala@hoho.com')]) == sess.query(User).first() - - assert [ - User(name='jack', addresses=[Address(email_address='lala@hoho.com')]), - User(name='ed', addresses=[Address(email_address='foo@bar.com')]) - ] == sess.query(User).all() - - @testing.resolve_artifact_names - def test_hasattr(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses)) - }) - u1 = User(name='jack') - - assert 'addresses' not in u1.__dict__.keys() - u1.addresses = [Address(email_address='test')] - assert 'addresses' in dir(u1) - - @testing.resolve_artifact_names - def test_collection_set(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses), order_by=addresses.c.email_address) - }) - sess = create_session(autoflush=True, autocommit=False) - u1 = User(name='jack') - a1 = Address(email_address='a1') - a2 = Address(email_address='a2') - a3 = Address(email_address='a3') - a4 = Address(email_address='a4') - - sess.add(u1) - u1.addresses = [a1, a3] - assert list(u1.addresses) == [a1, a3] - u1.addresses = [a1, a2, a4] - assert list(u1.addresses) == [a1, a2, a4] - u1.addresses = [a2, a3] - assert list(u1.addresses) == [a2, a3] - u1.addresses = [] - assert list(u1.addresses) == [] - - - - - @testing.resolve_artifact_names - def test_rollback(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses)) - }) - sess = create_session(expire_on_commit=False, autocommit=False, autoflush=True) - u1 = User(name='jack') - u1.addresses.append(Address(email_address='lala@hoho.com')) - sess.add(u1) - sess.flush() - sess.commit() - u1.addresses.append(Address(email_address='foo@bar.com')) - eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')]) - sess.rollback() - eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com')]) - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_delete_nocascade(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses), order_by=Address.id, backref='user') - }) - sess = create_session(autoflush=True) - u = User(name='ed') - u.addresses.append(Address(email_address='a')) - u.addresses.append(Address(email_address='b')) - u.addresses.append(Address(email_address='c')) - u.addresses.append(Address(email_address='d')) - u.addresses.append(Address(email_address='e')) - u.addresses.append(Address(email_address='f')) - sess.add(u) - - assert Address(email_address='c') == u.addresses[2] - sess.delete(u.addresses[2]) - sess.delete(u.addresses[4]) - sess.delete(u.addresses[3]) - assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses) - - sess.expunge_all() - u = sess.query(User).get(u.id) - - sess.delete(u) - - # u.addresses relation will have to force the load - # of all addresses so that they can be updated - sess.flush() - sess.close() - - assert testing.db.scalar(addresses.count(addresses.c.user_id != None)) ==0 - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_delete_cascade(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses), order_by=Address.id, backref='user', cascade="all, delete-orphan") - }) - sess = create_session(autoflush=True) - u = User(name='ed') - u.addresses.append(Address(email_address='a')) - u.addresses.append(Address(email_address='b')) - u.addresses.append(Address(email_address='c')) - u.addresses.append(Address(email_address='d')) - u.addresses.append(Address(email_address='e')) - u.addresses.append(Address(email_address='f')) - sess.add(u) - - assert Address(email_address='c') == u.addresses[2] - sess.delete(u.addresses[2]) - sess.delete(u.addresses[4]) - sess.delete(u.addresses[3]) - assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses) - - sess.expunge_all() - u = sess.query(User).get(u.id) - - sess.delete(u) - - # u.addresses relation will have to force the load - # of all addresses so that they can be updated - sess.flush() - sess.close() - - assert testing.db.scalar(addresses.count()) ==0 - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_remove_orphans(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses), order_by=Address.id, cascade="all, delete-orphan", backref='user') - }) - sess = create_session(autoflush=True) - u = User(name='ed') - u.addresses.append(Address(email_address='a')) - u.addresses.append(Address(email_address='b')) - u.addresses.append(Address(email_address='c')) - u.addresses.append(Address(email_address='d')) - u.addresses.append(Address(email_address='e')) - u.addresses.append(Address(email_address='f')) - sess.add(u) - - assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='c'), - Address(email_address='d'), Address(email_address='e'), Address(email_address='f')] == sess.query(Address).all() - - assert Address(email_address='c') == u.addresses[2] - - try: - del u.addresses[3] - assert False - except TypeError, e: - assert "doesn't support item deletion" in str(e), str(e) - - for a in u.addresses.filter(Address.email_address.in_(['c', 'e', 'f'])): - u.addresses.remove(a) - - assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses) - - assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == sess.query(Address).all() - - sess.delete(u) - sess.close() - - -def create_backref_test(autoflush, saveuser): - - @testing.resolve_artifact_names - def test_backref(self): - mapper(User, users, properties={ - 'addresses':dynamic_loader(mapper(Address, addresses), backref='user') - }) - sess = create_session(autoflush=autoflush) - - u = User(name='buffy') - - a = Address(email_address='foo@bar.com') - a.user = u - - if saveuser: - sess.add(u) - else: - sess.add(a) - - if not autoflush: - sess.flush() - - assert u in sess - assert a in sess - - self.assert_(list(u.addresses) == [a]) - - a.user = None - if not autoflush: - self.assert_(list(u.addresses) == [a]) - - if not autoflush: - sess.flush() - self.assert_(list(u.addresses) == []) - - test_backref = _function_named( - test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""), - (saveuser and "_saveuser" or "_savead"))) - setattr(SessionTest, test_backref.__name__, test_backref) - -for autoflush in (False, True): - for saveuser in (False, True): - create_backref_test(autoflush, saveuser) - -class DontDereferenceTest(_base.MappedTest): - def define_tables(self, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(40)), - Column('fullname', String(100)), - Column('password', String(15))) - - Table('addresses', metadata, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - Column('user_id', Integer, ForeignKey('users.id'))) - - @testing.resolve_artifact_names - def setup_mappers(self): - class User(_base.ComparableEntity): - pass - - class Address(_base.ComparableEntity): - pass - - mapper(User, users, properties={ - 'addresses': relation(Address, backref='user', lazy='dynamic') - }) - mapper(Address, addresses) - - @testing.resolve_artifact_names - def test_no_deref(self): - session = create_session() - user = User() - user.name = 'joe' - user.fullname = 'Joe User' - user.password = 'Joe\'s secret' - address = Address() - address.email_address = 'joe@joesdomain.example' - address.user = user - session.add(user) - session.flush() - session.expunge_all() - - def query1(): - session = create_session(testing.db) - user = session.query(User).first() - return user.addresses.all() - - def query2(): - session = create_session(testing.db) - return session.query(User).first().addresses.all() - - def query3(): - session = create_session(testing.db) - user = session.query(User).first() - return session.query(User).first().addresses.all() - - eq_(query1(), [Address(email_address='joe@joesdomain.example')]) - eq_(query2(), [Address(email_address='joe@joesdomain.example')]) - eq_(query3(), [Address(email_address='joe@joesdomain.example')]) - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py deleted file mode 100644 index 87c2442cc..000000000 --- a/test/orm/eager_relations.py +++ /dev/null @@ -1,1595 +0,0 @@ -"""basic tests of eager loaded attributes""" - -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from sqlalchemy.orm import eagerload, deferred, undefer -from testlib.sa import Table, Column, Integer, String, Date, ForeignKey, and_, select, func -from testlib.sa.orm import mapper, relation, create_session, lazyload, aliased -from testlib.testing import eq_ -from testlib.assertsql import CompiledSQL -from orm import _base, _fixtures -import datetime - -class EagerTest(_fixtures.FixtureTest): - run_inserts = 'once' - run_deletes = None - - @testing.resolve_artifact_names - def test_basic(self): - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=Address.id) - }) - sess = create_session() - q = sess.query(User) - - assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all() - self.assertEquals(self.static.user_address_result, q.order_by(User.id).all()) - - @testing.resolve_artifact_names - def test_late_compile(self): - m = mapper(User, users) - sess = create_session() - sess.query(User).all() - m.add_property("addresses", relation(mapper(Address, addresses))) - - sess.expunge_all() - def go(): - eq_( - [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])], - sess.query(User).options(eagerload('addresses')).filter(User.id==7).all() - ) - self.assert_sql_count(testing.db, go, 1) - - - @testing.resolve_artifact_names - def test_no_orphan(self): - """An eagerly loaded child object is not marked as an orphan""" - mapper(User, users, properties={ - 'addresses':relation(Address, cascade="all,delete-orphan", lazy=False) - }) - mapper(Address, addresses) - - sess = create_session() - user = sess.query(User).get(7) - assert getattr(User, 'addresses').hasparent(sa.orm.attributes.instance_state(user.addresses[0]), optimistic=True) - assert not sa.orm.class_mapper(Address)._is_orphan(sa.orm.attributes.instance_state(user.addresses[0])) - - @testing.resolve_artifact_names - def test_orderby(self): - mapper(User, users, properties = { - 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.email_address), - }) - q = create_session().query(User) - assert [ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=2, email_address='ed@wood.com') - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ] == q.order_by(User.id).all() - - @testing.resolve_artifact_names - def test_orderby_multi(self): - mapper(User, users, properties = { - 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=[addresses.c.email_address, addresses.c.id]), - }) - q = create_session().query(User) - assert [ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=2, email_address='ed@wood.com') - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ] == q.order_by(User.id).all() - - @testing.resolve_artifact_names - def test_orderby_related(self): - """A regular mapper select on a single table can order by a relation to a second table""" - mapper(Address, addresses) - mapper(User, users, properties = dict( - addresses = relation(Address, lazy=False, order_by=addresses.c.id), - )) - - q = create_session().query(User) - l = q.filter(User.id==Address.user_id).order_by(Address.email_address).all() - - assert [ - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=7, addresses=[ - Address(id=1) - ]), - ] == l - - @testing.resolve_artifact_names - def test_orderby_desc(self): - mapper(Address, addresses) - mapper(User, users, properties = dict( - addresses = relation(Address, lazy=False, - order_by=[sa.desc(addresses.c.email_address)]), - )) - sess = create_session() - assert [ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=3, email_address='ed@bettyboop.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ] == sess.query(User).order_by(User.id).all() - - @testing.resolve_artifact_names - def test_deferred_fk_col(self): - User, Address, Dingaling = self.classes.get_all( - 'User', 'Address', 'Dingaling') - users, addresses, dingalings = self.tables.get_all( - 'users', 'addresses', 'dingalings') - - mapper(Address, addresses, properties={ - 'user_id':deferred(addresses.c.user_id), - 'user':relation(User, lazy=False) - }) - mapper(User, users) - - sess = create_session() - - for q in [ - sess.query(Address).filter(Address.id.in_([1, 4, 5])), - sess.query(Address).filter(Address.id.in_([1, 4, 5])).limit(3) - ]: - sess.expunge_all() - eq_(q.all(), - [Address(id=1, user=User(id=7)), - Address(id=4, user=User(id=8)), - Address(id=5, user=User(id=9))] - ) - - a = sess.query(Address).filter(Address.id==1).first() - def go(): - eq_(a.user_id, 7) - # assert that the eager loader added 'user_id' to the row and deferred - # loading of that col was disabled - self.assert_sql_count(testing.db, go, 0) - - # do the mapping in reverse - # (we would have just used an "addresses" backref but the test - # fixtures then require the whole backref to be set up, lazy loaders - # trigger, etc.) - sa.orm.clear_mappers() - - mapper(Address, addresses, properties={ - 'user_id':deferred(addresses.c.user_id), - }) - mapper(User, users, properties={ - 'addresses':relation(Address, lazy=False)}) - - for q in [ - sess.query(User).filter(User.id==7), - sess.query(User).filter(User.id==7).limit(1) - ]: - sess.expunge_all() - eq_(q.all(), - [User(id=7, addresses=[Address(id=1)])] - ) - - sess.expunge_all() - u = sess.query(User).get(7) - def go(): - assert u.addresses[0].user_id==7 - # assert that the eager loader didn't have to affect 'user_id' here - # and that its still deferred - self.assert_sql_count(testing.db, go, 1) - - sa.orm.clear_mappers() - - mapper(User, users, properties={ - 'addresses':relation(Address, lazy=False)}) - mapper(Address, addresses, properties={ - 'user_id':deferred(addresses.c.user_id), - 'dingalings':relation(Dingaling, lazy=False)}) - mapper(Dingaling, dingalings, properties={ - 'address_id':deferred(dingalings.c.address_id)}) - sess.expunge_all() - def go(): - u = sess.query(User).get(8) - eq_(User(id=8, - addresses=[Address(id=2, dingalings=[Dingaling(id=1)]), - Address(id=3), - Address(id=4)]), - u) - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_many_to_many(self): - Keyword, Item = self.Keyword, self.Item - keywords, item_keywords, items = self.tables.get_all( - 'keywords', 'item_keywords', 'items') - - mapper(Keyword, keywords) - mapper(Item, items, properties = dict( - keywords = relation(Keyword, secondary=item_keywords, - lazy=False, order_by=keywords.c.id))) - - q = create_session().query(Item).order_by(Item.id) - def go(): - assert self.static.item_keyword_result == q.all() - self.assert_sql_count(testing.db, go, 1) - - def go(): - eq_(self.static.item_keyword_result[0:2], - q.join('keywords').filter(Keyword.name == 'red').all()) - self.assert_sql_count(testing.db, go, 1) - - def go(): - eq_(self.static.item_keyword_result[0:2], - (q.join('keywords', aliased=True). - filter(Keyword.name == 'red')).all()) - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_eager_option(self): - Keyword, Item = self.Keyword, self.Item - keywords, item_keywords, items = self.tables.get_all( - 'keywords', 'item_keywords', 'items') - - mapper(Keyword, keywords) - mapper(Item, items, properties = dict( - keywords = relation(Keyword, secondary=item_keywords, lazy=True, - order_by=keywords.c.id))) - - q = create_session().query(Item) - - def go(): - eq_(self.static.item_keyword_result[0:2], - (q.options(eagerload('keywords')). - join('keywords').filter(keywords.c.name == 'red')).order_by(Item.id).all()) - - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_cyclical(self): - """A circular eager relationship breaks the cycle with a lazy loader""" - User, Address = self.User, self.Address - users, addresses = self.tables.get_all('users', 'addresses') - - mapper(Address, addresses) - mapper(User, users, properties = dict( - addresses = relation(Address, lazy=False, - backref=sa.orm.backref('user', lazy=False), order_by=Address.id) - )) - assert sa.orm.class_mapper(User).get_property('addresses').lazy is False - assert sa.orm.class_mapper(Address).get_property('user').lazy is False - - sess = create_session() - self.assertEquals(self.static.user_address_result, sess.query(User).order_by(User.id).all()) - - @testing.resolve_artifact_names - def test_double(self): - """Eager loading with two relations simultaneously, from the same table, using aliases.""" - User, Address, Order = self.classes.get_all( - 'User', 'Address', 'Order') - users, addresses, orders = self.tables.get_all( - 'users', 'addresses', 'orders') - - openorders = sa.alias(orders, 'openorders') - closedorders = sa.alias(orders, 'closedorders') - - mapper(Address, addresses) - mapper(Order, orders) - - open_mapper = mapper(Order, openorders, non_primary=True) - closed_mapper = mapper(Order, closedorders, non_primary=True) - - mapper(User, users, properties = dict( - addresses = relation(Address, lazy=False, order_by=addresses.c.id), - open_orders = relation( - open_mapper, - primaryjoin=sa.and_(openorders.c.isopen == 1, - users.c.id==openorders.c.user_id), - lazy=False, order_by=openorders.c.id), - closed_orders = relation( - closed_mapper, - primaryjoin=sa.and_(closedorders.c.isopen == 0, - users.c.id==closedorders.c.user_id), - lazy=False, order_by=closedorders.c.id))) - - q = create_session().query(User).order_by(User.id) - - def go(): - assert [ - User( - id=7, - addresses=[Address(id=1)], - open_orders = [Order(id=3)], - closed_orders = [Order(id=1), Order(id=5)] - ), - User( - id=8, - addresses=[Address(id=2), Address(id=3), Address(id=4)], - open_orders = [], - closed_orders = [] - ), - User( - id=9, - addresses=[Address(id=5)], - open_orders = [Order(id=4)], - closed_orders = [Order(id=2)] - ), - User(id=10) - - ] == q.all() - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_double_same_mappers(self): - """Eager loading with two relations simulatneously, from the same table, using aliases.""" - User, Address, Order = self.classes.get_all( - 'User', 'Address', 'Order') - users, addresses, orders = self.tables.get_all( - 'users', 'addresses', 'orders') - - mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items': relation(Item, secondary=order_items, lazy=False, - order_by=items.c.id)}) - mapper(Item, items) - mapper(User, users, properties=dict( - addresses=relation(Address, lazy=False, order_by=addresses.c.id), - open_orders=relation( - Order, - primaryjoin=sa.and_(orders.c.isopen == 1, - users.c.id==orders.c.user_id), - lazy=False, order_by=orders.c.id), - closed_orders=relation( - Order, - primaryjoin=sa.and_(orders.c.isopen == 0, - users.c.id==orders.c.user_id), - lazy=False, order_by=orders.c.id))) - q = create_session().query(User).order_by(User.id) - - def go(): - assert [ - User(id=7, - addresses=[ - Address(id=1)], - open_orders=[Order(id=3, - items=[ - Item(id=3), - Item(id=4), - Item(id=5)])], - closed_orders=[Order(id=1, - items=[ - Item(id=1), - Item(id=2), - Item(id=3)]), - Order(id=5, - items=[ - Item(id=5)])]), - User(id=8, - addresses=[ - Address(id=2), - Address(id=3), - Address(id=4)], - open_orders = [], - closed_orders = []), - User(id=9, - addresses=[ - Address(id=5)], - open_orders=[ - Order(id=4, - items=[ - Item(id=1), - Item(id=5)])], - closed_orders=[ - Order(id=2, - items=[ - Item(id=1), - Item(id=2), - Item(id=3)])]), - User(id=10) - ] == q.all() - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_no_false_hits(self): - """Eager loaders don't interpret main table columns as part of their eager load.""" - User, Address, Order = self.classes.get_all( - 'User', 'Address', 'Order') - users, addresses, orders = self.tables.get_all( - 'users', 'addresses', 'orders') - - mapper(User, users, properties={ - 'addresses':relation(Address, lazy=False), - 'orders':relation(Order, lazy=False) - }) - mapper(Address, addresses) - mapper(Order, orders) - - allusers = create_session().query(User).all() - - # using a textual select, the columns will be 'id' and 'name'. the - # eager loaders have aliases which should not hit on those columns, - # they should be required to locate only their aliased/fully table - # qualified column name. - noeagers = create_session().query(User).from_statement("select * from users").all() - assert 'orders' not in noeagers[0].__dict__ - assert 'addresses' not in noeagers[0].__dict__ - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_limit(self): - """Limit operations combined with lazy-load relationships.""" - User, Item, Address, Order = self.classes.get_all( - 'User', 'Item', 'Address', 'Order') - users, items, order_items, orders, addresses = self.tables.get_all( - 'users', 'items', 'order_items', 'orders', 'addresses') - - mapper(Item, items) - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id) - }) - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id), - 'orders':relation(Order, lazy=True) - }) - - sess = create_session() - q = sess.query(User) - - if testing.against('mysql'): - l = q.limit(2).all() - assert self.static.user_all_result[:2] == l - else: - l = q.order_by(User.id).limit(2).offset(1).all() - print self.static.user_all_result[1:3] - print l - assert self.static.user_all_result[1:3] == l - - @testing.resolve_artifact_names - def test_distinct(self): - # this is an involved 3x union of the users table to get a lot of rows. - # then see if the "distinct" works its way out. you actually get the same - # result with or without the distinct, just via less or more rows. - u2 = users.alias('u2') - s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') - - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), lazy=False), - }) - - sess = create_session() - q = sess.query(User) - - def go(): - l = q.filter(s.c.u2_id==User.id).distinct().all() - assert self.static.user_address_result == l - self.assert_sql_count(testing.db, go, 1) - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_limit_2(self): - mapper(Keyword, keywords) - mapper(Item, items, properties = dict( - keywords = relation(Keyword, secondary=item_keywords, lazy=False, order_by=[keywords.c.id]), - )) - - sess = create_session() - q = sess.query(Item) - l = q.filter((Item.description=='item 2') | (Item.description=='item 5') | (Item.description=='item 3')).\ - order_by(Item.id).limit(2).all() - - assert self.static.item_keyword_result[1:3] == l - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_limit_3(self): - """test that the ORDER BY is propagated from the inner select to the outer select, when using the - 'wrapped' select statement resulting from the combination of eager loading and limit/offset clauses.""" - - mapper(Item, items) - mapper(Order, orders, properties = dict( - items = relation(Item, secondary=order_items, lazy=False) - )) - - mapper(Address, addresses) - mapper(User, users, properties = dict( - addresses = relation(Address, lazy=False, order_by=addresses.c.id), - orders = relation(Order, lazy=False, order_by=orders.c.id), - )) - sess = create_session() - - q = sess.query(User) - - if not testing.against('maxdb', 'mssql'): - l = q.join('orders').order_by(Order.user_id.desc()).limit(2).offset(1) - assert [ - User(id=9, - orders=[Order(id=2), Order(id=4)], - addresses=[Address(id=5)] - ), - User(id=7, - orders=[Order(id=1), Order(id=3), Order(id=5)], - addresses=[Address(id=1)] - ) - ] == l.all() - - l = q.join('addresses').order_by(Address.email_address.desc()).limit(1).offset(0) - assert [ - User(id=7, - orders=[Order(id=1), Order(id=3), Order(id=5)], - addresses=[Address(id=1)] - ) - ] == l.all() - - @testing.resolve_artifact_names - def test_limit_4(self): - # tests the LIMIT/OFFSET aliasing on a mapper against a select. original issue from ticket #904 - sel = sa.select([users, addresses.c.email_address], users.c.id==addresses.c.user_id).alias('useralias') - mapper(User, sel, properties={ - 'orders':relation(Order, primaryjoin=sel.c.id==orders.c.user_id, lazy=False) - }) - mapper(Order, orders) - - sess = create_session() - eq_(sess.query(User).first(), - User(name=u'jack',orders=[ - Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), - Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), - Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5)], - email_address=u'jack@bean.com',id=7) - ) - - @testing.resolve_artifact_names - def test_one_to_many_scalar(self): - mapper(User, users, properties = dict( - address = relation(mapper(Address, addresses), lazy=False, uselist=False) - )) - q = create_session().query(User) - - def go(): - l = q.filter(users.c.id == 7).all() - assert [User(id=7, address=Address(id=1))] == l - self.assert_sql_count(testing.db, go, 1) - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_many_to_one(self): - mapper(Address, addresses, properties = dict( - user = relation(mapper(User, users), lazy=False) - )) - sess = create_session() - q = sess.query(Address) - - def go(): - a = q.filter(addresses.c.id==1).one() - assert a.user is not None - u1 = sess.query(User).get(7) - assert a.user is u1 - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_many_to_one_null(self): - """test that a many-to-one eager load which loads None does - not later trigger a lazy load. - - """ - - # use a primaryjoin intended to defeat SA's usage of - # query.get() for a many-to-one lazyload - mapper(Order, orders, properties = dict( - address = relation(mapper(Address, addresses), - primaryjoin=and_( - addresses.c.id==orders.c.address_id, - addresses.c.email_address != None - ), - - lazy=False) - )) - sess = create_session() - - def go(): - o1 = sess.query(Order).options(lazyload('address')).filter(Order.id==5).one() - self.assertEquals(o1.address, None) - self.assert_sql_count(testing.db, go, 2) - - sess.expunge_all() - def go(): - o1 = sess.query(Order).filter(Order.id==5).one() - self.assertEquals(o1.address, None) - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_one_and_many(self): - """tests eager load for a parent object with a child object that - contains a many-to-many relationship to a third object.""" - - mapper(User, users, properties={ - 'orders':relation(Order, lazy=False, order_by=orders.c.id) - }) - mapper(Item, items) - mapper(Order, orders, properties = dict( - items = relation(Item, secondary=order_items, lazy=False, order_by=items.c.id) - )) - - q = create_session().query(User) - - l = q.filter("users.id in (7, 8, 9)").order_by("users.id") - - def go(): - assert self.static.user_order_result[0:3] == l.all() - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_double_with_aggregate(self): - max_orders_by_user = sa.select([sa.func.max(orders.c.id).label('order_id')], group_by=[orders.c.user_id]).alias('max_orders_by_user') - - max_orders = orders.select(orders.c.id==max_orders_by_user.c.order_id).alias('max_orders') - - mapper(Order, orders) - mapper(User, users, properties={ - 'orders':relation(Order, backref='user', lazy=False), - 'max_order':relation(mapper(Order, max_orders, non_primary=True), lazy=False, uselist=False) - }) - q = create_session().query(User) - - def go(): - assert [ - User(id=7, orders=[ - Order(id=1), - Order(id=3), - Order(id=5), - ], - max_order=Order(id=5) - ), - User(id=8, orders=[]), - User(id=9, orders=[Order(id=2),Order(id=4)], - max_order=Order(id=4) - ), - User(id=10), - ] == q.all() - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_wide(self): - mapper(Order, orders, properties={'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id)}) - mapper(Item, items) - mapper(User, users, properties = dict( - addresses = relation(mapper(Address, addresses), lazy = False, order_by=addresses.c.id), - orders = relation(Order, lazy = False, order_by=orders.c.id), - )) - q = create_session().query(User) - l = q.all() - assert self.static.user_all_result == q.order_by(User.id).all() - - @testing.resolve_artifact_names - def test_against_select(self): - """test eager loading of a mapper which is against a select""" - - s = sa.select([orders], orders.c.isopen==1).alias('openorders') - - mapper(Order, s, properties={ - 'user':relation(User, lazy=False) - }) - mapper(User, users) - mapper(Item, items) - - q = create_session().query(Order) - assert [ - Order(id=3, user=User(id=7)), - Order(id=4, user=User(id=9)) - ] == q.all() - - q = q.select_from(s.join(order_items).join(items)).filter(~Item.id.in_([1, 2, 5])) - assert [ - Order(id=3, user=User(id=7)), - ] == q.all() - - @testing.resolve_artifact_names - def test_aliasing(self): - """test that eager loading uses aliases to insulate the eager load from regular criterion against those tables.""" - - mapper(User, users, properties = dict( - addresses = relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id) - )) - q = create_session().query(User) - l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id).order_by(User.id) - assert self.static.user_address_result[1:2] == l.all() - -class AddEntityTest(_fixtures.FixtureTest): - run_inserts = 'once' - run_deletes = None - - @testing.resolve_artifact_names - def _assert_result(self): - return [ - ( - User(id=7, - addresses=[Address(id=1)] - ), - Order(id=1, - items=[Item(id=1), Item(id=2), Item(id=3)] - ), - ), - ( - User(id=7, - addresses=[Address(id=1)] - ), - Order(id=3, - items=[Item(id=3), Item(id=4), Item(id=5)] - ), - ), - ( - User(id=7, - addresses=[Address(id=1)] - ), - Order(id=5, - items=[Item(id=5)] - ), - ), - ( - User(id=9, - addresses=[Address(id=5)] - ), - Order(id=2, - items=[Item(id=1), Item(id=2), Item(id=3)] - ), - ), - ( - User(id=9, - addresses=[Address(id=5)] - ), - Order(id=4, - items=[Item(id=1), Item(id=5)] - ), - ) - ] - - @testing.resolve_artifact_names - def test_mapper_configured(self): - mapper(User, users, properties={ - 'addresses':relation(Address, lazy=False), - 'orders':relation(Order) - }) - mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id) - }) - mapper(Item, items) - - - sess = create_session() - oalias = sa.orm.aliased(Order) - def go(): - ret = sess.query(User, oalias).join(('orders', oalias)).order_by(User.id, oalias.id).all() - eq_(ret, self._assert_result()) - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_options(self): - mapper(User, users, properties={ - 'addresses':relation(Address), - 'orders':relation(Order) - }) - mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, order_by=items.c.id) - }) - mapper(Item, items) - - sess = create_session() - - oalias = sa.orm.aliased(Order) - def go(): - ret = sess.query(User, oalias).options(eagerload('addresses')).join(('orders', oalias)).order_by(User.id, oalias.id).all() - eq_(ret, self._assert_result()) - self.assert_sql_count(testing.db, go, 6) - - sess.expunge_all() - def go(): - ret = sess.query(User, oalias).options(eagerload('addresses'), eagerload(oalias.items)).join(('orders', oalias)).order_by(User.id, oalias.id).all() - eq_(ret, self._assert_result()) - self.assert_sql_count(testing.db, go, 1) - -class OrderBySecondaryTest(_base.MappedTest): - def define_tables(self, metadata): - Table('m2m', metadata, - Column('id', Integer, primary_key=True), - Column('aid', Integer, ForeignKey('a.id')), - Column('bid', Integer, ForeignKey('b.id'))) - - Table('a', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) - Table('b', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) - - def fixtures(self): - return dict( - a=(('id', 'data'), - (1, 'a1'), - (2, 'a2')), - - b=(('id', 'data'), - (1, 'b1'), - (2, 'b2'), - (3, 'b3'), - (4, 'b4')), - - m2m=(('id', 'aid', 'bid'), - (2, 1, 1), - (4, 2, 4), - (1, 1, 3), - (6, 2, 2), - (3, 1, 2), - (5, 2, 3))) - - @testing.resolve_artifact_names - def test_ordering(self): - class A(_base.ComparableEntity):pass - class B(_base.ComparableEntity):pass - - mapper(A, a, properties={ - 'bs':relation(B, secondary=m2m, lazy=False, order_by=m2m.c.id) - }) - mapper(B, b) - - sess = create_session() - eq_(sess.query(A).all(), [A(data='a1', bs=[B(data='b3'), B(data='b1'), B(data='b2')]), A(bs=[B(data='b4'), B(data='b3'), B(data='b2')])]) - - -class SelfReferentialEagerTest(_base.MappedTest): - def define_tables(self, metadata): - Table('nodes', metadata, - Column('id', Integer, sa.Sequence('node_id_seq', optional=True), - primary_key=True), - Column('parent_id', Integer, ForeignKey('nodes.id')), - Column('data', String(30))) - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_basic(self): - class Node(_base.ComparableEntity): - def append(self, node): - self.children.append(node) - - mapper(Node, nodes, properties={ - 'children':relation(Node, lazy=False, join_depth=3, order_by=nodes.c.id) - }) - sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) - sess.add(n1) - sess.flush() - sess.expunge_all() - def go(): - d = sess.query(Node).filter_by(data='n1').all()[0] - assert Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]) == d - self.assert_sql_count(testing.db, go, 1) - - sess.expunge_all() - def go(): - d = sess.query(Node).filter_by(data='n1').first() - assert Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]) == d - self.assert_sql_count(testing.db, go, 1) - - - @testing.resolve_artifact_names - def test_lazy_fallback_doesnt_affect_eager(self): - class Node(_base.ComparableEntity): - def append(self, node): - self.children.append(node) - - mapper(Node, nodes, properties={ - 'children':relation(Node, lazy=False, join_depth=1, order_by=nodes.c.id) - }) - sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) - sess.add(n1) - sess.flush() - sess.expunge_all() - - # eager load with join depth 1. when eager load of 'n1' hits the - # children of 'n12', no columns are present, eager loader degrades to - # lazy loader; fine. but then, 'n12' is *also* in the first level of - # columns since we're loading the whole table. when those rows - # arrive, now we *can* eager load its children and an eager collection - # should be initialized. essentially the 'n12' instance is present in - # not just two different rows but two distinct sets of columns in this - # result set. - def go(): - allnodes = sess.query(Node).order_by(Node.data).all() - n12 = allnodes[2] - assert n12.data == 'n12' - assert [ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ] == list(n12.children) - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_with_deferred(self): - class Node(_base.ComparableEntity): - def append(self, node): - self.children.append(node) - - mapper(Node, nodes, properties={ - 'children':relation(Node, lazy=False, join_depth=3, order_by=nodes.c.id), - 'data':deferred(nodes.c.data) - }) - sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - sess.add(n1) - sess.flush() - sess.expunge_all() - - def go(): - self.assertEquals( - Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), - sess.query(Node).order_by(Node.id).first(), - ) - self.assert_sql_count(testing.db, go, 4) - - sess.expunge_all() - - def go(): - assert Node(data='n1', children=[Node(data='n11'), Node(data='n12')]) == sess.query(Node).options(undefer('data')).order_by(Node.id).first() - self.assert_sql_count(testing.db, go, 3) - - sess.expunge_all() - - def go(): - assert Node(data='n1', children=[Node(data='n11'), Node(data='n12')]) == sess.query(Node).options(undefer('data'), undefer('children.data')).first() - self.assert_sql_count(testing.db, go, 1) - - - @testing.resolve_artifact_names - def test_options(self): - class Node(_base.ComparableEntity): - def append(self, node): - self.children.append(node) - - mapper(Node, nodes, properties={ - 'children':relation(Node, lazy=True, order_by=nodes.c.id) - }, order_by=nodes.c.id) - sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) - sess.add(n1) - sess.flush() - sess.expunge_all() - def go(): - d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first() - assert Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]) == d - self.assert_sql_count(testing.db, go, 2) - - def go(): - d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first() - - # test that the query isn't wrapping the initial query for eager loading. - self.assert_sql_execution(testing.db, go, - CompiledSQL( - "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes " - "WHERE nodes.data = :data_1 ORDER BY nodes.id LIMIT 1 OFFSET 0", - {'data_1': 'n1'} - ) - ) - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_no_depth(self): - class Node(_base.ComparableEntity): - def append(self, node): - self.children.append(node) - - mapper(Node, nodes, properties={ - 'children':relation(Node, lazy=False) - }) - sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) - sess.add(n1) - sess.flush() - sess.expunge_all() - def go(): - d = sess.query(Node).filter_by(data='n1').first() - assert Node(data='n1', children=[ - Node(data='n11'), - Node(data='n12', children=[ - Node(data='n121'), - Node(data='n122'), - Node(data='n123') - ]), - Node(data='n13') - ]) == d - self.assert_sql_count(testing.db, go, 3) - -class MixedSelfReferentialEagerTest(_base.MappedTest): - def define_tables(self, metadata): - Table('a_table', metadata, - Column('id', Integer, primary_key=True) - ) - - Table('b_table', metadata, - Column('id', Integer, primary_key=True), - Column('parent_b1_id', Integer, ForeignKey('b_table.id')), - Column('parent_a_id', Integer, ForeignKey('a_table.id')), - Column('parent_b2_id', Integer, ForeignKey('b_table.id'))) - - - @testing.resolve_artifact_names - def setup_mappers(self): - class A(_base.ComparableEntity): - pass - class B(_base.ComparableEntity): - pass - - mapper(A,a_table) - mapper(B,b_table,properties = { - 'parent_b1': relation(B, - remote_side = [b_table.c.id], - primaryjoin = (b_table.c.parent_b1_id ==b_table.c.id), - order_by = b_table.c.id - ), - 'parent_z': relation(A,lazy = True), - 'parent_b2': relation(B, - remote_side = [b_table.c.id], - primaryjoin = (b_table.c.parent_b2_id ==b_table.c.id), - order_by = b_table.c.id - ) - }); - - @testing.resolve_artifact_names - def insert_data(self): - a_table.insert().execute(dict(id=1), dict(id=2), dict(id=3)) - b_table.insert().execute( - dict(id=1, parent_a_id=2, parent_b1_id=None, parent_b2_id=None), - dict(id=2, parent_a_id=1, parent_b1_id=1, parent_b2_id=None), - dict(id=3, parent_a_id=1, parent_b1_id=1, parent_b2_id=2), - dict(id=4, parent_a_id=3, parent_b1_id=1, parent_b2_id=None), - dict(id=5, parent_a_id=3, parent_b1_id=None, parent_b2_id=2), - dict(id=6, parent_a_id=1, parent_b1_id=1, parent_b2_id=3), - dict(id=7, parent_a_id=2, parent_b1_id=None, parent_b2_id=3), - dict(id=8, parent_a_id=2, parent_b1_id=1, parent_b2_id=2), - dict(id=9, parent_a_id=None, parent_b1_id=1, parent_b2_id=None), - dict(id=10, parent_a_id=3, parent_b1_id=7, parent_b2_id=2), - dict(id=11, parent_a_id=3, parent_b1_id=1, parent_b2_id=8), - dict(id=12, parent_a_id=2, parent_b1_id=5, parent_b2_id=2), - dict(id=13, parent_a_id=3, parent_b1_id=4, parent_b2_id=4), - dict(id=14, parent_a_id=3, parent_b1_id=7, parent_b2_id=2), - ) - - @testing.resolve_artifact_names - def test_eager_load(self): - session = create_session() - def go(): - eq_( - session.query(B).options(eagerload('parent_b1'),eagerload('parent_b2'),eagerload('parent_z')). - filter(B.id.in_([2, 8, 11])).order_by(B.id).all(), - [ - B(id=2, parent_z=A(id=1), parent_b1=B(id=1), parent_b2=None), - B(id=8, parent_z=A(id=2), parent_b1=B(id=1), parent_b2=B(id=2)), - B(id=11, parent_z=A(id=3), parent_b1=B(id=1), parent_b2=B(id=8)) - ] - ) - self.assert_sql_count(testing.db, go, 1) - -class SelfReferentialM2MEagerTest(_base.MappedTest): - def define_tables(self, metadata): - Table('widget', metadata, - Column('id', Integer, primary_key=True), - Column('name', sa.Unicode(40), nullable=False, unique=True), - ) - - Table('widget_rel', metadata, - Column('parent_id', Integer, ForeignKey('widget.id')), - Column('child_id', Integer, ForeignKey('widget.id')), - sa.UniqueConstraint('parent_id', 'child_id'), - ) - - @testing.resolve_artifact_names - def test_basic(self): - class Widget(_base.ComparableEntity): - pass - - mapper(Widget, widget, properties={ - 'children': relation(Widget, secondary=widget_rel, - primaryjoin=widget_rel.c.parent_id==widget.c.id, - secondaryjoin=widget_rel.c.child_id==widget.c.id, - lazy=False, join_depth=1, - ) - }) - - sess = create_session() - w1 = Widget(name=u'w1') - w2 = Widget(name=u'w2') - w1.children.append(w2) - sess.add(w1) - sess.flush() - sess.expunge_all() - - assert [Widget(name='w1', children=[Widget(name='w2')])] == sess.query(Widget).filter(Widget.name==u'w1').all() - -class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): - run_setup_mappers = 'once' - run_inserts = 'once' - run_deletes = None - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(User, users, properties={ - 'addresses':relation(Address, backref='user'), - 'orders':relation(Order, backref='user'), # o2m, m2o - }) - mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m - }) - mapper(Item, items, properties={ - 'keywords':relation(Keyword, secondary=item_keywords) #m2m - }) - mapper(Keyword, keywords) - - @testing.resolve_artifact_names - def test_two_entities(self): - sess = create_session() - - # two FROM clauses - def go(): - eq_( - [ - (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])), - (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])), - ], - sess.query(User, Order).filter(User.id==Order.user_id).\ - options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).\ - order_by(User.id, Order.id).all(), - ) - self.assert_sql_count(testing.db, go, 1) - - # one FROM clause - def go(): - eq_( - [ - (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])), - (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])), - ], - sess.query(User, Order).join(User.orders).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).\ - order_by(User.id, Order.id).all(), - ) - self.assert_sql_count(testing.db, go, 1) - - @testing.exclude('sqlite', '>', (0, 0, 0), "sqlite flat out blows it on the multiple JOINs") - @testing.resolve_artifact_names - def test_two_entities_with_joins(self): - sess = create_session() - - # two FROM clauses where there's a join on each one - def go(): - u1 = aliased(User) - o1 = aliased(Order) - eq_( - [ - ( - User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), - Order(description=u'order 2', isopen=0, items=[Item(description=u'item 1'), Item(description=u'item 2'), Item(description=u'item 3')]), - User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), - Order(description=u'order 3', isopen=1, items=[Item(description=u'item 3'), Item(description=u'item 4'), Item(description=u'item 5')]) - ), - - ( - User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), - Order(description=u'order 2', isopen=0, items=[Item(description=u'item 1'), Item(description=u'item 2'), Item(description=u'item 3')]), - User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), - Order(address_id=None, description=u'order 5', isopen=0, items=[Item(description=u'item 5')]) - ), - - ( - User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), - Order(description=u'order 4', isopen=1, items=[Item(description=u'item 1'), Item(description=u'item 5')]), - User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), - Order(address_id=None, description=u'order 5', isopen=0, items=[Item(description=u'item 5')]) - ), - ], - sess.query(User, Order, u1, o1).\ - join((Order, User.orders)).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).\ - join((o1, u1.orders)).options(eagerload(u1.addresses), eagerload(o1.items)).filter(u1.id==7).\ - filter(Order.id 10) - assert res2.count() == 19 - - @testing.resolve_artifact_names - def test_options(self): - query = create_session().query(Foo) - class ext1(sa.orm.MapperExtension): - def populate_instance(self, mapper, selectcontext, row, instance, **flags): - instance.TEST = "hello world" - return sa.orm.EXT_CONTINUE - assert query.options(sa.orm.extension(ext1()))[0].TEST == "hello world" - - @testing.resolve_artifact_names - def test_order_by(self): - query = create_session().query(Foo) - assert query.order_by([Foo.bar])[0].bar == 0 - assert query.order_by([sa.desc(Foo.bar)])[0].bar == 99 - - @testing.resolve_artifact_names - def test_offset(self): - query = create_session().query(Foo) - assert list(query.order_by([Foo.bar]).offset(10))[0].bar == 10 - - @testing.resolve_artifact_names - def test_offset(self): - query = create_session().query(Foo) - assert len(list(query.limit(10))) == 10 - - -class GenerativeTest2(_base.MappedTest): - - def define_tables(self, metadata): - Table('Table1', metadata, - Column('id', Integer, primary_key=True)) - Table('Table2', metadata, - Column('t1id', Integer, ForeignKey("Table1.id"), - primary_key=True), - Column('num', Integer, primary_key=True)) - - @testing.resolve_artifact_names - def setup_mappers(self): - class Obj1(_base.BasicEntity): - pass - class Obj2(_base.BasicEntity): - pass - - mapper(Obj1, Table1) - mapper(Obj2, Table2) - - def fixtures(self): - return dict( - Table1=(('id',), - (1,), - (2,), - (3,), - (4,)), - Table2=(('num', 't1id'), - (1, 1), - (2, 1), - (3, 1), - (4, 2), - (5, 2), - (6, 3))) - - @testing.resolve_artifact_names - def test_distinct_count(self): - query = create_session().query(Obj1) - eq_(query.count(), 4) - - res = query.filter(sa.and_(Table1.c.id == Table2.c.t1id, - Table2.c.t1id == 1)) - eq_(res.count(), 3) - res = query.filter(sa.and_(Table1.c.id == Table2.c.t1id, - Table2.c.t1id == 1)).distinct() - eq_(res.count(), 1) - - -class RelationsTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' - run_inserts = 'once' - run_deletes = None - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(User, users, properties={ - 'orders':relation(mapper(Order, orders, properties={ - 'addresses':relation(mapper(Address, addresses))}))}) - - - @testing.resolve_artifact_names - def test_join(self): - """Query.join""" - - session = create_session() - q = (session.query(User).join(['orders', 'addresses']). - filter(Address.id == 1)) - eq_([User(id=7)], q.all()) - - @testing.resolve_artifact_names - def test_outer_join(self): - """Query.outerjoin""" - - session = create_session() - q = (session.query(User).outerjoin(['orders', 'addresses']). - filter(sa.or_(Order.id == None, Address.id == 1))) - eq_(set([User(id=7), User(id=8), User(id=10)]), - set(q.all())) - - @testing.resolve_artifact_names - def test_outer_join_count(self): - """test the join and outerjoin functions on Query""" - - session = create_session() - - q = (session.query(User).outerjoin(['orders', 'addresses']). - filter(sa.or_(Order.id == None, Address.id == 1))) - eq_(q.count(), 4) - - @testing.resolve_artifact_names - def test_from(self): - session = create_session() - - sel = users.outerjoin(orders).outerjoin( - addresses, orders.c.address_id == addresses.c.id) - q = (session.query(User).select_from(sel). - filter(sa.or_(Order.id == None, Address.id == 1))) - eq_(set([User(id=7), User(id=8), User(id=10)]), - set(q.all())) - - -class CaseSensitiveTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('Table1', metadata, - Column('ID', Integer, primary_key=True)) - Table('Table2', metadata, - Column('T1ID', Integer, ForeignKey("Table1.ID"), - primary_key=True), - Column('NUM', Integer, primary_key=True)) - - @testing.resolve_artifact_names - def setup_mappers(self): - class Obj1(_base.BasicEntity): - pass - class Obj2(_base.BasicEntity): - pass - - mapper(Obj1, Table1) - mapper(Obj2, Table2) - - def fixtures(self): - return dict( - Table1=(('ID',), - (1,), - (2,), - (3,), - (4,)), - Table2=(('NUM', 'T1ID'), - (1, 1), - (2, 1), - (3, 1), - (4, 2), - (5, 2), - (6, 3))) - - @testing.resolve_artifact_names - def test_distinct_count(self): - q = create_session(bind=testing.db).query(Obj1) - assert q.count() == 4 - res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)) - assert res.count() == 3 - res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)).distinct() - self.assertEqual(res.count(), 1) - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/abc_inheritance.py b/test/orm/inheritance/abc_inheritance.py deleted file mode 100644 index ee324e381..000000000 --- a/test/orm/inheritance/abc_inheritance.py +++ /dev/null @@ -1,171 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.orm import * -from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE - -from testlib import testing -from orm import _base - - -def produce_test(parent, child, direction): - """produce a testcase for A->B->C inheritance with a self-referential - relationship between two of the classes, using either one-to-many or - many-to-one.""" - class ABCTest(_base.MappedTest): - def define_tables(self, meta): - global ta, tb, tc - ta = ["a", meta] - ta.append(Column('id', Integer, primary_key=True)), - ta.append(Column('a_data', String(30))) - if "a"== parent and direction == MANYTOONE: - ta.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) - elif "a" == child and direction == ONETOMANY: - ta.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) - ta = Table(*ta) - - tb = ["b", meta] - tb.append(Column('id', Integer, ForeignKey("a.id"), primary_key=True, )) - - tb.append(Column('b_data', String(30))) - - if "b"== parent and direction == MANYTOONE: - tb.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) - elif "b" == child and direction == ONETOMANY: - tb.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) - tb = Table(*tb) - - tc = ["c", meta] - tc.append(Column('id', Integer, ForeignKey("b.id"), primary_key=True, )) - - tc.append(Column('c_data', String(30))) - - if "c"== parent and direction == MANYTOONE: - tc.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) - elif "c" == child and direction == ONETOMANY: - tc.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) - tc = Table(*tc) - - def tearDown(self): - if direction == MANYTOONE: - parent_table = {"a":ta, "b":tb, "c": tc}[parent] - parent_table.update(values={parent_table.c.child_id:None}).execute() - elif direction == ONETOMANY: - child_table = {"a":ta, "b":tb, "c": tc}[child] - child_table.update(values={child_table.c.parent_id:None}).execute() - super(ABCTest, self).tearDown() - - def test_roundtrip(self): - parent_table = {"a":ta, "b":tb, "c": tc}[parent] - child_table = {"a":ta, "b":tb, "c": tc}[child] - - remote_side = None - - if direction == MANYTOONE: - foreign_keys = [parent_table.c.child_id] - elif direction == ONETOMANY: - foreign_keys = [child_table.c.parent_id] - - atob = ta.c.id==tb.c.id - btoc = tc.c.id==tb.c.id - - if direction == ONETOMANY: - relationjoin = parent_table.c.id==child_table.c.parent_id - elif direction == MANYTOONE: - relationjoin = parent_table.c.child_id==child_table.c.id - if parent is child: - remote_side = [child_table.c.id] - - abcjoin = polymorphic_union( - {"a":ta.select(tb.c.id==None, from_obj=[ta.outerjoin(tb, onclause=atob)]), - "b":ta.join(tb, onclause=atob).outerjoin(tc, onclause=btoc).select(tc.c.id==None, fold_equivalents=True), - "c":tc.join(tb, onclause=btoc).join(ta, onclause=atob) - },"type", "abcjoin" - ) - - bcjoin = polymorphic_union( - { - "b":ta.join(tb, onclause=atob).outerjoin(tc, onclause=btoc).select(tc.c.id==None, fold_equivalents=True), - "c":tc.join(tb, onclause=btoc).join(ta, onclause=atob) - },"type", "bcjoin" - ) - class A(object): - def __init__(self, name): - self.a_data = name - class B(A):pass - class C(B):pass - - mapper(A, ta, polymorphic_on=abcjoin.c.type, with_polymorphic=('*', abcjoin), polymorphic_identity="a") - mapper(B, tb, polymorphic_on=bcjoin.c.type, with_polymorphic=('*', bcjoin), polymorphic_identity="b", inherits=A, inherit_condition=atob) - mapper(C, tc, polymorphic_identity="c", inherits=B, inherit_condition=btoc) - - parent_mapper = class_mapper({ta:A, tb:B, tc:C}[parent_table]) - child_mapper = class_mapper({ta:A, tb:B, tc:C}[child_table]) - - parent_class = parent_mapper.class_ - child_class = child_mapper.class_ - - parent_mapper.add_property("collection", relation(child_mapper, primaryjoin=relationjoin, foreign_keys=foreign_keys, remote_side=remote_side, uselist=True)) - - sess = create_session() - - parent_obj = parent_class('parent1') - child_obj = child_class('child1') - somea = A('somea') - someb = B('someb') - somec = C('somec') - - #print "APPENDING", parent.__class__.__name__ , "TO", child.__class__.__name__ - - sess.add(parent_obj) - parent_obj.collection.append(child_obj) - if direction == ONETOMANY: - child2 = child_class('child2') - parent_obj.collection.append(child2) - sess.add(child2) - elif direction == MANYTOONE: - parent2 = parent_class('parent2') - parent2.collection.append(child_obj) - sess.add(parent2) - sess.add(somea) - sess.add(someb) - sess.add(somec) - sess.flush() - sess.expunge_all() - - # assert result via direct get() of parent object - result = sess.query(parent_class).get(parent_obj.id) - assert result.id == parent_obj.id - assert result.collection[0].id == child_obj.id - if direction == ONETOMANY: - assert result.collection[1].id == child2.id - elif direction == MANYTOONE: - result2 = sess.query(parent_class).get(parent2.id) - assert result2.id == parent2.id - assert result2.collection[0].id == child_obj.id - - sess.expunge_all() - - # assert result via polymorphic load of parent object - result = sess.query(A).filter_by(id=parent_obj.id).one() - assert result.id == parent_obj.id - assert result.collection[0].id == child_obj.id - if direction == ONETOMANY: - assert result.collection[1].id == child2.id - elif direction == MANYTOONE: - result2 = sess.query(A).filter_by(id=parent2.id).one() - assert result2.id == parent2.id - assert result2.collection[0].id == child_obj.id - - ABCTest.__name__ = "Test%sTo%s%s" % (parent, child, (direction is ONETOMANY and "O2M" or "M2O")) - return ABCTest - -# test all combinations of polymorphic a/b/c related to another of a/b/c -for parent in ["a", "b", "c"]: - for child in ["a", "b", "c"]: - for direction in [ONETOMANY, MANYTOONE]: - testclass = produce_test(parent, child, direction) - exec("%s = testclass" % testclass.__name__) - del testclass - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/abc_polymorphic.py b/test/orm/inheritance/abc_polymorphic.py deleted file mode 100644 index 6fabbb24c..000000000 --- a/test/orm/inheritance/abc_polymorphic.py +++ /dev/null @@ -1,90 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy import util -from sqlalchemy.orm import * - -from testlib import _function_named -from orm import _base, _fixtures - -class ABCTest(_base.MappedTest): - def define_tables(self, metadata): - global a, b, c - a = Table('a', metadata, - Column('id', Integer, primary_key=True), - Column('adata', String(30)), - Column('type', String(30)), - ) - b = Table('b', metadata, - Column('id', Integer, ForeignKey('a.id'), primary_key=True), - Column('bdata', String(30))) - c = Table('c', metadata, - Column('id', Integer, ForeignKey('b.id'), primary_key=True), - Column('cdata', String(30))) - - def make_test(fetchtype): - def test_roundtrip(self): - class A(_fixtures.Base):pass - class B(A):pass - class C(B):pass - - if fetchtype == 'union': - abc = a.outerjoin(b).outerjoin(c) - bc = a.join(b).outerjoin(c) - else: - abc = bc = None - - mapper(A, a, with_polymorphic=('*', abc), polymorphic_on=a.c.type, polymorphic_identity='a') - mapper(B, b, with_polymorphic=('*', bc), inherits=A, polymorphic_identity='b') - mapper(C, c, inherits=B, polymorphic_identity='c') - - a1 = A(adata='a1') - b1 = B(bdata='b1', adata='b1') - b2 = B(bdata='b2', adata='b2') - b3 = B(bdata='b3', adata='b3') - c1 = C(cdata='c1', bdata='c1', adata='c1') - c2 = C(cdata='c2', bdata='c2', adata='c2') - c3 = C(cdata='c2', bdata='c2', adata='c2') - - sess = create_session() - for x in (a1, b1, b2, b3, c1, c2, c3): - sess.add(x) - sess.flush() - sess.expunge_all() - - #for obj in sess.query(A).all(): - # print obj - assert [ - A(adata='a1'), - B(bdata='b1', adata='b1'), - B(bdata='b2', adata='b2'), - B(bdata='b3', adata='b3'), - C(cdata='c1', bdata='c1', adata='c1'), - C(cdata='c2', bdata='c2', adata='c2'), - C(cdata='c2', bdata='c2', adata='c2'), - ] == sess.query(A).all() - - assert [ - B(bdata='b1', adata='b1'), - B(bdata='b2', adata='b2'), - B(bdata='b3', adata='b3'), - C(cdata='c1', bdata='c1', adata='c1'), - C(cdata='c2', bdata='c2', adata='c2'), - C(cdata='c2', bdata='c2', adata='c2'), - ] == sess.query(B).all() - - assert [ - C(cdata='c1', bdata='c1', adata='c1'), - C(cdata='c2', bdata='c2', adata='c2'), - C(cdata='c2', bdata='c2', adata='c2'), - ] == sess.query(C).all() - - test_roundtrip = _function_named( - test_roundtrip, 'test_%s' % fetchtype) - return test_roundtrip - - test_union = make_test('union') - test_none = make_test('none') - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/alltests.py b/test/orm/inheritance/alltests.py deleted file mode 100644 index 41f0521dd..000000000 --- a/test/orm/inheritance/alltests.py +++ /dev/null @@ -1,31 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -def suite(): - modules_to_test = ( - 'orm.inheritance.basic', - 'orm.inheritance.query', - 'orm.inheritance.manytomany', - 'orm.inheritance.single', - 'orm.inheritance.concrete', - 'orm.inheritance.polymorph', - 'orm.inheritance.polymorph2', - 'orm.inheritance.poly_linked_list', - 'orm.inheritance.abc_polymorphic', - 'orm.inheritance.abc_inheritance', - 'orm.inheritance.productspec', - 'orm.inheritance.magazine', - 'orm.inheritance.selects', - - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py deleted file mode 100644 index 150874477..000000000 --- a/test/orm/inheritance/basic.py +++ /dev/null @@ -1,1015 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy import exc as sa_exc, util -from sqlalchemy.orm import * -from sqlalchemy.orm import exc as orm_exc - -#from testlib import * -#from testlib import fixtures -from testlib import _function_named, testing, engines -from orm import _base, _fixtures - -class O2MTest(_base.MappedTest): - """deals with inheritance and one-to-many relationships""" - def define_tables(self, metadata): - global foo, bar, blub - foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), - Column('data', String(20))) - - bar = Table('bar', metadata, - Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - Column('data', String(20))) - - blub = Table('blub', metadata, - Column('id', Integer, ForeignKey('bar.id'), primary_key=True), - Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False), - Column('data', String(20))) - - def testbasic(self): - class Foo(object): - def __init__(self, data=None): - self.data = data - def __repr__(self): - return "Foo id %d, data %s" % (self.id, self.data) - mapper(Foo, foo) - - class Bar(Foo): - def __repr__(self): - return "Bar id %d, data %s" % (self.id, self.data) - - mapper(Bar, bar, inherits=Foo) - - class Blub(Bar): - def __repr__(self): - return "Blub id %d, data %s" % (self.id, self.data) - - mapper(Blub, blub, inherits=Bar, properties={ - 'parent_foo':relation(Foo) - }) - - sess = create_session() - b1 = Blub("blub #1") - b2 = Blub("blub #2") - f = Foo("foo #1") - sess.add(b1) - sess.add(b2) - sess.add(f) - b1.parent_foo = f - b2.parent_foo = f - sess.flush() - compare = ','.join([repr(b1), repr(b2), repr(b1.parent_foo), repr(b2.parent_foo)]) - sess.expunge_all() - l = sess.query(Blub).all() - result = ','.join([repr(l[0]), repr(l[1]), repr(l[0].parent_foo), repr(l[1].parent_foo)]) - print compare - print result - self.assert_(compare == result) - self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') - -class FalseDiscriminatorTest(_base.MappedTest): - def define_tables(self, metadata): - global t1 - t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('type', Integer, nullable=False)) - - def test_false_discriminator(self): - class Foo(object):pass - class Bar(Foo):pass - mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=1) - mapper(Bar, inherits=Foo, polymorphic_identity=0) - sess = create_session() - f1 = Bar() - sess.add(f1) - sess.flush() - assert f1.type == 0 - sess.expunge_all() - assert isinstance(sess.query(Foo).one(), Bar) - -class PolymorphicSynonymTest(_base.MappedTest): - def define_tables(self, metadata): - global t1, t2 - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('type', String(10), nullable=False), - Column('info', String(255))) - t2 = Table('t2', metadata, - Column('id', Integer, ForeignKey('t1.id'), primary_key=True), - Column('data', String(10), nullable=False)) - - def test_polymorphic_synonym(self): - class T1(_fixtures.Base): - def info(self): - return "THE INFO IS:" + self._info - def _set_info(self, x): - self._info = x - info = property(info, _set_info) - - class T2(T1):pass - - mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', properties={ - 'info':synonym('_info', map_column=True) - }) - mapper(T2, t2, inherits=T1, polymorphic_identity='t2') - sess = create_session() - at1 = T1(info='at1') - at2 = T2(info='at2', data='t2 data') - sess.add(at1) - sess.add(at2) - sess.flush() - sess.expunge_all() - self.assertEquals(sess.query(T2).filter(T2.info=='at2').one(), at2) - self.assertEquals(at2.info, "THE INFO IS:at2") - - -class CascadeTest(_base.MappedTest): - """that cascades on polymorphic relations continue - cascading along the path of the instance's mapper, not - the base mapper.""" - - def define_tables(self, metadata): - global t1, t2, t3, t4 - t1= Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)) - ) - - t2 = Table('t2', metadata, - Column('id', Integer, primary_key=True), - Column('t1id', Integer, ForeignKey('t1.id')), - Column('type', String(30)), - Column('data', String(30)) - ) - t3 = Table('t3', metadata, - Column('id', Integer, ForeignKey('t2.id'), primary_key=True), - Column('moredata', String(30))) - - t4 = Table('t4', metadata, - Column('id', Integer, primary_key=True), - Column('t3id', Integer, ForeignKey('t3.id')), - Column('data', String(30))) - - def test_cascade(self): - class T1(_fixtures.Base): - pass - class T2(_fixtures.Base): - pass - class T3(T2): - pass - class T4(_fixtures.Base): - pass - - mapper(T1, t1, properties={ - 't2s':relation(T2, cascade="all") - }) - mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2') - mapper(T3, t3, inherits=T2, polymorphic_identity='t3', properties={ - 't4s':relation(T4, cascade="all") - }) - mapper(T4, t4) - - sess = create_session() - t1_1 = T1(data='t1') - - t3_1 = T3(data ='t3', moredata='t3') - t2_1 = T2(data='t2') - - t1_1.t2s.append(t2_1) - t1_1.t2s.append(t3_1) - - t4_1 = T4(data='t4') - t3_1.t4s.append(t4_1) - - sess.add(t1_1) - - - assert t4_1 in sess.new - sess.flush() - - sess.delete(t1_1) - assert t4_1 in sess.deleted - sess.flush() - -class GetTest(_base.MappedTest): - def define_tables(self, metadata): - global foo, bar, blub - foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), - Column('type', String(30)), - Column('data', String(20))) - - bar = Table('bar', metadata, - Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - Column('data', String(20))) - - blub = Table('blub', metadata, - Column('id', Integer, primary_key=True), - Column('foo_id', Integer, ForeignKey('foo.id')), - Column('bar_id', Integer, ForeignKey('bar.id')), - Column('data', String(20))) - - def create_test(polymorphic, name): - def test_get(self): - class Foo(object): - pass - - class Bar(Foo): - pass - - class Blub(Bar): - pass - - if polymorphic: - mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo') - mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar') - mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub') - else: - mapper(Foo, foo) - mapper(Bar, bar, inherits=Foo) - mapper(Blub, blub, inherits=Bar) - - sess = create_session() - f = Foo() - b = Bar() - bl = Blub() - sess.add(f) - sess.add(b) - sess.add(bl) - sess.flush() - - if polymorphic: - def go(): - assert sess.query(Foo).get(f.id) == f - assert sess.query(Foo).get(b.id) == b - assert sess.query(Foo).get(bl.id) == bl - assert sess.query(Bar).get(b.id) == b - assert sess.query(Bar).get(bl.id) == bl - assert sess.query(Blub).get(bl.id) == bl - - self.assert_sql_count(testing.db, go, 0) - else: - # this is testing the 'wrong' behavior of using get() - # polymorphically with mappers that are not configured to be - # polymorphic. the important part being that get() always - # returns an instance of the query's type. - def go(): - assert sess.query(Foo).get(f.id) == f - - bb = sess.query(Foo).get(b.id) - assert isinstance(b, Foo) and bb.id==b.id - - bll = sess.query(Foo).get(bl.id) - assert isinstance(bll, Foo) and bll.id==bl.id - - assert sess.query(Bar).get(b.id) == b - - bll = sess.query(Bar).get(bl.id) - assert isinstance(bll, Bar) and bll.id == bl.id - - assert sess.query(Blub).get(bl.id) == bl - - self.assert_sql_count(testing.db, go, 3) - - test_get = _function_named(test_get, name) - return test_get - - test_get_polymorphic = create_test(True, 'test_get_polymorphic') - test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic') - -class EagerLazyTest(_base.MappedTest): - """tests eager load/lazy load of child items off inheritance mappers, tests that - LazyLoader constructs the right query condition.""" - def define_tables(self, metadata): - global foo, bar, bar_foo - foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), - Column('data', String(30))) - bar = Table('bar', metadata, - Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - Column('data', String(30))) - - bar_foo = Table('bar_foo', metadata, - Column('bar_id', Integer, ForeignKey('bar.id')), - Column('foo_id', Integer, ForeignKey('foo.id')) - ) - - @testing.fails_on('maxdb', 'FIXME: unknown') - def testbasic(self): - class Foo(object): pass - class Bar(Foo): pass - - foos = mapper(Foo, foo) - bars = mapper(Bar, bar, inherits=foos) - bars.add_property('lazy', relation(foos, bar_foo, lazy=True)) - bars.add_property('eager', relation(foos, bar_foo, lazy=False)) - - foo.insert().execute(data='foo1') - bar.insert().execute(id=1, data='bar1') - - foo.insert().execute(data='foo2') - bar.insert().execute(id=2, data='bar2') - - foo.insert().execute(data='foo3') #3 - foo.insert().execute(data='foo4') #4 - - bar_foo.insert().execute(bar_id=1, foo_id=3) - bar_foo.insert().execute(bar_id=2, foo_id=4) - - sess = create_session() - q = sess.query(Bar) - self.assert_(len(q.first().lazy) == 1) - self.assert_(len(q.first().eager) == 1) - - -class FlushTest(_base.MappedTest): - """test dependency sorting among inheriting mappers""" - def define_tables(self, metadata): - global users, roles, user_roles, admins - users = Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('email', String(128)), - Column('password', String(16)), - ) - - roles = Table('role', metadata, - Column('id', Integer, primary_key=True), - Column('description', String(32)) - ) - - user_roles = Table('user_role', metadata, - Column('user_id', Integer, ForeignKey('users.id'), primary_key=True), - Column('role_id', Integer, ForeignKey('role.id'), primary_key=True) - ) - - admins = Table('admin', metadata, - Column('admin_id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey('users.id')) - ) - - def testone(self): - class User(object):pass - class Role(object):pass - class Admin(User):pass - role_mapper = mapper(Role, roles) - user_mapper = mapper(User, users, properties = { - 'roles' : relation(Role, secondary=user_roles, lazy=False) - } - ) - admin_mapper = mapper(Admin, admins, inherits=user_mapper) - sess = create_session() - adminrole = Role() - sess.add(adminrole) - sess.flush() - - # create an Admin, and append a Role. the dependency processors - # corresponding to the "roles" attribute for the Admin mapper and the User mapper - # have to ensure that two dependency processors dont fire off and insert the - # many to many row twice. - a = Admin() - a.roles.append(adminrole) - a.password = 'admin' - sess.add(a) - sess.flush() - - assert user_roles.count().scalar() == 1 - - def testtwo(self): - class User(object): - def __init__(self, email=None, password=None): - self.email = email - self.password = password - - class Role(object): - def __init__(self, description=None): - self.description = description - - class Admin(User):pass - - role_mapper = mapper(Role, roles) - user_mapper = mapper(User, users, properties = { - 'roles' : relation(Role, secondary=user_roles, lazy=False) - } - ) - - admin_mapper = mapper(Admin, admins, inherits=user_mapper) - - # create roles - adminrole = Role('admin') - - sess = create_session() - sess.add(adminrole) - sess.flush() - - # create admin user - a = Admin(email='tim', password='admin') - a.roles.append(adminrole) - sess.add(a) - sess.flush() - - a.password = 'sadmin' - sess.flush() - assert user_roles.count().scalar() == 1 - -class VersioningTest(_base.MappedTest): - def define_tables(self, metadata): - global base, subtable, stuff - base = Table('base', metadata, - Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ), - Column('version_id', Integer, nullable=False), - Column('value', String(40)), - Column('discriminator', Integer, nullable=False) - ) - subtable = Table('subtable', metadata, - Column('id', None, ForeignKey('base.id'), primary_key=True), - Column('subdata', String(50)) - ) - stuff = Table('stuff', metadata, - Column('id', Integer, primary_key=True), - Column('parent', Integer, ForeignKey('base.id')) - ) - - @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.') - @engines.close_open_connections - def test_save_update(self): - class Base(_fixtures.Base): - pass - class Sub(Base): - pass - class Stuff(Base): - pass - mapper(Stuff, stuff) - mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1, properties={ - 'stuff':relation(Stuff) - }) - mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) - - sess = create_session() - - b1 = Base(value='b1') - s1 = Sub(value='sub1', subdata='some subdata') - sess.add(b1) - sess.add(s1) - - sess.flush() - - sess2 = create_session() - s2 = sess2.query(Base).get(s1.id) - s2.subdata = 'sess2 subdata' - - s1.subdata = 'sess1 subdata' - - sess.flush() - - try: - sess2.query(Base).with_lockmode('read').get(s1.id) - assert False - except orm_exc.ConcurrentModificationError, e: - assert True - - try: - sess2.flush() - assert False - except orm_exc.ConcurrentModificationError, e: - assert True - - sess2.refresh(s2) - assert s2.subdata == 'sess1 subdata' - s2.subdata = 'sess2 subdata' - sess2.flush() - - @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.') - def test_delete(self): - class Base(_fixtures.Base): - pass - class Sub(Base): - pass - - mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1) - mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) - - sess = create_session() - - b1 = Base(value='b1') - s1 = Sub(value='sub1', subdata='some subdata') - s2 = Sub(value='sub2', subdata='some other subdata') - sess.add(b1) - sess.add(s1) - sess.add(s2) - - sess.flush() - - sess2 = create_session() - s3 = sess2.query(Base).get(s1.id) - sess2.delete(s3) - sess2.flush() - - s2.subdata = 'some new subdata' - sess.flush() - - try: - s1.subdata = 'some new subdata' - sess.flush() - assert False - except orm_exc.ConcurrentModificationError, e: - assert True - -class DistinctPKTest(_base.MappedTest): - """test the construction of mapper.primary_key when an inheriting relationship - joins on a column other than primary key column.""" - - run_inserts = 'once' - run_deletes = None - - def define_tables(self, metadata): - global person_table, employee_table, Person, Employee - - person_table = Table("persons", metadata, - Column("id", Integer, primary_key=True), - Column("name", String(80)), - ) - - employee_table = Table("employees", metadata, - Column("id", Integer, primary_key=True), - Column("salary", Integer), - Column("person_id", Integer, ForeignKey("persons.id")), - ) - - class Person(object): - def __init__(self, name): - self.name = name - - class Employee(Person): pass - - def insert_data(self): - person_insert = person_table.insert() - person_insert.execute(id=1, name='alice') - person_insert.execute(id=2, name='bob') - - employee_insert = employee_table.insert() - employee_insert.execute(id=2, salary=250, person_id=1) # alice - employee_insert.execute(id=3, salary=200, person_id=2) # bob - - def test_implicit(self): - person_mapper = mapper(Person, person_table) - mapper(Employee, employee_table, inherits=person_mapper) - assert list(class_mapper(Employee).primary_key) == [person_table.c.id] - - def test_explicit_props(self): - person_mapper = mapper(Person, person_table) - mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id}) - self._do_test(True) - - def test_explicit_composite_pk(self): - person_mapper = mapper(Person, person_table) - try: - mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id]) - self._do_test(True) - assert False - except sa_exc.SAWarning, e: - assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e) - - def test_explicit_pk(self): - person_mapper = mapper(Person, person_table) - mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id]) - self._do_test(False) - - def _do_test(self, composite): - session = create_session() - query = session.query(Employee) - - if composite: - alice1 = query.get([1,2]) - bob = query.get([2,3]) - alice2 = query.get([1,2]) - else: - alice1 = query.get(1) - bob = query.get(2) - alice2 = query.get(1) - - assert alice1.name == alice2.name == 'alice' - assert bob.name == 'bob' - -class SyncCompileTest(_base.MappedTest): - """test that syncrules compile properly on custom inherit conds""" - def define_tables(self, metadata): - global _a_table, _b_table, _c_table - - _a_table = Table('a', metadata, - Column('id', Integer, primary_key=True), - Column('data1', String(128)) - ) - - _b_table = Table('b', metadata, - Column('a_id', Integer, ForeignKey('a.id'), primary_key=True), - Column('data2', String(128)) - ) - - _c_table = Table('c', metadata, - # Column('a_id', Integer, ForeignKey('b.a_id'), primary_key=True), #works - Column('b_a_id', Integer, ForeignKey('b.a_id'), primary_key=True), - Column('data3', String(128)) - ) - - def test_joins(self): - for j1 in (None, _b_table.c.a_id==_a_table.c.id, _a_table.c.id==_b_table.c.a_id): - for j2 in (None, _b_table.c.a_id==_c_table.c.b_a_id, _c_table.c.b_a_id==_b_table.c.a_id): - self._do_test(j1, j2) - for t in reversed(_a_table.metadata.sorted_tables): - t.delete().execute().close() - - def _do_test(self, j1, j2): - class A(object): - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - class B(A): - pass - - class C(B): - pass - - mapper(A, _a_table) - mapper(B, _b_table, inherits=A, - inherit_condition=j1 - ) - mapper(C, _c_table, inherits=B, - inherit_condition=j2 - ) - - session = create_session() - - a = A(data1='a1') - session.add(a) - - b = B(data1='b1', data2='b2') - session.add(b) - - c = C(data1='c1', data2='c2', data3='c3') - session.add(c) - - session.flush() - session.expunge_all() - - assert len(session.query(A).all()) == 3 - assert len(session.query(B).all()) == 2 - assert len(session.query(C).all()) == 1 - -class OverrideColKeyTest(_base.MappedTest): - """test overriding of column attributes.""" - - def define_tables(self, metadata): - global base, subtable - - base = Table('base', metadata, - Column('base_id', Integer, primary_key=True), - Column('data', String(255)), - Column('sqlite_fixer', String(10)) - ) - - subtable = Table('subtable', metadata, - Column('base_id', Integer, ForeignKey('base.base_id'), primary_key=True), - Column('subdata', String(255)) - ) - - def test_plain(self): - # control case - class Base(object): - pass - class Sub(Base): - pass - - mapper(Base, base) - mapper(Sub, subtable, inherits=Base) - - # Sub gets a "base_id" property using the "base_id" - # column of both tables. - self.assertEquals( - class_mapper(Sub).get_property('base_id').columns, - [base.c.base_id, subtable.c.base_id] - ) - - def test_override_explicit(self): - # this pattern is what you see when using declarative - # in particular, here we do a "manual" version of - # what we'd like the mapper to do. - - class Base(object): - pass - class Sub(Base): - pass - - mapper(Base, base, properties={ - 'id':base.c.base_id - }) - mapper(Sub, subtable, inherits=Base, properties={ - # this is the manual way to do it, is not really - # possible in declarative - 'id':[base.c.base_id, subtable.c.base_id] - }) - - self.assertEquals( - class_mapper(Sub).get_property('id').columns, - [base.c.base_id, subtable.c.base_id] - ) - - s1 = Sub() - s1.id = 10 - sess = create_session() - sess.add(s1) - sess.flush() - assert sess.query(Sub).get(10) is s1 - - def test_override_onlyinparent(self): - class Base(object): - pass - class Sub(Base): - pass - - mapper(Base, base, properties={ - 'id':base.c.base_id - }) - mapper(Sub, subtable, inherits=Base) - - self.assertEquals( - class_mapper(Sub).get_property('id').columns, - [base.c.base_id] - ) - - self.assertEquals( - class_mapper(Sub).get_property('base_id').columns, - [subtable.c.base_id] - ) - - s1 = Sub() - s1.id = 10 - - s2 = Sub() - s2.base_id = 15 - - sess = create_session() - sess.add_all([s1, s2]) - sess.flush() - - # s1 gets '10' - assert sess.query(Sub).get(10) is s1 - - # s2 gets a new id, base_id is overwritten by the ultimate - # PK col - assert s2.id == s2.base_id != 15 - - def test_override_implicit(self): - # this is how the pattern looks intuitively when - # using declarative. - # fixed as part of [ticket:1111] - - class Base(object): - pass - class Sub(Base): - pass - - mapper(Base, base, properties={ - 'id':base.c.base_id - }) - mapper(Sub, subtable, inherits=Base, properties={ - 'id':subtable.c.base_id - }) - - # Sub mapper compilation needs to detect that "base.c.base_id" - # is renamed in the inherited mapper as "id", even though - # it has its own "id" property. Sub's "id" property - # gets joined normally with the extra column. - - self.assertEquals( - class_mapper(Sub).get_property('id').columns, - [base.c.base_id, subtable.c.base_id] - ) - - s1 = Sub() - s1.id = 10 - sess = create_session() - sess.add(s1) - sess.flush() - assert sess.query(Sub).get(10) is s1 - - def test_plain_descriptor(self): - """test that descriptors prevent inheritance from propigating properties to subclasses.""" - - class Base(object): - pass - class Sub(Base): - @property - def data(self): - return "im the data" - - mapper(Base, base) - mapper(Sub, subtable, inherits=Base) - - s1 = Sub() - sess = create_session() - sess.add(s1) - sess.flush() - assert sess.query(Sub).one().data == "im the data" - - def test_custom_descriptor(self): - """test that descriptors prevent inheritance from propigating properties to subclasses.""" - - class MyDesc(object): - def __get__(self, instance, owner): - if instance is None: - return self - return "im the data" - - class Base(object): - pass - class Sub(Base): - data = MyDesc() - - mapper(Base, base) - mapper(Sub, subtable, inherits=Base) - - s1 = Sub() - sess = create_session() - sess.add(s1) - sess.flush() - assert sess.query(Sub).one().data == "im the data" - - def test_sub_columns_over_base_descriptors(self): - class Base(object): - @property - def subdata(self): - return "this is base" - - class Sub(Base): - pass - - mapper(Base, base) - mapper(Sub, subtable, inherits=Base) - - sess = create_session() - b1 = Base() - assert b1.subdata == "this is base" - s1 = Sub() - s1.subdata = "this is sub" - assert s1.subdata == "this is sub" - - sess.add_all([s1, b1]) - sess.flush() - sess.expunge_all() - - assert sess.query(Base).get(b1.base_id).subdata == "this is base" - assert sess.query(Sub).get(s1.base_id).subdata == "this is sub" - - def test_base_descriptors_over_base_cols(self): - class Base(object): - @property - def data(self): - return "this is base" - - class Sub(Base): - pass - - mapper(Base, base) - mapper(Sub, subtable, inherits=Base) - - sess = create_session() - b1 = Base() - assert b1.data == "this is base" - s1 = Sub() - assert s1.data == "this is base" - - sess.add_all([s1, b1]) - sess.flush() - sess.expunge_all() - - assert sess.query(Base).get(b1.base_id).data == "this is base" - assert sess.query(Sub).get(s1.base_id).data == "this is base" - -class OptimizedLoadTest(_base.MappedTest): - """test that the 'optimized load' routine doesn't crash when - a column in the join condition is not available. - - """ - def define_tables(self, metadata): - global base, sub - base = Table('base', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('type', String(50)) - ) - sub = Table('sub', metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - Column('sub', String(50)) - ) - - def test_optimized_passes(self): - class Base(object): - pass - class Sub(Base): - pass - - mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base') - - # redefine Sub's "id" to favor the "id" col in the subtable. - # "id" is also part of the primary join condition - mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={'id':sub.c.id}) - sess = create_session() - s1 = Sub() - s1.data = 's1data' - s1.sub = 's1sub' - sess.add(s1) - sess.flush() - sess.expunge_all() - - # load s1 via Base. s1.id won't populate since it's relative to - # the "sub" table. The optimized load kicks in and tries to - # generate on the primary join, but cannot since "id" is itself unloaded. - # the optimized load needs to return "None" so regular full-row loading proceeds - s1 = sess.query(Base).get(s1.id) - assert s1.sub == 's1sub' - -class PKDiscriminatorTest(_base.MappedTest): - def define_tables(self, metadata): - parents = Table('parents', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(60))) - - children = Table('children', metadata, - Column('id', Integer, ForeignKey('parents.id'), primary_key=True), - Column('type', Integer,primary_key=True), - Column('name', String(60))) - - @testing.resolve_artifact_names - def test_pk_as_discriminator(self): - class Parent(object): - def __init__(self, name=None): - self.name = name - - class Child(object): - def __init__(self, name=None): - self.name = name - - class A(Child): - pass - - mapper(Parent, parents, properties={ - 'children': relation(Child, backref='parent'), - }) - mapper(Child, children, polymorphic_on=children.c.type, - polymorphic_identity=1) - - mapper(A, inherits=Child, polymorphic_identity=2) - - s = create_session() - p = Parent('p1') - a = A('a1') - p.children.append(a) - s.add(p) - s.flush() - - assert a.id - assert a.type == 2 - - -class DeleteOrphanTest(_base.MappedTest): - def define_tables(self, metadata): - global single, parent - single = Table('single', metadata, - Column('id', Integer, primary_key=True), - Column('type', String(50), nullable=False), - Column('data', String(50)), - Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False), - ) - - parent = Table('parent', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) - - def test_orphan_message(self): - class Base(_fixtures.Base): - pass - - class SubClass(Base): - pass - - class Parent(_fixtures.Base): - pass - - mapper(Base, single, polymorphic_on=single.c.type, polymorphic_identity='base') - mapper(SubClass, inherits=Base, polymorphic_identity='sub') - mapper(Parent, parent, properties={ - 'related':relation(Base, cascade="all, delete-orphan") - }) - - sess = create_session() - s1 = SubClass(data='s1') - sess.add(s1) - self.assertRaisesMessage(orm_exc.FlushError, - "is not attached to any parent 'Parent' instance via that classes' 'related' attribute", sess.flush) - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/concrete.py b/test/orm/inheritance/concrete.py deleted file mode 100644 index 6cdaed7e6..000000000 --- a/test/orm/inheritance/concrete.py +++ /dev/null @@ -1,515 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.orm import * -from sqlalchemy.orm import exc as orm_exc -from testlib import * -from testlib import sa, testing -from orm import _base -from sqlalchemy.orm import attributes -from testlib.testing import eq_ - -class Employee(object): - def __init__(self, name): - self.name = name - def __repr__(self): - return self.__class__.__name__ + " " + self.name - -class Manager(Employee): - def __init__(self, name, manager_data): - self.name = name - self.manager_data = manager_data - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.manager_data - -class Engineer(Employee): - def __init__(self, name, engineer_info): - self.name = name - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " " + self.engineer_info - -class Hacker(Engineer): - def __init__(self, name, nickname, engineer_info): - self.name = name - self.nickname = nickname - self.engineer_info = engineer_info - def __repr__(self): - return self.__class__.__name__ + " " + self.name + " '" + \ - self.nickname + "' " + self.engineer_info - -class Company(object): - pass - - -class ConcreteTest(_base.MappedTest): - def define_tables(self, metadata): - global managers_table, engineers_table, hackers_table, companies, employees_table - - companies = Table('companies', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50))) - - employees_table = Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), - Column('name', String(50)), - Column('company_id', Integer, ForeignKey('companies.id')) - ) - - managers_table = Table('managers', metadata, - Column('employee_id', Integer, primary_key=True), - Column('name', String(50)), - Column('manager_data', String(50)), - Column('company_id', Integer, ForeignKey('companies.id')) - ) - - engineers_table = Table('engineers', metadata, - Column('employee_id', Integer, primary_key=True), - Column('name', String(50)), - Column('engineer_info', String(50)), - Column('company_id', Integer, ForeignKey('companies.id')) - ) - - hackers_table = Table('hackers', metadata, - Column('employee_id', Integer, primary_key=True), - Column('name', String(50)), - Column('engineer_info', String(50)), - Column('company_id', Integer, ForeignKey('companies.id')), - Column('nickname', String(50)) - ) - - - - def test_basic(self): - pjoin = polymorphic_union({ - 'manager':managers_table, - 'engineer':engineers_table - }, 'type', 'pjoin') - - employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) - manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, - concrete=True, polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, - concrete=True, polymorphic_identity='engineer') - - session = create_session() - session.add(Manager('Tom', 'knows how to manage things')) - session.add(Engineer('Kurt', 'knows how to hack')) - session.flush() - session.expunge_all() - - assert set([repr(x) for x in session.query(Employee)]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) - assert set([repr(x) for x in session.query(Manager)]) == set(["Manager Tom knows how to manage things"]) - assert set([repr(x) for x in session.query(Engineer)]) == set(["Engineer Kurt knows how to hack"]) - - manager = session.query(Manager).one() - session.expire(manager, ['manager_data']) - self.assertEquals(manager.manager_data, "knows how to manage things") - - def test_multi_level_no_base(self): - pjoin = polymorphic_union({ - 'manager': managers_table, - 'engineer': engineers_table, - 'hacker': hackers_table - }, 'type', 'pjoin') - - pjoin2 = polymorphic_union({ - 'engineer': engineers_table, - 'hacker': hackers_table - }, 'type', 'pjoin2') - - employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) - manager_mapper = mapper(Manager, managers_table, - inherits=employee_mapper, concrete=True, - polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, engineers_table, - with_polymorphic=('*', pjoin2), - polymorphic_on=pjoin2.c.type, - inherits=employee_mapper, concrete=True, - polymorphic_identity='engineer') - hacker_mapper = mapper(Hacker, hackers_table, - inherits=engineer_mapper, - concrete=True, polymorphic_identity='hacker') - - session = create_session() - tom = Manager('Tom', 'knows how to manage things') - jerry = Engineer('Jerry', 'knows how to program') - hacker = Hacker('Kurt', 'Badass', 'knows how to hack') - session.add_all((tom, jerry, hacker)) - session.flush() - - # ensure "readonly" on save logic didn't pollute the expired_attributes - # collection - assert 'nickname' not in attributes.instance_state(jerry).expired_attributes - assert 'name' not in attributes.instance_state(jerry).expired_attributes - assert 'name' not in attributes.instance_state(hacker).expired_attributes - assert 'nickname' not in attributes.instance_state(hacker).expired_attributes - def go(): - self.assertEquals(jerry.name, "Jerry") - self.assertEquals(hacker.nickname, "Badass") - self.assert_sql_count(testing.db, go, 0) - - session.expunge_all() - - assert repr(session.query(Employee).filter(Employee.name=='Tom').one()) == "Manager Tom knows how to manage things" - assert repr(session.query(Manager).filter(Manager.name=='Tom').one()) == "Manager Tom knows how to manage things" - - - assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) - assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"]) - assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) - assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"]) - - def test_multi_level_with_base(self): - pjoin = polymorphic_union({ - 'employee':employees_table, - 'manager': managers_table, - 'engineer': engineers_table, - 'hacker': hackers_table - }, 'type', 'pjoin') - - pjoin2 = polymorphic_union({ - 'engineer': engineers_table, - 'hacker': hackers_table - }, 'type', 'pjoin2') - - employee_mapper = mapper(Employee, employees_table, - with_polymorphic=('*', pjoin), polymorphic_on=pjoin.c.type) - manager_mapper = mapper(Manager, managers_table, - inherits=employee_mapper, concrete=True, - polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, engineers_table, - with_polymorphic=('*', pjoin2), - polymorphic_on=pjoin2.c.type, - inherits=employee_mapper, concrete=True, - polymorphic_identity='engineer') - hacker_mapper = mapper(Hacker, hackers_table, - inherits=engineer_mapper, - concrete=True, polymorphic_identity='hacker') - - session = create_session() - tom = Manager('Tom', 'knows how to manage things') - jerry = Engineer('Jerry', 'knows how to program') - hacker = Hacker('Kurt', 'Badass', 'knows how to hack') - session.add_all((tom, jerry, hacker)) - session.flush() - - def go(): - self.assertEquals(jerry.name, "Jerry") - self.assertEquals(hacker.nickname, "Badass") - self.assert_sql_count(testing.db, go, 0) - - session.expunge_all() - - # check that we aren't getting a cartesian product in the raw SQL. - # this requires that Engineer's polymorphic discriminator is not rendered - # in the statement which is only against Employee's "pjoin" - assert len(testing.db.execute(session.query(Employee).with_labels().statement).fetchall()) == 3 - - assert set([repr(x) for x in session.query(Employee)]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) - assert set([repr(x) for x in session.query(Manager)]) == set(["Manager Tom knows how to manage things"]) - assert set([repr(x) for x in session.query(Engineer)]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) - assert set([repr(x) for x in session.query(Hacker)]) == set(["Hacker Kurt 'Badass' knows how to hack"]) - - - def test_without_default_polymorphic(self): - pjoin = polymorphic_union({ - 'employee':employees_table, - 'manager': managers_table, - 'engineer': engineers_table, - 'hacker': hackers_table - }, 'type', 'pjoin') - - pjoin2 = polymorphic_union({ - 'engineer': engineers_table, - 'hacker': hackers_table - }, 'type', 'pjoin2') - - employee_mapper = mapper(Employee, employees_table, - polymorphic_identity='employee') - manager_mapper = mapper(Manager, managers_table, - inherits=employee_mapper, concrete=True, - polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, engineers_table, - inherits=employee_mapper, concrete=True, - polymorphic_identity='engineer') - hacker_mapper = mapper(Hacker, hackers_table, - inherits=engineer_mapper, - concrete=True, polymorphic_identity='hacker') - - session = create_session() - jdoe = Employee('Jdoe') - tom = Manager('Tom', 'knows how to manage things') - jerry = Engineer('Jerry', 'knows how to program') - hacker = Hacker('Kurt', 'Badass', 'knows how to hack') - session.add_all((jdoe, tom, jerry, hacker)) - session.flush() - - eq_( - len(testing.db.execute(session.query(Employee).with_polymorphic('*', pjoin, pjoin.c.type).with_labels().statement).fetchall()), - 4 - ) - - eq_( - session.query(Employee).get(jdoe.employee_id), jdoe - ) - eq_( - session.query(Engineer).get(jerry.employee_id), jerry - ) - eq_( - set([repr(x) for x in session.query(Employee).with_polymorphic('*', pjoin, pjoin.c.type)]), - set(["Employee Jdoe", "Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) - ) - eq_( - set([repr(x) for x in session.query(Manager)]), - set(["Manager Tom knows how to manage things"]) - ) - eq_( - set([repr(x) for x in session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type)]), - set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) - ) - eq_( - set([repr(x) for x in session.query(Hacker)]), - set(["Hacker Kurt 'Badass' knows how to hack"]) - ) - # test adaption of the column by wrapping the query in a subquery - eq_( - len(testing.db.execute( - session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type).from_self().statement - ).fetchall()), - 2 - ) - eq_( - set([repr(x) for x in session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type).from_self()]), - set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) - ) - - def test_relation(self): - pjoin = polymorphic_union({ - 'manager':managers_table, - 'engineer':engineers_table - }, 'type', 'pjoin') - - mapper(Company, companies, properties={ - 'employees':relation(Employee) - }) - employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) - manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='manager') - engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer') - - session = create_session() - c = Company() - c.employees.append(Manager('Tom', 'knows how to manage things')) - c.employees.append(Engineer('Kurt', 'knows how to hack')) - session.add(c) - session.flush() - session.expunge_all() - - def go(): - c2 = session.query(Company).get(c.id) - assert set([repr(x) for x in c2.employees]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) - self.assert_sql_count(testing.db, go, 2) - session.expunge_all() - def go(): - c2 = session.query(Company).options(eagerload(Company.employees)).get(c.id) - assert set([repr(x) for x in c2.employees]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) - self.assert_sql_count(testing.db, go, 1) - -class PropertyInheritanceTest(_base.MappedTest): - def define_tables(self, metadata): - Table('a_table', metadata, - Column('id', Integer, primary_key=True), - Column('some_c_id', Integer, ForeignKey('c_table.id')), - Column('aname', String(50)), - ) - Table('b_table', metadata, - Column('id', Integer, primary_key=True), - Column('some_c_id', Integer, ForeignKey('c_table.id')), - Column('bname', String(50)), - ) - Table('c_table', metadata, - Column('id', Integer, primary_key=True), - Column('cname', String(50)), - - ) - - def setup_classes(self): - class A(_base.ComparableEntity): - pass - - class B(A): - pass - - class C(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def test_noninherited_warning(self): - mapper(A, a_table, properties={ - 'some_c':relation(C) - }) - mapper(B, b_table,inherits=A, concrete=True) - mapper(C, c_table) - - b = B() - c = C() - self.assertRaises(AttributeError, setattr, b, 'some_c', c) - - clear_mappers() - mapper(A, a_table, properties={ - 'a_id':a_table.c.id - }) - mapper(B, b_table,inherits=A, concrete=True) - mapper(C, c_table) - b = B() - self.assertRaises(AttributeError, setattr, b, 'a_id', 3) - - clear_mappers() - mapper(A, a_table, properties={ - 'a_id':a_table.c.id - }) - mapper(B, b_table,inherits=A, concrete=True) - mapper(C, c_table) - - @testing.resolve_artifact_names - def test_inheriting(self): - mapper(A, a_table, properties={ - 'some_c':relation(C, back_populates='many_a') - }) - mapper(B, b_table,inherits=A, concrete=True, properties={ - 'some_c':relation(C, back_populates='many_b') - }) - mapper(C, c_table, properties={ - 'many_a':relation(A, back_populates='some_c'), - 'many_b':relation(B, back_populates='some_c'), - }) - - sess = sessionmaker()() - - c1 = C(cname='c1') - c2 = C(cname='c2') - a1 = A(some_c=c1, aname='a1') - a2 = A(some_c=c2, aname='a2') - b1 = B(some_c=c1, bname='b1') - b2 = B(some_c=c1, bname='b2') - - self.assertRaises(AttributeError, setattr, b1, 'aname', 'foo') - self.assertRaises(AttributeError, getattr, A, 'bname') - - assert c2.many_a == [a2] - assert c1.many_a == [a1] - assert c1.many_b == [b1, b2] - - sess.add_all([c1, c2]) - sess.commit() - - assert sess.query(C).filter(C.many_a.contains(a2)).one() is c2 - assert c2.many_a == [a2] - assert c1.many_a == [a1] - assert c1.many_b == [b1, b2] - - assert sess.query(B).filter(B.bname=='b1').one() is b1 - - @testing.resolve_artifact_names - def test_polymorphic_backref(self): - """test multiple backrefs to the same polymorphically-loading attribute.""" - - ajoin = polymorphic_union( - {'a':a_table, - 'b':b_table - }, 'type', 'ajoin' - ) - mapper(A, a_table, with_polymorphic=('*', ajoin), - polymorphic_on=ajoin.c.type, polymorphic_identity='a', - properties={ - 'some_c':relation(C, back_populates='many_a') - }) - mapper(B, b_table,inherits=A, concrete=True, - polymorphic_identity='b', - properties={ - 'some_c':relation(C, back_populates='many_a') - }) - mapper(C, c_table, properties={ - 'many_a':relation(A, back_populates='some_c', order_by=ajoin.c.id), - }) - - sess = sessionmaker()() - - c1 = C(cname='c1') - c2 = C(cname='c2') - a1 = A(some_c=c1, aname='a1', id=1) - a2 = A(some_c=c2, aname='a2', id=2) - b1 = B(some_c=c1, bname='b1', id=3) - b2 = B(some_c=c1, bname='b2', id=4) - - eq_([a2], c2.many_a) - eq_([a1, b1, b2], c1.many_a) - - sess.add_all([c1, c2]) - sess.commit() - - assert sess.query(C).filter(C.many_a.contains(a2)).one() is c2 - assert sess.query(C).filter(C.many_a.contains(b1)).one() is c1 - eq_(c2.many_a, [a2]) - eq_(c1.many_a, [a1, b1, b2]) - - sess.expire_all() - - def go(): - eq_( - [C(many_a=[A(aname='a1'), B(bname='b1'), B(bname='b2')]), C(many_a=[A(aname='a2')])], - sess.query(C).options(eagerload(C.many_a)).order_by(C.id).all(), - ) - self.assert_sql_count(testing.db, go, 1) - - -class ColKeysTest(_base.MappedTest): - def define_tables(self, metadata): - global offices_table, refugees_table - refugees_table = Table('refugee', metadata, - Column('refugee_fid', Integer, primary_key=True), - Column('refugee_name', Unicode(30), key='name')) - - offices_table = Table('office', metadata, - Column('office_fid', Integer, primary_key=True), - Column('office_name', Unicode(30), key='name')) - - def insert_data(self): - refugees_table.insert().execute( - dict(refugee_fid=1, name=u"refugee1"), - dict(refugee_fid=2, name=u"refugee2") - ) - offices_table.insert().execute( - dict(office_fid=1, name=u"office1"), - dict(office_fid=2, name=u"office2") - ) - - def test_keys(self): - pjoin = polymorphic_union({ - 'refugee': refugees_table, - 'office': offices_table - }, 'type', 'pjoin') - class Location(object): - pass - - class Refugee(Location): - pass - - class Office(Location): - pass - - location_mapper = mapper(Location, pjoin, polymorphic_on=pjoin.c.type, - polymorphic_identity='location') - office_mapper = mapper(Office, offices_table, inherits=location_mapper, - concrete=True, polymorphic_identity='office') - refugee_mapper = mapper(Refugee, refugees_table, inherits=location_mapper, - concrete=True, polymorphic_identity='refugee') - - sess = create_session() - eq_(sess.query(Refugee).get(1).name, "refugee1") - eq_(sess.query(Refugee).get(2).name, "refugee2") - - eq_(sess.query(Office).get(1).name, "office1") - eq_(sess.query(Office).get(2).name, "office2") - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/magazine.py b/test/orm/inheritance/magazine.py deleted file mode 100644 index 34374c887..000000000 --- a/test/orm/inheritance/magazine.py +++ /dev/null @@ -1,220 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.orm import * - -from testlib import testing, _function_named -from orm import _base - -class BaseObject(object): - def __init__(self, *args, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) -class Publication(BaseObject): - pass - -class Issue(BaseObject): - pass - -class Location(BaseObject): - def __repr__(self): - return "%s(%s, %s)" % (self.__class__.__name__, str(getattr(self, 'issue_id', None)), repr(str(self._name.name))) - - def _get_name(self): - return self._name - - def _set_name(self, name): - session = create_session() - s = session.query(LocationName).filter(LocationName.name==name).first() - session.expunge_all() - if s is not None: - self._name = s - - return - - found = False - - for i in session.new: - if isinstance(i, LocationName) and i.name == name: - self._name = i - found = True - - break - - if found == False: - self._name = LocationName(name=name) - - name = property(_get_name, _set_name) - -class LocationName(BaseObject): - def __repr__(self): - return "%s()" % (self.__class__.__name__) - -class PageSize(BaseObject): - def __repr__(self): - return "%s(%sx%s, %s)" % (self.__class__.__name__, self.width, self.height, self.name) - -class Magazine(BaseObject): - def __repr__(self): - return "%s(%s, %s)" % (self.__class__.__name__, repr(self.location), repr(self.size)) - -class Page(BaseObject): - def __repr__(self): - return "%s(%s)" % (self.__class__.__name__, str(self.page_no)) - -class MagazinePage(Page): - def __repr__(self): - return "%s(%s, %s)" % (self.__class__.__name__, str(self.page_no), repr(self.magazine)) - -class ClassifiedPage(MagazinePage): - pass - - -class MagazineTest(_base.MappedTest): - def define_tables(self, metadata): - global publication_table, issue_table, location_table, location_name_table, magazine_table, \ - page_table, magazine_page_table, classified_page_table, page_size_table - - zerodefault = {} #{'default':0} - publication_table = Table('publication', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('name', String(45), default=''), - ) - issue_table = Table('issue', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('publication_id', Integer, ForeignKey('publication.id'), **zerodefault), - Column('issue', Integer, **zerodefault), - ) - location_table = Table('location', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('issue_id', Integer, ForeignKey('issue.id'), **zerodefault), - Column('ref', CHAR(3), default=''), - Column('location_name_id', Integer, ForeignKey('location_name.id'), **zerodefault), - ) - location_name_table = Table('location_name', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('name', String(45), default=''), - ) - magazine_table = Table('magazine', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('location_id', Integer, ForeignKey('location.id'), **zerodefault), - Column('page_size_id', Integer, ForeignKey('page_size.id'), **zerodefault), - ) - page_table = Table('page', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('page_no', Integer, **zerodefault), - Column('type', CHAR(1), default='p'), - ) - magazine_page_table = Table('magazine_page', metadata, - Column('page_id', Integer, ForeignKey('page.id'), primary_key=True, **zerodefault), - Column('magazine_id', Integer, ForeignKey('magazine.id'), **zerodefault), - Column('orders', TEXT, default=''), - ) - classified_page_table = Table('classified_page', metadata, - Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True, **zerodefault), - Column('titles', String(45), default=''), - ) - page_size_table = Table('page_size', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('width', Integer, **zerodefault), - Column('height', Integer, **zerodefault), - Column('name', String(45), default=''), - ) - -def generate_round_trip_test(use_unions=False, use_joins=False): - def test_roundtrip(self): - publication_mapper = mapper(Publication, publication_table) - - issue_mapper = mapper(Issue, issue_table, properties = { - 'publication': relation(Publication, backref=backref('issues', cascade="all, delete-orphan")), - }) - - location_name_mapper = mapper(LocationName, location_name_table) - - location_mapper = mapper(Location, location_table, properties = { - 'issue': relation(Issue, backref=backref('locations', lazy=False, cascade="all, delete-orphan")), - '_name': relation(LocationName), - }) - - page_size_mapper = mapper(PageSize, page_size_table) - - magazine_mapper = mapper(Magazine, magazine_table, properties = { - 'location': relation(Location, backref=backref('magazine', uselist=False)), - 'size': relation(PageSize), - }) - - if use_unions: - page_join = polymorphic_union( - { - 'm': page_table.join(magazine_page_table), - 'c': page_table.join(magazine_page_table).join(classified_page_table), - 'p': page_table.select(page_table.c.type=='p'), - }, None, 'page_join') - page_mapper = mapper(Page, page_table, with_polymorphic=('*', page_join), polymorphic_on=page_join.c.type, polymorphic_identity='p') - elif use_joins: - page_join = page_table.outerjoin(magazine_page_table).outerjoin(classified_page_table) - page_mapper = mapper(Page, page_table, with_polymorphic=('*', page_join), polymorphic_on=page_table.c.type, polymorphic_identity='p') - else: - page_mapper = mapper(Page, page_table, polymorphic_on=page_table.c.type, polymorphic_identity='p') - - if use_unions: - magazine_join = polymorphic_union( - { - 'm': page_table.join(magazine_page_table), - 'c': page_table.join(magazine_page_table).join(classified_page_table), - }, None, 'page_join') - magazine_page_mapper = mapper(MagazinePage, magazine_page_table, with_polymorphic=('*', magazine_join), inherits=page_mapper, polymorphic_identity='m', properties={ - 'magazine': relation(Magazine, backref=backref('pages', order_by=magazine_join.c.page_no)) - }) - elif use_joins: - magazine_join = page_table.join(magazine_page_table).outerjoin(classified_page_table) - magazine_page_mapper = mapper(MagazinePage, magazine_page_table, with_polymorphic=('*', magazine_join), inherits=page_mapper, polymorphic_identity='m', properties={ - 'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no)) - }) - else: - magazine_page_mapper = mapper(MagazinePage, magazine_page_table, inherits=page_mapper, polymorphic_identity='m', properties={ - 'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no)) - }) - - classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c', primary_key=[page_table.c.id]) - #compile_mappers() - #print [str(s) for s in classified_page_mapper.primary_key] - #print classified_page_mapper.columntoproperty[page_table.c.id] - - - session = create_session() - - pub = Publication(name='Test') - issue = Issue(issue=46,publication=pub) - - location = Location(ref='ABC',name='London',issue=issue) - - page_size = PageSize(name='A4',width=210,height=297) - - magazine = Magazine(location=location,size=page_size) - page = ClassifiedPage(magazine=magazine,page_no=1) - page2 = MagazinePage(magazine=magazine,page_no=2) - page3 = ClassifiedPage(magazine=magazine,page_no=3) - session.add(pub) - - session.flush() - print [x for x in session] - session.expunge_all() - - session.flush() - session.expunge_all() - p = session.query(Publication).filter(Publication.name=="Test").one() - - print p.issues[0].locations[0].magazine.pages - print [page, page2, page3] - assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]), repr(p.issues[0].locations[0].magazine.pages) - - test_roundtrip = _function_named( - test_roundtrip, "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions")) - setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip) - -for (use_union, use_join) in [(True, False), (False, True), (False, False)]: - generate_round_trip_test(use_union, use_join) - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/manytomany.py b/test/orm/inheritance/manytomany.py deleted file mode 100644 index 5dbf69ba5..000000000 --- a/test/orm/inheritance/manytomany.py +++ /dev/null @@ -1,248 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.orm import * - -from testlib import testing -from orm import _base - - -class InheritTest(_base.MappedTest): - """deals with inheritance and many-to-many relationships""" - def define_tables(self, metadata): - global principals - global users - global groups - global user_group_map - - principals = Table('principals', metadata, - Column('principal_id', Integer, - Sequence('principal_id_seq', optional=False), - primary_key=True), - Column('name', String(50), nullable=False)) - - users = Table('prin_users', metadata, - Column('principal_id', Integer, - ForeignKey('principals.principal_id'), primary_key=True), - Column('password', String(50), nullable=False), - Column('email', String(50), nullable=False), - Column('login_id', String(50), nullable=False)) - - groups = Table('prin_groups', metadata, - Column('principal_id', Integer, - ForeignKey('principals.principal_id'), primary_key=True)) - - user_group_map = Table('prin_user_group_map', metadata, - Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), - primary_key=True ), - Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), - primary_key=True ), - ) - - def testbasic(self): - class Principal(object): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) - - class User(Principal): - pass - - class Group(Principal): - pass - - mapper(Principal, principals) - mapper(User, users, inherits=Principal) - - mapper(Group, groups, inherits=Principal, properties={ - 'users': relation(User, secondary=user_group_map, - lazy=True, backref="groups") - }) - - g = Group(name="group1") - g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1")) - sess = create_session() - sess.add(g) - sess.flush() - # TODO: put an assertion - -class InheritTest2(_base.MappedTest): - """deals with inheritance and many-to-many relationships""" - def define_tables(self, metadata): - global foo, bar, foo_bar - foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_id_seq', optional=True), - primary_key=True), - Column('data', String(20)), - ) - - bar = Table('bar', metadata, - Column('bid', Integer, ForeignKey('foo.id'), primary_key=True), - #Column('fid', Integer, ForeignKey('foo.id'), ) - ) - - foo_bar = Table('foo_bar', metadata, - Column('foo_id', Integer, ForeignKey('foo.id')), - Column('bar_id', Integer, ForeignKey('bar.bid'))) - - def testget(self): - class Foo(object): - def __init__(self, data=None): - self.data = data - class Bar(Foo):pass - - mapper(Foo, foo) - mapper(Bar, bar, inherits=Foo) - print foo.join(bar).primary_key - print class_mapper(Bar).primary_key - b = Bar('somedata') - sess = create_session() - sess.add(b) - sess.flush() - sess.expunge_all() - - # test that "bar.bid" does not need to be referenced in a get - # (ticket 185) - assert sess.query(Bar).get(b.id).id == b.id - - def testbasic(self): - class Foo(object): - def __init__(self, data=None): - self.data = data - - mapper(Foo, foo) - class Bar(Foo): - pass - - mapper(Bar, bar, inherits=Foo, properties={ - 'foos': relation(Foo, secondary=foo_bar, lazy=False) - }) - - sess = create_session() - b = Bar('barfoo') - sess.add(b) - sess.flush() - - f1 = Foo('subfoo1') - f2 = Foo('subfoo2') - b.foos.append(f1) - b.foos.append(f2) - - sess.flush() - sess.expunge_all() - - l = sess.query(Bar).all() - print l[0] - print l[0].foos - self.assert_unordered_result(l, Bar, -# {'id':1, 'data':'barfoo', 'bid':1, 'foos':(Foo, [{'id':2,'data':'subfoo1'}, {'id':3,'data':'subfoo2'}])}, - {'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])}, - ) - -class InheritTest3(_base.MappedTest): - """deals with inheritance and many-to-many relationships""" - def define_tables(self, metadata): - global foo, bar, blub, bar_foo, blub_bar, blub_foo - - # the 'data' columns are to appease SQLite which cant handle a blank INSERT - foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), - Column('data', String(20))) - - bar = Table('bar', metadata, - Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - Column('data', String(20))) - - blub = Table('blub', metadata, - Column('id', Integer, ForeignKey('bar.id'), primary_key=True), - Column('data', String(20))) - - bar_foo = Table('bar_foo', metadata, - Column('bar_id', Integer, ForeignKey('bar.id')), - Column('foo_id', Integer, ForeignKey('foo.id'))) - - blub_bar = Table('bar_blub', metadata, - Column('blub_id', Integer, ForeignKey('blub.id')), - Column('bar_id', Integer, ForeignKey('bar.id'))) - - blub_foo = Table('blub_foo', metadata, - Column('blub_id', Integer, ForeignKey('blub.id')), - Column('foo_id', Integer, ForeignKey('foo.id'))) - - def testbasic(self): - class Foo(object): - def __init__(self, data=None): - self.data = data - def __repr__(self): - return "Foo id %d, data %s" % (self.id, self.data) - mapper(Foo, foo) - - class Bar(Foo): - def __repr__(self): - return "Bar id %d, data %s" % (self.id, self.data) - - mapper(Bar, bar, inherits=Foo, properties={ - 'foos' :relation(Foo, secondary=bar_foo, lazy=True) - }) - - sess = create_session() - b = Bar('bar #1') - sess.add(b) - b.foos.append(Foo("foo #1")) - b.foos.append(Foo("foo #2")) - sess.flush() - compare = repr(b) + repr(sorted([repr(o) for o in b.foos])) - sess.expunge_all() - l = sess.query(Bar).all() - print repr(l[0]) + repr(l[0].foos) - found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos])) - self.assertEqual(found, compare) - - @testing.fails_on('maxdb', 'FIXME: unknown') - def testadvanced(self): - class Foo(object): - def __init__(self, data=None): - self.data = data - def __repr__(self): - return "Foo id %d, data %s" % (self.id, self.data) - mapper(Foo, foo) - - class Bar(Foo): - def __repr__(self): - return "Bar id %d, data %s" % (self.id, self.data) - mapper(Bar, bar, inherits=Foo) - - class Blub(Bar): - def __repr__(self): - return "Blub id %d, data %s, bars %s, foos %s" % (self.id, self.data, repr([b for b in self.bars]), repr([f for f in self.foos])) - - mapper(Blub, blub, inherits=Bar, properties={ - 'bars':relation(Bar, secondary=blub_bar, lazy=False), - 'foos':relation(Foo, secondary=blub_foo, lazy=False), - }) - - sess = create_session() - f1 = Foo("foo #1") - b1 = Bar("bar #1") - b2 = Bar("bar #2") - bl1 = Blub("blub #1") - for o in (f1, b1, b2, bl1): - sess.add(o) - bl1.foos.append(f1) - bl1.bars.append(b2) - sess.flush() - compare = repr(bl1) - blubid = bl1.id - sess.expunge_all() - - l = sess.query(Blub).all() - print l - self.assert_(repr(l[0]) == compare) - sess.expunge_all() - x = sess.query(Blub).filter_by(id=blubid).one() - print x - self.assert_(repr(x) == compare) - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/poly_linked_list.py b/test/orm/inheritance/poly_linked_list.py deleted file mode 100644 index 2cf051949..000000000 --- a/test/orm/inheritance/poly_linked_list.py +++ /dev/null @@ -1,199 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.orm import * - -from orm import _base -from testlib import testing - - -class PolymorphicCircularTest(_base.MappedTest): - run_setup_mappers = 'once' - - def define_tables(self, metadata): - global Table1, Table1B, Table2, Table3, Data - table1 = Table('table1', metadata, - Column('id', Integer, primary_key=True), - Column('related_id', Integer, ForeignKey('table1.id'), nullable=True), - Column('type', String(30)), - Column('name', String(30)) - ) - - table2 = Table('table2', metadata, - Column('id', Integer, ForeignKey('table1.id'), primary_key=True), - ) - - table3 = Table('table3', metadata, - Column('id', Integer, ForeignKey('table1.id'), primary_key=True), - ) - - data = Table('data', metadata, - Column('id', Integer, primary_key=True), - Column('node_id', Integer, ForeignKey('table1.id')), - Column('data', String(30)) - ) - - #join = polymorphic_union( - # { - # 'table3' : table1.join(table3), - # 'table2' : table1.join(table2), - # 'table1' : table1.select(table1.c.type.in_(['table1', 'table1b'])), - # }, None, 'pjoin') - - join = table1.outerjoin(table2).outerjoin(table3).alias('pjoin') - #join = None - - class Table1(object): - def __init__(self, name, data=None): - self.name = name - if data is not None: - self.data = data - def __repr__(self): - return "%s(%s, %s, %s)" % (self.__class__.__name__, self.id, repr(str(self.name)), repr(self.data)) - - class Table1B(Table1): - pass - - class Table2(Table1): - pass - - class Table3(Table1): - pass - - class Data(object): - def __init__(self, data): - self.data = data - def __repr__(self): - return "%s(%s, %s)" % (self.__class__.__name__, self.id, repr(str(self.data))) - - try: - # this is how the mapping used to work. ensure that this raises an error now - table1_mapper = mapper(Table1, table1, - select_table=join, - polymorphic_on=table1.c.type, - polymorphic_identity='table1', - properties={ - 'next': relation(Table1, - backref=backref('prev', foreignkey=join.c.id, uselist=False), - uselist=False, primaryjoin=join.c.id==join.c.related_id), - 'data':relation(mapper(Data, data)) - }, - order_by=table1.c.id) - table1_mapper.compile() - assert False - except: - assert True - clear_mappers() - - # currently, the "eager" relationships degrade to lazy relationships - # due to the polymorphic load. - # the "next" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential" - # exception now. since eager loading would never work for that relation anyway, its better that the user - # gets an exception instead of it silently not eager loading. - table1_mapper = mapper(Table1, table1, - #select_table=join, - polymorphic_on=table1.c.type, - polymorphic_identity='table1', - properties={ - 'next': relation(Table1, - backref=backref('prev', remote_side=table1.c.id, uselist=False), - uselist=False, primaryjoin=table1.c.id==table1.c.related_id), - 'data':relation(mapper(Data, data), lazy=False, order_by=data.c.id) - }, - order_by=table1.c.id - ) - - table1b_mapper = mapper(Table1B, inherits=table1_mapper, polymorphic_identity='table1b') - - table2_mapper = mapper(Table2, table2, - inherits=table1_mapper, - polymorphic_identity='table2') - - table3_mapper = mapper(Table3, table3, inherits=table1_mapper, polymorphic_identity='table3') - - table1_mapper.compile() - assert table1_mapper.primary_key == [table1.c.id], table1_mapper.primary_key - - @testing.fails_on('maxdb', 'FIXME: unknown') - def testone(self): - self.do_testlist([Table1, Table2, Table1, Table2]) - - @testing.fails_on('maxdb', 'FIXME: unknown') - def testtwo(self): - self.do_testlist([Table3]) - - @testing.fails_on('maxdb', 'FIXME: unknown') - def testthree(self): - self.do_testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1]) - - @testing.fails_on('maxdb', 'FIXME: unknown') - def testfour(self): - self.do_testlist([ - Table2('t2', [Data('data1'), Data('data2')]), - Table1('t1', []), - Table3('t3', [Data('data3')]), - Table1B('t1b', [Data('data4'), Data('data5')]) - ]) - - def do_testlist(self, classes): - sess = create_session( ) - - # create objects in a linked list - count = 1 - obj = None - for c in classes: - if isinstance(c, type): - newobj = c('item %d' % count) - count += 1 - else: - newobj = c - if obj is not None: - obj.next = newobj - else: - t = newobj - obj = newobj - - # save to DB - sess.add(t) - sess.flush() - - # string version of the saved list - assertlist = [] - node = t - while (node): - assertlist.append(node) - n = node.next - if n is not None: - assert n.prev is node - node = n - original = repr(assertlist) - - - # clear and query forwards - sess.expunge_all() - node = sess.query(Table1).filter(Table1.id==t.id).first() - assertlist = [] - while (node): - assertlist.append(node) - n = node.next - if n is not None: - assert n.prev is node - node = n - forwards = repr(assertlist) - - # clear and query backwards - sess.expunge_all() - node = sess.query(Table1).filter(Table1.id==obj.id).first() - assertlist = [] - while (node): - assertlist.insert(0, node) - n = node.prev - if n is not None: - assert n.next is node - node = n - backwards = repr(assertlist) - - # everything should match ! - assert original == forwards == backwards - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/polymorph.py deleted file mode 100644 index 81f6c82a1..000000000 --- a/test/orm/inheritance/polymorph.py +++ /dev/null @@ -1,304 +0,0 @@ -"""tests basic polymorphic mapper loading/saving, minimal relations""" - -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.orm import * -from sqlalchemy.orm import exc as orm_exc -from testlib import _function_named, Column, testing -from orm import _fixtures, _base - -class Person(_fixtures.Base): - pass -class Engineer(Person): - pass -class Manager(Person): - pass -class Boss(Manager): - pass -class Company(_fixtures.Base): - pass - -class PolymorphTest(_base.MappedTest): - def define_tables(self, metadata): - global companies, people, engineers, managers, boss - - companies = Table('companies', metadata, - Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('name', String(50))) - - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('company_id', Integer, ForeignKey('companies.company_id')), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('status', String(30)), - Column('engineer_name', String(50)), - Column('primary_language', String(50)), - ) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('status', String(30)), - Column('manager_name', String(50)) - ) - - boss = Table('boss', metadata, - Column('boss_id', Integer, ForeignKey('managers.person_id'), primary_key=True), - Column('golf_swing', String(30)), - ) - - metadata.create_all() - -class InsertOrderTest(PolymorphTest): - def test_insert_order(self): - """test that classes of multiple types mix up mapper inserts - so that insert order of individual tables is maintained""" - person_join = polymorphic_union( - { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - 'person':people.select(people.c.type=='person'), - }, None, 'pjoin') - - person_mapper = mapper(Person, people, with_polymorphic=('*', person_join), polymorphic_on=person_join.c.type, polymorphic_identity='person') - - mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') - mapper(Company, companies, properties={ - 'employees': relation(Person, - backref='company', - order_by=person_join.c.person_id) - }) - - session = create_session() - c = Company(name='company1') - c.employees.append(Manager(status='AAB', manager_name='manager1', name='pointy haired boss')) - c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', name='dilbert')) - c.employees.append(Person(status='HHH', name='joesmith')) - c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', name='wally')) - c.employees.append(Manager(status='ABA', manager_name='manager2', name='jsmith')) - session.add(c) - session.flush() - session.expunge_all() - self.assertEquals(session.query(Company).get(c.company_id), c) - -class RelationToSubclassTest(PolymorphTest): - def test_basic(self): - """test a relation to an inheriting mapper where the relation is to a subclass - but the join condition is expressed by the parent table. - - also test that backrefs work in this case. - - this test touches upon a lot of the join/foreign key determination code in properties.py - and creates the need for properties.py to search for conditions individually within - the mapper's local table as well as the mapper's 'mapped' table, so that relations - requiring lots of specificity (like self-referential joins) as well as relations requiring - more generalization (like the example here) both come up with proper results.""" - - mapper(Person, people) - - mapper(Engineer, engineers, inherits=Person) - mapper(Manager, managers, inherits=Person) - - mapper(Company, companies, properties={ - 'managers': relation(Manager, backref="company") - }) - - sess = create_session() - - c = Company(name='company1') - c.managers.append(Manager(status='AAB', manager_name='manager1', name='pointy haired boss')) - sess.add(c) - sess.flush() - sess.expunge_all() - - self.assertEquals(sess.query(Company).filter_by(company_id=c.company_id).one(), c) - assert c.managers[0].company is c - -class RoundTripTest(PolymorphTest): - pass - -def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic): - """generates a round trip test. - - include_base - whether or not to include the base 'person' type in the union. - lazy_relation - whether or not the Company relation to People is lazy or eager. - redefine_colprop - if we redefine the 'name' column to be 'people_name' on the base Person class - use_literal_join - primary join condition is explicitly specified - """ - def test_roundtrip(self): - if with_polymorphic == 'unions': - if include_base: - person_join = polymorphic_union( - { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - 'person':people.select(people.c.type=='person'), - }, None, 'pjoin') - else: - person_join = polymorphic_union( - { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - }, None, 'pjoin') - - manager_join = people.join(managers).outerjoin(boss) - person_with_polymorphic = ['*', person_join] - manager_with_polymorphic = ['*', manager_join] - elif with_polymorphic == 'joins': - person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss) - manager_join = people.join(managers).outerjoin(boss) - person_with_polymorphic = ['*', person_join] - manager_with_polymorphic = ['*', manager_join] - elif with_polymorphic == 'auto': - person_with_polymorphic = '*' - manager_with_polymorphic = '*' - else: - person_with_polymorphic = None - manager_with_polymorphic = None - - if redefine_colprop: - person_mapper = mapper(Person, people, with_polymorphic=person_with_polymorphic, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name}) - else: - person_mapper = mapper(Person, people, with_polymorphic=person_with_polymorphic, polymorphic_on=people.c.type, polymorphic_identity='person') - - mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - mapper(Manager, managers, inherits=person_mapper, with_polymorphic=manager_with_polymorphic, polymorphic_identity='manager') - - mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss') - - mapper(Company, companies, properties={ - 'employees': relation(Person, lazy=lazy_relation, - cascade="all, delete-orphan", - backref="company", order_by=people.c.person_id - ) - }) - - if redefine_colprop: - person_attribute_name = 'person_name' - else: - person_attribute_name = 'name' - - employees = [ - Manager(status='AAB', manager_name='manager1', **{person_attribute_name:'pointy haired boss'}), - Engineer(status='BBA', engineer_name='engineer1', primary_language='java', **{person_attribute_name:'dilbert'}), - ] - if include_base: - employees.append(Person(**{person_attribute_name:'joesmith'})) - employees += [ - Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'}), - Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'}) - ] - - pointy = employees[0] - jsmith = employees[-1] - dilbert = employees[1] - - session = create_session() - c = Company(name='company1') - c.employees = employees - session.add(c) - - session.flush() - session.expunge_all() - - self.assertEquals(session.query(Person).get(dilbert.person_id), dilbert) - session.expunge_all() - - self.assertEquals(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert) - session.expunge_all() - - def go(): - cc = session.query(Company).get(c.company_id) - self.assertEquals(cc.employees, employees) - - if not lazy_relation: - if with_polymorphic != 'none': - self.assert_sql_count(testing.db, go, 1) - else: - self.assert_sql_count(testing.db, go, 5) - - else: - if with_polymorphic != 'none': - self.assert_sql_count(testing.db, go, 2) - else: - self.assert_sql_count(testing.db, go, 6) - - # test selecting from the query, using the base mapped table (people) as the selection criterion. - # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" - self.assertEquals( - session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first(), - dilbert - ) - - assert session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first().person_id - - self.assertEquals( - session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first(), - dilbert - ) - - # test selecting from the query, joining against an alias of the base "people" table. test that - # the "palias" alias does *not* get sucked up into the "person_join" conversion. - palias = people.alias("palias") - dilbert = session.query(Person).get(dilbert.person_id) - assert dilbert is session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() - assert dilbert is session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() - assert dilbert is session.query(Person).filter((Engineer.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id)).first() - assert dilbert is session.query(Engineer).filter(Engineer.engineer_name=="engineer1")[0] - - dilbert.engineer_name = 'hes dibert!' - - session.flush() - session.expunge_all() - - def go(): - session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() - self.assert_sql_count(testing.db, go, 1) - session.expunge_all() - dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() - def go(): - # assert that only primary table is queried for already-present-in-session - d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() - self.assert_sql_count(testing.db, go, 1) - - # test standalone orphans - daboss = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'}) - session.add(daboss) - self.assertRaises(orm_exc.FlushError, session.flush) - c = session.query(Company).first() - daboss.company = c - manager_list = [e for e in c.employees if isinstance(e, Manager)] - session.flush() - session.expunge_all() - - self.assertEquals(session.query(Manager).order_by(Manager.person_id).all(), manager_list) - c = session.query(Company).first() - - session.delete(c) - session.flush() - - self.assertEquals(people.count().scalar(), 0) - - test_roundtrip = _function_named( - test_roundtrip, "test_%s%s%s_%s" % ( - (lazy_relation and "lazy" or "eager"), - (include_base and "_inclbase" or ""), - (redefine_colprop and "_redefcol" or ""), - with_polymorphic)) - setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip) - -for lazy_relation in [True, False]: - for redefine_colprop in [True, False]: - for with_polymorphic in ['unions', 'joins', 'auto', 'none']: - if with_polymorphic == 'unions': - for include_base in [True, False]: - generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic) - else: - generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic) - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/polymorph2.py b/test/orm/inheritance/polymorph2.py deleted file mode 100644 index aec162b75..000000000 --- a/test/orm/inheritance/polymorph2.py +++ /dev/null @@ -1,1107 +0,0 @@ -"""this is a test suite consisting mainly of end-user test cases, testing all kinds of painful -inheritance setups for which we maintain compatibility. -""" - -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy import util -from sqlalchemy.orm import * - -from testlib import _function_named, TestBase, AssertsExecutionResults, testing -from orm import _base, _fixtures -from testlib.testing import eq_ - -class AttrSettable(object): - def __init__(self, **kwargs): - [setattr(self, k, v) for k, v in kwargs.iteritems()] - def __repr__(self): - return self.__class__.__name__ + "(%s)" % (hex(id(self))) - - -class RelationTest1(_base.MappedTest): - """test self-referential relationships on polymorphic mappers""" - def define_tables(self, metadata): - global people, managers - - people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), - Column('manager_id', Integer, ForeignKey('managers.person_id', use_alter=True, name="mpid_fq")), - Column('name', String(50)), - Column('type', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('status', String(30)), - Column('manager_name', String(50)) - ) - - def tearDown(self): - people.update(values={people.c.manager_id:None}).execute() - super(RelationTest1, self).tearDown() - - def testparentrefsdescendant(self): - class Person(AttrSettable): - pass - class Manager(Person): - pass - - # note that up until recently (0.4.4), we had to specify "foreign_keys" here - # for this primary join. - mapper(Person, people, properties={ - 'manager':relation(Manager, primaryjoin=(people.c.manager_id == - managers.c.person_id), - uselist=False, post_update=True) - }) - mapper(Manager, managers, inherits=Person, - inherit_condition=people.c.person_id==managers.c.person_id) - - self.assertEquals(class_mapper(Person).get_property('manager').synchronize_pairs, [(managers.c.person_id,people.c.manager_id)]) - - session = create_session() - p = Person(name='some person') - m = Manager(name='some manager') - p.manager = m - session.add(p) - session.flush() - session.expunge_all() - - p = session.query(Person).get(p.person_id) - m = session.query(Manager).get(m.person_id) - print p, m, p.manager - assert p.manager is m - - def testdescendantrefsparent(self): - class Person(AttrSettable): - pass - class Manager(Person): - pass - - mapper(Person, people) - mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, properties={ - 'employee':relation(Person, primaryjoin=(people.c.manager_id == - managers.c.person_id), - foreign_keys=[people.c.manager_id], - uselist=False, post_update=True) - }) - - session = create_session() - p = Person(name='some person') - m = Manager(name='some manager') - m.employee = p - session.add(m) - session.flush() - session.expunge_all() - - p = session.query(Person).get(p.person_id) - m = session.query(Manager).get(m.person_id) - print p, m, m.employee - assert m.employee is p - -class RelationTest2(_base.MappedTest): - """test self-referential relationships on polymorphic mappers""" - def define_tables(self, metadata): - global people, managers, data - people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), - Column('name', String(50)), - Column('type', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('manager_id', Integer, ForeignKey('people.person_id')), - Column('status', String(30)), - ) - - data = Table('data', metadata, - Column('person_id', Integer, ForeignKey('managers.person_id'), primary_key=True), - Column('data', String(30)) - ) - - def testrelationonsubclass_j1_nodata(self): - self.do_test("join1", False) - def testrelationonsubclass_j2_nodata(self): - self.do_test("join2", False) - def testrelationonsubclass_j1_data(self): - self.do_test("join1", True) - def testrelationonsubclass_j2_data(self): - self.do_test("join2", True) - def testrelationonsubclass_j3_nodata(self): - self.do_test("join3", False) - def testrelationonsubclass_j3_data(self): - self.do_test("join3", True) - - def do_test(self, jointype="join1", usedata=False): - class Person(AttrSettable): - pass - class Manager(Person): - pass - - if jointype == "join1": - poly_union = polymorphic_union({ - 'person':people.select(people.c.type=='person'), - 'manager':join(people, managers, people.c.person_id==managers.c.person_id) - }, None) - polymorphic_on=poly_union.c.type - elif jointype == "join2": - poly_union = polymorphic_union({ - 'person':people.select(people.c.type=='person'), - 'manager':managers.join(people, people.c.person_id==managers.c.person_id) - }, None) - polymorphic_on=poly_union.c.type - elif jointype == "join3": - poly_union = None - polymorphic_on = people.c.type - - if usedata: - class Data(object): - def __init__(self, data): - self.data = data - mapper(Data, data) - - mapper(Person, people, with_polymorphic=('*', poly_union), polymorphic_identity='person', polymorphic_on=polymorphic_on) - - if usedata: - mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager', - properties={ - 'colleague':relation(Person, primaryjoin=managers.c.manager_id==people.c.person_id, lazy=True, uselist=False), - 'data':relation(Data, uselist=False) - } - ) - else: - mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager', - properties={ - 'colleague':relation(Person, primaryjoin=managers.c.manager_id==people.c.person_id, lazy=True, uselist=False) - } - ) - - sess = create_session() - p = Person(name='person1') - m = Manager(name='manager1') - m.colleague = p - if usedata: - m.data = Data('ms data') - sess.add(m) - sess.flush() - - sess.expunge_all() - p = sess.query(Person).get(p.person_id) - m = sess.query(Manager).get(m.person_id) - print p - print m - assert m.colleague is p - if usedata: - assert m.data.data == 'ms data' - -class RelationTest3(_base.MappedTest): - """test self-referential relationships on polymorphic mappers""" - def define_tables(self, metadata): - global people, managers, data - people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), - Column('colleague_id', Integer, ForeignKey('people.person_id')), - Column('name', String(50)), - Column('type', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('status', String(30)), - ) - - data = Table('data', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('data', String(30)) - ) - -def generate_test(jointype="join1", usedata=False): - def do_test(self): - class Person(AttrSettable): - pass - class Manager(Person): - pass - - if usedata: - class Data(object): - def __init__(self, data): - self.data = data - - if jointype == "join1": - poly_union = polymorphic_union({ - 'manager':managers.join(people, people.c.person_id==managers.c.person_id), - 'person':people.select(people.c.type=='person') - }, None) - elif jointype =="join2": - poly_union = polymorphic_union({ - 'manager':join(people, managers, people.c.person_id==managers.c.person_id), - 'person':people.select(people.c.type=='person') - }, None) - elif jointype == 'join3': - poly_union = people.outerjoin(managers) - elif jointype == "join4": - poly_union=None - - if usedata: - mapper(Data, data) - - if usedata: - mapper(Person, people, with_polymorphic=('*', poly_union), polymorphic_identity='person', polymorphic_on=people.c.type, - properties={ - 'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, remote_side=people.c.colleague_id, uselist=True), - 'data':relation(Data, uselist=False) - } - ) - else: - mapper(Person, people, with_polymorphic=('*', poly_union), polymorphic_identity='person', polymorphic_on=people.c.type, - properties={ - 'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, - remote_side=people.c.colleague_id, uselist=True) - } - ) - - mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager') - - sess = create_session() - p = Person(name='person1') - p2 = Person(name='person2') - p3 = Person(name='person3') - m = Manager(name='manager1') - p.colleagues.append(p2) - m.colleagues.append(p3) - if usedata: - p.data = Data('ps data') - m.data = Data('ms data') - - sess.add(m) - sess.add(p) - sess.flush() - - sess.expunge_all() - p = sess.query(Person).get(p.person_id) - p2 = sess.query(Person).get(p2.person_id) - p3 = sess.query(Person).get(p3.person_id) - m = sess.query(Person).get(m.person_id) - print p, p2, p.colleagues, m.colleagues - assert len(p.colleagues) == 1 - assert p.colleagues == [p2] - assert m.colleagues == [p3] - if usedata: - assert p.data.data == 'ps data' - assert m.data.data == 'ms data' - - do_test = _function_named( - do_test, 'test_relationonbaseclass_%s_%s' % ( - jointype, data and "nodata" or "data")) - return do_test - -for jointype in ["join1", "join2", "join3", "join4"]: - for data in (True, False): - func = generate_test(jointype, data) - setattr(RelationTest3, func.__name__, func) - - -class RelationTest4(_base.MappedTest): - def define_tables(self, metadata): - global people, engineers, managers, cars - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), - Column('name', String(50))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('status', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('longer_status', String(70))) - - cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), - Column('owner', Integer, ForeignKey('people.person_id'))) - - def testmanytoonepolymorphic(self): - """in this test, the polymorphic union is between two subclasses, but does not include the base table by itself - in the union. however, the primaryjoin condition is going to be against the base table, and its a many-to-one - relationship (unlike the test in polymorph.py) so the column in the base table is explicit. Can the ClauseAdapter - figure out how to alias the primaryjoin to the polymorphic union ?""" - - # class definitions - class Person(object): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) - def __repr__(self): - return "Ordinary person %s" % self.name - class Engineer(Person): - def __repr__(self): - return "Engineer %s, status %s" % (self.name, self.status) - class Manager(Person): - def __repr__(self): - return "Manager %s, status %s" % (self.name, self.longer_status) - class Car(object): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) - def __repr__(self): - return "Car number %d" % self.car_id - - # create a union that represents both types of joins. - employee_join = polymorphic_union( - { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - }, "type", 'employee_join') - - person_mapper = mapper(Person, people, with_polymorphic=('*', employee_join), polymorphic_on=employee_join.c.type, polymorphic_identity='person') - engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') - car_mapper = mapper(Car, cars, properties= {'employee':relation(person_mapper)}) - - session = create_session() - - # creating 5 managers named from M1 to E5 - for i in range(1,5): - session.add(Manager(name="M%d" % i,longer_status="YYYYYYYYY")) - # creating 5 engineers named from E1 to E5 - for i in range(1,5): - session.add(Engineer(name="E%d" % i,status="X")) - - session.flush() - - engineer4 = session.query(Engineer).filter(Engineer.name=="E4").first() - manager3 = session.query(Manager).filter(Manager.name=="M3").first() - - car1 = Car(employee=engineer4) - session.add(car1) - car2 = Car(employee=manager3) - session.add(car2) - session.flush() - - session.expunge_all() - - def go(): - testcar = session.query(Car).options(eagerload('employee')).get(car1.car_id) - assert str(testcar.employee) == "Engineer E4, status X" - self.assert_sql_count(testing.db, go, 1) - - print "----------------------------" - car1 = session.query(Car).get(car1.car_id) - print "----------------------------" - usingGet = session.query(person_mapper).get(car1.owner) - print "----------------------------" - usingProperty = car1.employee - print "----------------------------" - - # All print should output the same person (engineer E4) - assert str(engineer4) == "Engineer E4, status X" - print str(usingGet) - assert str(usingGet) == "Engineer E4, status X" - assert str(usingProperty) == "Engineer E4, status X" - - session.expunge_all() - print "-----------------------------------------------------------------" - # and now for the lightning round, eager ! - - def go(): - testcar = session.query(Car).options(eagerload('employee')).get(car1.car_id) - assert str(testcar.employee) == "Engineer E4, status X" - self.assert_sql_count(testing.db, go, 1) - - session.expunge_all() - s = session.query(Car) - c = s.join("employee").filter(Person.name=="E4")[0] - assert c.car_id==car1.car_id - -class RelationTest5(_base.MappedTest): - def define_tables(self, metadata): - global people, engineers, managers, cars - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), - Column('name', String(50)), - Column('type', String(50))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('status', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('longer_status', String(70))) - - cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), - Column('owner', Integer, ForeignKey('people.person_id'))) - - def testeagerempty(self): - """an easy one...test parent object with child relation to an inheriting mapper, using eager loads, - works when there are no child objects present""" - class Person(object): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) - def __repr__(self): - return "Ordinary person %s" % self.name - class Engineer(Person): - def __repr__(self): - return "Engineer %s, status %s" % (self.name, self.status) - class Manager(Person): - def __repr__(self): - return "Manager %s, status %s" % (self.name, self.longer_status) - class Car(object): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) - def __repr__(self): - return "Car number %d" % self.car_id - - person_mapper = mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') - engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') - car_mapper = mapper(Car, cars, properties= {'manager':relation(manager_mapper, lazy=False)}) - - sess = create_session() - car1 = Car() - car2 = Car() - car2.manager = Manager() - sess.add(car1) - sess.add(car2) - sess.flush() - sess.expunge_all() - - carlist = sess.query(Car).all() - assert carlist[0].manager is None - assert carlist[1].manager.person_id == car2.manager.person_id - -class RelationTest6(_base.MappedTest): - """test self-referential relationships on a single joined-table inheritance mapper""" - def define_tables(self, metadata): - global people, managers, data - people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), - Column('name', String(50)), - ) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('colleague_id', Integer, ForeignKey('managers.person_id')), - Column('status', String(30)), - ) - - def testbasic(self): - class Person(AttrSettable): - pass - class Manager(Person): - pass - - mapper(Person, people) - # relationship is from people.join(managers) -> people.join(managers). self referential logic - # needs to be used to figure out the lazy clause, meaning create_lazy_clause must go from parent.mapped_table - # to parent.mapped_table - mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, - properties={ - 'colleague':relation(Manager, primaryjoin=managers.c.colleague_id==managers.c.person_id, lazy=True, uselist=False) - } - ) - - sess = create_session() - m = Manager(name='manager1') - m2 =Manager(name='manager2') - m.colleague = m2 - sess.add(m) - sess.flush() - - sess.expunge_all() - m = sess.query(Manager).get(m.person_id) - m2 = sess.query(Manager).get(m2.person_id) - assert m.colleague is m2 - -class RelationTest7(_base.MappedTest): - def define_tables(self, metadata): - global people, engineers, managers, cars, offroad_cars - cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), - Column('name', String(30))) - - offroad_cars = Table('offroad_cars', metadata, - Column('car_id',Integer, ForeignKey('cars.car_id'),nullable=False,primary_key=True)) - - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), - Column('car_id', Integer, ForeignKey('cars.car_id'), nullable=False), - Column('name', String(50))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('field', String(30))) - - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('category', String(70))) - - def test_manytoone_lazyload(self): - """test that lazy load clause to a polymorphic child mapper generates correctly [ticket:493]""" - class PersistentObject(object): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) - - class Status(PersistentObject): - def __repr__(self): - return "Status %s" % self.name - - class Person(PersistentObject): - def __repr__(self): - return "Ordinary person %s" % self.name - - class Engineer(Person): - def __repr__(self): - return "Engineer %s, field %s" % (self.name, self.field) - - class Manager(Person): - def __repr__(self): - return "Manager %s, category %s" % (self.name, self.category) - - class Car(PersistentObject): - def __repr__(self): - return "Car number %d, name %s" % (self.car_id, self.name) - - class Offraod_Car(Car): - def __repr__(self): - return "Offroad Car number %d, name %s" % (self.car_id,self.name) - - employee_join = polymorphic_union( - { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - }, "type", 'employee_join') - - car_join = polymorphic_union( - { - 'car' : cars.outerjoin(offroad_cars).select(offroad_cars.c.car_id == None, fold_equivalents=True), - 'offroad' : cars.join(offroad_cars) - }, "type", 'car_join') - - car_mapper = mapper(Car, cars, - with_polymorphic=('*', car_join) ,polymorphic_on=car_join.c.type, - polymorphic_identity='car', - ) - offroad_car_mapper = mapper(Offraod_Car, offroad_cars, inherits=car_mapper, polymorphic_identity='offroad') - person_mapper = mapper(Person, people, - with_polymorphic=('*', employee_join), polymorphic_on=employee_join.c.type, - polymorphic_identity='person', - properties={ - 'car':relation(car_mapper) - }) - engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') - - session = create_session() - basic_car=Car(name="basic") - offroad_car=Offraod_Car(name="offroad") - - for i in range(1,4): - if i%2: - car=Car() - else: - car=Offraod_Car() - session.add(Manager(name="M%d" % i,category="YYYYYYYYY",car=car)) - session.add(Engineer(name="E%d" % i,field="X",car=car)) - session.flush() - session.expunge_all() - - r = session.query(Person).all() - for p in r: - assert p.car_id == p.car.car_id - -class RelationTest8(_base.MappedTest): - def define_tables(self, metadata): - global taggable, users - taggable = Table('taggable', metadata, - Column('id', Integer, primary_key=True), - Column('type', String(30)), - Column('owner_id', Integer, ForeignKey('taggable.id')), - ) - users = Table ('users', metadata, - Column('id', Integer, ForeignKey('taggable.id'), primary_key=True), - Column('data', String(50)), - ) - - def test_selfref_onjoined(self): - class Taggable(_base.ComparableEntity): - pass - - class User(Taggable): - pass - - mapper( Taggable, taggable, polymorphic_on=taggable.c.type, polymorphic_identity='taggable', properties = { - 'owner' : relation (User, - primaryjoin=taggable.c.owner_id ==taggable.c.id, - remote_side=taggable.c.id - ), - }) - - - mapper(User, users, inherits=Taggable, polymorphic_identity='user', - inherit_condition=users.c.id == taggable.c.id, - ) - - - u1 = User(data='u1') - t1 = Taggable(owner=u1) - sess = create_session() - sess.add(t1) - sess.flush() - - sess.expunge_all() - eq_( - sess.query(Taggable).order_by(Taggable.id).all(), - [User(data='u1'), Taggable(owner=User(data='u1'))] - ) - -class GenerativeTest(TestBase, AssertsExecutionResults): - def setUpAll(self): - # cars---owned by--- people (abstract) --- has a --- status - # | ^ ^ | - # | | | | - # | engineers managers | - # | | - # +--------------------------------------- has a ------+ - - global metadata, status, people, engineers, managers, cars - metadata = MetaData(testing.db) - # table definitions - status = Table('status', metadata, - Column('status_id', Integer, primary_key=True), - Column('name', String(20))) - - people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), - Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), - Column('name', String(50))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('field', String(30))) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('category', String(70))) - - cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), - Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), - Column('owner', Integer, ForeignKey('people.person_id'), nullable=False)) - - metadata.create_all() - - def tearDownAll(self): - metadata.drop_all() - def tearDown(self): - clear_mappers() - for t in reversed(metadata.sorted_tables): - t.delete().execute() - - def testjointo(self): - # class definitions - class PersistentObject(object): - def __init__(self, **kwargs): - for key, value in kwargs.iteritems(): - setattr(self, key, value) - class Status(PersistentObject): - def __repr__(self): - return "Status %s" % self.name - class Person(PersistentObject): - def __repr__(self): - return "Ordinary person %s" % self.name - class Engineer(Person): - def __repr__(self): - return "Engineer %s, field %s, status %s" % (self.name, self.field, self.status) - class Manager(Person): - def __repr__(self): - return "Manager %s, category %s, status %s" % (self.name, self.category, self.status) - class Car(PersistentObject): - def __repr__(self): - return "Car number %d" % self.car_id - - # create a union that represents both types of joins. - employee_join = polymorphic_union( - { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - }, "type", 'employee_join') - - status_mapper = mapper(Status, status) - person_mapper = mapper(Person, people, - with_polymorphic=('*', employee_join), polymorphic_on=employee_join.c.type, - polymorphic_identity='person', properties={'status':relation(status_mapper)}) - engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') - car_mapper = mapper(Car, cars, properties= {'employee':relation(person_mapper), 'status':relation(status_mapper)}) - - session = create_session() - - active = Status(name="active") - dead = Status(name="dead") - - session.add(active) - session.add(dead) - session.flush() - - # TODO: we haven't created assertions for all the data combinations created here - - # creating 5 managers named from M1 to M5 and 5 engineers named from E1 to E5 - # M4, M5, E4 and E5 are dead - for i in range(1,5): - if i<4: - st=active - else: - st=dead - session.add(Manager(name="M%d" % i,category="YYYYYYYYY",status=st)) - session.add(Engineer(name="E%d" % i,field="X",status=st)) - - session.flush() - - # get E4 - engineer4 = session.query(engineer_mapper).filter_by(name="E4").one() - - # create 2 cars for E4, one active and one dead - car1 = Car(employee=engineer4,status=active) - car2 = Car(employee=engineer4,status=dead) - session.add(car1) - session.add(car2) - session.flush() - - # this particular adapt used to cause a recursion overflow; - # added here for testing - e = exists([Car.owner], Car.owner==employee_join.c.person_id) - Query(Person)._adapt_clause(employee_join, False, False) - - r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active") - assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]" - r = session.query(Engineer).join('status').filter(Person.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active")).order_by(Person.name) - assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]" - - r = session.query(Person).filter(exists([1], Car.owner==Person.person_id)) - assert str(list(r)) == "[Engineer E4, field X, status Status dead]" - -class MultiLevelTest(_base.MappedTest): - def define_tables(self, metadata): - global table_Employee, table_Engineer, table_Manager - table_Employee = Table( 'Employee', metadata, - Column( 'name', type_= String(100), ), - Column( 'id', primary_key= True, type_= Integer, ), - Column( 'atype', type_= String(100), ), - ) - - table_Engineer = Table( 'Engineer', metadata, - Column( 'machine', type_= String(100), ), - Column( 'id', Integer, ForeignKey( 'Employee.id', ), primary_key= True, ), - ) - - table_Manager = Table( 'Manager', metadata, - Column( 'duties', type_= String(100), ), - Column( 'id', Integer, ForeignKey( 'Engineer.id', ), primary_key= True, ), - ) - def test_threelevels(self): - class Employee( object): - def set( me, **kargs): - for k,v in kargs.iteritems(): setattr( me, k, v) - return me - def __str__(me): return str(me.__class__.__name__)+':'+str(me.name) - __repr__ = __str__ - class Engineer( Employee): pass - class Manager( Engineer): pass - - pu_Employee = polymorphic_union( { - 'Manager': table_Employee.join( table_Engineer).join( table_Manager), - 'Engineer': select([table_Employee, table_Engineer.c.machine], table_Employee.c.atype == 'Engineer', from_obj=[table_Employee.join(table_Engineer)]), - 'Employee': table_Employee.select( table_Employee.c.atype == 'Employee'), - }, None, 'pu_employee', ) - -# pu_Employee = polymorphic_union( { -# 'Manager': table_Employee.join( table_Engineer).join( table_Manager), -# 'Engineer': table_Employee.join(table_Engineer).select(table_Employee.c.atype == 'Engineer'), -# 'Employee': table_Employee.select( table_Employee.c.atype == 'Employee'), -# }, None, 'pu_employee', ) - - mapper_Employee = mapper( Employee, table_Employee, - polymorphic_identity= 'Employee', - polymorphic_on= pu_Employee.c.atype, - with_polymorphic=('*', pu_Employee), - ) - - pu_Engineer = polymorphic_union( { - 'Manager': table_Employee.join( table_Engineer).join( table_Manager), - 'Engineer': select([table_Employee, table_Engineer.c.machine], table_Employee.c.atype == 'Engineer', from_obj=[table_Employee.join(table_Engineer)]), - }, None, 'pu_engineer', ) - mapper_Engineer = mapper( Engineer, table_Engineer, - inherit_condition= table_Engineer.c.id == table_Employee.c.id, - inherits= mapper_Employee, - polymorphic_identity= 'Engineer', - polymorphic_on= pu_Engineer.c.atype, - with_polymorphic=('*', pu_Engineer), - ) - - mapper_Manager = mapper( Manager, table_Manager, - inherit_condition= table_Manager.c.id == table_Engineer.c.id, - inherits= mapper_Engineer, - polymorphic_identity= 'Manager', - ) - - a = Employee().set( name= 'one') - b = Engineer().set( egn= 'two', machine= 'any') - c = Manager().set( name= 'head', machine= 'fast', duties= 'many') - - session = create_session() - session.add(a) - session.add(b) - session.add(c) - session.flush() - assert set(session.query(Employee).all()) == set([a,b,c]) - assert set(session.query( Engineer).all()) == set([b,c]) - assert session.query( Manager).all() == [c] - -class ManyToManyPolyTest(_base.MappedTest): - def define_tables(self, metadata): - global base_item_table, item_table, base_item_collection_table, collection_table - base_item_table = Table( - 'base_item', metadata, - Column('id', Integer, primary_key=True), - Column('child_name', String(255), default=None)) - - item_table = Table( - 'item', metadata, - Column('id', Integer, ForeignKey('base_item.id'), primary_key=True), - Column('dummy', Integer, default=0)) # Dummy column to avoid weird insert problems - - base_item_collection_table = Table( - 'base_item_collection', metadata, - Column('item_id', Integer, ForeignKey('base_item.id')), - Column('collection_id', Integer, ForeignKey('collection.id'))) - - collection_table = Table( - 'collection', metadata, - Column('id', Integer, primary_key=True), - Column('name', Unicode(255))) - - def test_pjoin_compile(self): - """test that remote_side columns in the secondary join table arent attempted to be - matched to the target polymorphic selectable""" - class BaseItem(object): pass - class Item(BaseItem): pass - class Collection(object): pass - item_join = polymorphic_union( { - 'BaseItem':base_item_table.select(base_item_table.c.child_name=='BaseItem'), - 'Item':base_item_table.join(item_table), - }, None, 'item_join') - - mapper( - BaseItem, base_item_table, - with_polymorphic=('*', item_join), - polymorphic_on=base_item_table.c.child_name, - polymorphic_identity='BaseItem', - properties=dict(collections=relation(Collection, secondary=base_item_collection_table, backref="items"))) - - mapper( - Item, item_table, - inherits=BaseItem, - polymorphic_identity='Item') - - mapper(Collection, collection_table) - - class_mapper(BaseItem) - -class CustomPKTest(_base.MappedTest): - def define_tables(self, metadata): - global t1, t2 - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('type', String(30), nullable=False), - Column('data', String(30))) - # note that the primary key column in t2 is named differently - t2 = Table('t2', metadata, - Column('t2id', Integer, ForeignKey('t1.id'), primary_key=True), - Column('t2data', String(30))) - - def test_custompk(self): - """test that the primary_key attribute is propagated to the polymorphic mapper""" - - class T1(object):pass - class T2(T1):pass - - # create a polymorphic union with the select against the base table first. - # with the join being second, the alias of the union will - # pick up two "primary key" columns. technically the alias should have a - # 2-col pk in any case but the leading select has a NULL for the "t2id" column - d = util.OrderedDict() - d['t1'] = t1.select(t1.c.type=='t1') - d['t2'] = t1.join(t2) - pjoin = polymorphic_union(d, None, 'pjoin') - - mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', with_polymorphic=('*', pjoin), primary_key=[pjoin.c.id]) - mapper(T2, t2, inherits=T1, polymorphic_identity='t2') - print [str(c) for c in class_mapper(T1).primary_key] - ot1 = T1() - ot2 = T2() - sess = create_session() - sess.add(ot1) - sess.add(ot2) - sess.flush() - sess.expunge_all() - - # query using get(), using only one value. this requires the select_table mapper - # has the same single-col primary key. - assert sess.query(T1).get(ot1.id).id == ot1.id - - ot1 = sess.query(T1).get(ot1.id) - ot1.data = 'hi' - sess.flush() - - def test_pk_collapses(self): - """test that a composite primary key attribute formed by a join is "collapsed" into its - minimal columns""" - - class T1(object):pass - class T2(T1):pass - - # create a polymorphic union with the select against the base table first. - # with the join being second, the alias of the union will - # pick up two "primary key" columns. technically the alias should have a - # 2-col pk in any case but the leading select has a NULL for the "t2id" column - d = util.OrderedDict() - d['t1'] = t1.select(t1.c.type=='t1') - d['t2'] = t1.join(t2) - pjoin = polymorphic_union(d, None, 'pjoin') - - mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', with_polymorphic=('*', pjoin)) - mapper(T2, t2, inherits=T1, polymorphic_identity='t2') - assert len(class_mapper(T1).primary_key) == 1 - - print [str(c) for c in class_mapper(T1).primary_key] - ot1 = T1() - ot2 = T2() - sess = create_session() - sess.add(ot1) - sess.add(ot2) - sess.flush() - sess.expunge_all() - - # query using get(), using only one value. this requires the select_table mapper - # has the same single-col primary key. - assert sess.query(T1).get(ot1.id).id == ot1.id - - ot1 = sess.query(T1).get(ot1.id) - ot1.data = 'hi' - sess.flush() - -class InheritingEagerTest(_base.MappedTest): - def define_tables(self, metadata): - global people, employees, tags, peopleTags - - people = Table('people', metadata, - Column('id', Integer, primary_key=True), - Column('_type', String(30), nullable=False), - ) - - - employees = Table('employees', metadata, - Column('id', Integer, ForeignKey('people.id'),primary_key=True), - ) - - tags = Table('tags', metadata, - Column('id', Integer, primary_key=True), - Column('label', String(50), nullable=False), - ) - - peopleTags = Table('peopleTags', metadata, - Column('person_id', Integer,ForeignKey('people.id')), - Column('tag_id', Integer,ForeignKey('tags.id')), - ) - - def test_basic(self): - """test that Query uses the full set of mapper._eager_loaders when generating SQL""" - - class Person(_fixtures.Base): - pass - - class Employee(Person): - def __init__(self, name='bob'): - self.name = name - - class Tag(_fixtures.Base): - def __init__(self, label): - self.label = label - - mapper(Person, people, polymorphic_on=people.c._type,polymorphic_identity='person', properties={ - 'tags': relation(Tag, secondary=peopleTags,backref='people', lazy=False) - }) - mapper(Employee, employees, inherits=Person,polymorphic_identity='employee') - mapper(Tag, tags) - - session = create_session() - - bob = Employee() - session.add(bob) - - tag = Tag('crazy') - bob.tags.append(tag) - - tag = Tag('funny') - bob.tags.append(tag) - session.flush() - - session.expunge_all() - # query from Employee with limit, query needs to apply eager limiting subquery - instance = session.query(Employee).filter_by(id=1).limit(1).first() - assert len(instance.tags) == 2 - -class MissingPolymorphicOnTest(_base.MappedTest): - def define_tables(self, metadata): - global tablea, tableb, tablec, tabled - tablea = Table('tablea', metadata, - Column('id', Integer, primary_key=True), - Column('adata', String(50)), - ) - tableb = Table('tableb', metadata, - Column('id', Integer, primary_key=True), - Column('aid', Integer, ForeignKey('tablea.id')), - Column('data', String(50)), - ) - tablec = Table('tablec', metadata, - Column('id', Integer, ForeignKey('tablea.id'), primary_key=True), - Column('cdata', String(50)), - ) - tabled = Table('tabled', metadata, - Column('id', Integer, ForeignKey('tablec.id'), primary_key=True), - Column('ddata', String(50)), - ) - - def test_polyon_col_setsup(self): - class A(_fixtures.Base): - pass - class B(_fixtures.Base): - pass - class C(A): - pass - class D(C): - pass - - poly_select = select([tablea, tableb.c.data.label('discriminator')], from_obj=tablea.join(tableb)).alias('poly') - - mapper(B, tableb) - mapper(A, tablea, with_polymorphic=('*', poly_select), polymorphic_on=poly_select.c.discriminator, properties={ - 'b':relation(B, uselist=False) - }) - mapper(C, tablec, inherits=A,polymorphic_identity='c') - mapper(D, tabled, inherits=C, polymorphic_identity='d') - - c = C(cdata='c1', adata='a1', b=B(data='c')) - d = D(cdata='c2', adata='a2', ddata='d2', b=B(data='d')) - sess = create_session() - sess.add(c) - sess.add(d) - sess.flush() - sess.expunge_all() - self.assertEquals(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')]) - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/productspec.py b/test/orm/inheritance/productspec.py deleted file mode 100644 index b6a8c5146..000000000 --- a/test/orm/inheritance/productspec.py +++ /dev/null @@ -1,320 +0,0 @@ -import testenv; testenv.configure_for_tests() -from datetime import datetime -from sqlalchemy import * -from sqlalchemy.orm import * - - -from testlib import testing -from orm import _base - - -class InheritTest(_base.MappedTest): - """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships""" - def define_tables(self, metadata): - global products_table, specification_table, documents_table - global Product, Detail, Assembly, SpecLine, Document, RasterDocument - - products_table = Table('products', metadata, - Column('product_id', Integer, primary_key=True), - Column('product_type', String(128)), - Column('name', String(128)), - Column('mark', String(128)), - ) - - specification_table = Table('specification', metadata, - Column('spec_line_id', Integer, primary_key=True), - Column('master_id', Integer, ForeignKey("products.product_id"), - nullable=True), - Column('slave_id', Integer, ForeignKey("products.product_id"), - nullable=True), - Column('quantity', Float, default=1.), - ) - - documents_table = Table('documents', metadata, - Column('document_id', Integer, primary_key=True), - Column('document_type', String(128)), - Column('product_id', Integer, ForeignKey('products.product_id')), - Column('create_date', DateTime, default=lambda:datetime.now()), - Column('last_updated', DateTime, default=lambda:datetime.now(), - onupdate=lambda:datetime.now()), - Column('name', String(128)), - Column('data', Binary), - Column('size', Integer, default=0), - ) - - class Product(object): - def __init__(self, name, mark=''): - self.name = name - self.mark = mark - def __repr__(self): - return '<%s %s>' % (self.__class__.__name__, self.name) - - class Detail(Product): - def __init__(self, name): - self.name = name - - class Assembly(Product): - def __repr__(self): - return Product.__repr__(self) + " " + " ".join([x + "=" + repr(getattr(self, x, None)) for x in ['specification', 'documents']]) - - class SpecLine(object): - def __init__(self, master=None, slave=None, quantity=1): - self.master = master - self.slave = slave - self.quantity = quantity - - def __repr__(self): - return '<%s %.01f %s>' % ( - self.__class__.__name__, - self.quantity or 0., - repr(self.slave) - ) - - class Document(object): - def __init__(self, name, data=None): - self.name = name - self.data = data - def __repr__(self): - return '<%s %s>' % (self.__class__.__name__, self.name) - - class RasterDocument(Document): - pass - - def testone(self): - product_mapper = mapper(Product, products_table, - polymorphic_on=products_table.c.product_type, - polymorphic_identity='product') - - detail_mapper = mapper(Detail, inherits=product_mapper, - polymorphic_identity='detail') - - assembly_mapper = mapper(Assembly, inherits=product_mapper, - polymorphic_identity='assembly') - - specification_mapper = mapper(SpecLine, specification_table, - properties=dict( - master=relation(Assembly, - foreign_keys=[specification_table.c.master_id], - primaryjoin=specification_table.c.master_id==products_table.c.product_id, - lazy=True, backref=backref('specification'), - uselist=False), - slave=relation(Product, - foreign_keys=[specification_table.c.slave_id], - primaryjoin=specification_table.c.slave_id==products_table.c.product_id, - lazy=True, uselist=False), - quantity=specification_table.c.quantity, - ) - ) - - session = create_session( ) - - a1 = Assembly(name='a1') - - p1 = Product(name='p1') - a1.specification.append(SpecLine(slave=p1)) - - d1 = Detail(name='d1') - a1.specification.append(SpecLine(slave=d1)) - - session.add(a1) - orig = repr(a1) - session.flush() - session.expunge_all() - - a1 = session.query(Product).filter_by(name='a1').one() - new = repr(a1) - print orig - print new - assert orig == new == ' specification=[>, >] documents=None' - - def testtwo(self): - product_mapper = mapper(Product, products_table, - polymorphic_on=products_table.c.product_type, - polymorphic_identity='product') - - detail_mapper = mapper(Detail, inherits=product_mapper, - polymorphic_identity='detail') - - specification_mapper = mapper(SpecLine, specification_table, - properties=dict( - slave=relation(Product, - foreign_keys=[specification_table.c.slave_id], - primaryjoin=specification_table.c.slave_id==products_table.c.product_id, - lazy=True, uselist=False), - ) - ) - - session = create_session( ) - - s = SpecLine(slave=Product(name='p1')) - s2 = SpecLine(slave=Detail(name='d1')) - session.add(s) - session.add(s2) - orig = repr([s, s2]) - session.flush() - session.expunge_all() - new = repr(session.query(SpecLine).all()) - print orig - print new - assert orig == new == '[>, >]' - - def testthree(self): - product_mapper = mapper(Product, products_table, - polymorphic_on=products_table.c.product_type, - polymorphic_identity='product') - detail_mapper = mapper(Detail, inherits=product_mapper, - polymorphic_identity='detail') - assembly_mapper = mapper(Assembly, inherits=product_mapper, - polymorphic_identity='assembly') - - specification_mapper = mapper(SpecLine, specification_table, - properties=dict( - master=relation(Assembly, lazy=False, uselist=False, - foreign_keys=[specification_table.c.master_id], - primaryjoin=specification_table.c.master_id==products_table.c.product_id, - backref=backref('specification', cascade="all, delete-orphan"), - ), - slave=relation(Product, lazy=False, uselist=False, - foreign_keys=[specification_table.c.slave_id], - primaryjoin=specification_table.c.slave_id==products_table.c.product_id, - ), - quantity=specification_table.c.quantity, - ) - ) - - document_mapper = mapper(Document, documents_table, - polymorphic_on=documents_table.c.document_type, - polymorphic_identity='document', - properties=dict( - name=documents_table.c.name, - data=deferred(documents_table.c.data), - product=relation(Product, lazy=True, backref=backref('documents', cascade="all, delete-orphan")), - ), - ) - raster_document_mapper = mapper(RasterDocument, inherits=document_mapper, - polymorphic_identity='raster_document') - - session = create_session() - - a1 = Assembly(name='a1') - a1.specification.append(SpecLine(slave=Detail(name='d1'))) - a1.documents.append(Document('doc1')) - a1.documents.append(RasterDocument('doc2')) - session.add(a1) - orig = repr(a1) - session.flush() - session.expunge_all() - - a1 = session.query(Product).filter_by(name='a1').one() - new = repr(a1) - print orig - print new - assert orig == new == ' specification=[>] documents=[, ]' - - def testfour(self): - """this tests the RasterDocument being attached to the Assembly, but *not* the Document. this means only - a "sub-class" task, i.e. corresponding to an inheriting mapper but not the base mapper, is created. """ - product_mapper = mapper(Product, products_table, - polymorphic_on=products_table.c.product_type, - polymorphic_identity='product') - detail_mapper = mapper(Detail, inherits=product_mapper, - polymorphic_identity='detail') - assembly_mapper = mapper(Assembly, inherits=product_mapper, - polymorphic_identity='assembly') - - document_mapper = mapper(Document, documents_table, - polymorphic_on=documents_table.c.document_type, - polymorphic_identity='document', - properties=dict( - name=documents_table.c.name, - data=deferred(documents_table.c.data), - product=relation(Product, lazy=True, backref=backref('documents', cascade="all, delete-orphan")), - ), - ) - raster_document_mapper = mapper(RasterDocument, inherits=document_mapper, - polymorphic_identity='raster_document') - - session = create_session( ) - - a1 = Assembly(name='a1') - a1.documents.append(RasterDocument('doc2')) - session.add(a1) - orig = repr(a1) - session.flush() - session.expunge_all() - - a1 = session.query(Product).filter_by(name='a1').one() - new = repr(a1) - print orig - print new - assert orig == new == ' specification=None documents=[]' - - del a1.documents[0] - session.flush() - session.expunge_all() - - a1 = session.query(Product).filter_by(name='a1').one() - assert len(session.query(Document).all()) == 0 - - def testfive(self): - """tests the late compilation of mappers""" - - specification_mapper = mapper(SpecLine, specification_table, - properties=dict( - master=relation(Assembly, lazy=False, uselist=False, - foreign_keys=[specification_table.c.master_id], - primaryjoin=specification_table.c.master_id==products_table.c.product_id, - backref=backref('specification'), - ), - slave=relation(Product, lazy=False, uselist=False, - foreign_keys=[specification_table.c.slave_id], - primaryjoin=specification_table.c.slave_id==products_table.c.product_id, - ), - quantity=specification_table.c.quantity, - ) - ) - - product_mapper = mapper(Product, products_table, - polymorphic_on=products_table.c.product_type, - polymorphic_identity='product', properties={ - 'documents' : relation(Document, lazy=True, - backref='product', cascade='all, delete-orphan'), - }) - - detail_mapper = mapper(Detail, inherits=Product, - polymorphic_identity='detail') - - document_mapper = mapper(Document, documents_table, - polymorphic_on=documents_table.c.document_type, - polymorphic_identity='document', - properties=dict( - name=documents_table.c.name, - data=deferred(documents_table.c.data), - ), - ) - - raster_document_mapper = mapper(RasterDocument, inherits=Document, - polymorphic_identity='raster_document') - - assembly_mapper = mapper(Assembly, inherits=Product, - polymorphic_identity='assembly') - - session = create_session() - - a1 = Assembly(name='a1') - a1.specification.append(SpecLine(slave=Detail(name='d1'))) - a1.documents.append(Document('doc1')) - a1.documents.append(RasterDocument('doc2')) - session.add(a1) - orig = repr(a1) - session.flush() - session.expunge_all() - - a1 = session.query(Product).filter_by(name='a1').one() - new = repr(a1) - print orig - print new - assert orig == new == ' specification=[>] documents=[, ]' - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py deleted file mode 100644 index 58d205455..000000000 --- a/test/orm/inheritance/query.py +++ /dev/null @@ -1,1105 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.orm import * -from sqlalchemy import exc as sa_exc -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.engine import default - -from testlib import AssertsCompiledSQL, testing -from orm import _base, _fixtures -from testlib.testing import eq_ - -class Company(_fixtures.Base): - pass - -class Person(_fixtures.Base): - pass -class Engineer(Person): - pass -class Manager(Person): - pass -class Boss(Manager): - pass - -class Machine(_fixtures.Base): - pass - -class Paperwork(_fixtures.Base): - pass - -def make_test(select_type): - class PolymorphicQueryTest(_base.MappedTest, AssertsCompiledSQL): - run_inserts = 'once' - run_setup_mappers = 'once' - run_deletes = None - - def define_tables(self, metadata): - global companies, people, engineers, managers, boss, paperwork, machines - - companies = Table('companies', metadata, - Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key=True), - Column('name', String(50))) - - people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), - Column('company_id', Integer, ForeignKey('companies.company_id')), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('status', String(30)), - Column('engineer_name', String(50)), - Column('primary_language', String(50)), - ) - - machines = Table('machines', metadata, - Column('machine_id', Integer, primary_key=True), - Column('name', String(50)), - Column('engineer_id', Integer, ForeignKey('engineers.person_id'))) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('status', String(30)), - Column('manager_name', String(50)) - ) - - boss = Table('boss', metadata, - Column('boss_id', Integer, ForeignKey('managers.person_id'), primary_key=True), - Column('golf_swing', String(30)), - ) - - paperwork = Table('paperwork', metadata, - Column('paperwork_id', Integer, primary_key=True), - Column('description', String(50)), - Column('person_id', Integer, ForeignKey('people.person_id'))) - - clear_mappers() - - mapper(Company, companies, properties={ - 'employees':relation(Person, order_by=people.c.person_id) - }) - - mapper(Machine, machines) - - if select_type == '': - person_join = manager_join = None - person_with_polymorphic = None - manager_with_polymorphic = None - elif select_type == 'Polymorphic': - person_join = manager_join = None - person_with_polymorphic = '*' - manager_with_polymorphic = '*' - elif select_type == 'Unions': - person_join = polymorphic_union( - { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - }, None, 'pjoin') - - manager_join = people.join(managers).outerjoin(boss) - person_with_polymorphic = ([Person, Manager, Engineer], person_join) - manager_with_polymorphic = ('*', manager_join) - elif select_type == 'AliasedJoins': - person_join = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin') - manager_join = people.join(managers).outerjoin(boss).select(use_labels=True).alias('mjoin') - person_with_polymorphic = ([Person, Manager, Engineer], person_join) - manager_with_polymorphic = ('*', manager_join) - elif select_type == 'Joins': - person_join = people.outerjoin(engineers).outerjoin(managers) - manager_join = people.join(managers).outerjoin(boss) - person_with_polymorphic = ([Person, Manager, Engineer], person_join) - manager_with_polymorphic = ('*', manager_join) - - - # testing a order_by here as well; the surrogate mapper has to adapt it - mapper(Person, people, - with_polymorphic=person_with_polymorphic, - polymorphic_on=people.c.type, polymorphic_identity='person', order_by=people.c.person_id, - properties={ - 'paperwork':relation(Paperwork, order_by=paperwork.c.paperwork_id) - }) - mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer', properties={ - 'machines':relation(Machine, order_by=machines.c.machine_id) - }) - mapper(Manager, managers, with_polymorphic=manager_with_polymorphic, - inherits=Person, polymorphic_identity='manager') - mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss') - mapper(Paperwork, paperwork) - - - def insert_data(self): - global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3, c1, c2 - - c1 = Company(name="MegaCorp, Inc.") - c2 = Company(name="Elbonia, Inc.") - e1 = Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", paperwork=[ - Paperwork(description="tps report #1"), - Paperwork(description="tps report #2") - ], machines=[ - Machine(name='IBM ThinkPad'), - Machine(name='IPhone'), - ]) - e2 = Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer", paperwork=[ - Paperwork(description="tps report #3"), - Paperwork(description="tps report #4") - ], machines=[ - Machine(name="Commodore 64") - ]) - b1 = Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss", paperwork=[ - Paperwork(description="review #1"), - ]) - m1 = Manager(name="dogbert", manager_name="dogbert", status="regular manager", paperwork=[ - Paperwork(description="review #2"), - Paperwork(description="review #3") - ]) - c1.employees = [e1, e2, b1, m1] - - e3 = Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer", paperwork=[ - Paperwork(description='elbonian missive #3') - ], machines=[ - Machine(name="Commodore 64"), - Machine(name="IBM 3270") - ]) - - c2.employees = [e3] - sess = create_session() - sess.add(c1) - sess.add(c2) - sess.flush() - sess.expunge_all() - - all_employees = [e1, e2, b1, m1, e3] - c1_employees = [e1, e2, b1, m1] - c2_employees = [e3] - - def test_loads_at_once(self): - """test that all objects load from the full query, when with_polymorphic is used""" - - sess = create_session() - def go(): - self.assertEquals(sess.query(Person).all(), all_employees) - self.assert_sql_count(testing.db, go, {'':14, 'Polymorphic':9}.get(select_type, 10)) - - def test_primary_eager_aliasing(self): - sess = create_session() - - def go(): - self.assertEquals(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) - self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4)) - - sess = create_session() - - # assert the JOINs dont over JOIN - assert sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().subquery().count().scalar() == 2 - - def go(): - self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) - self.assert_sql_count(testing.db, go, 3) - - - def test_get(self): - sess = create_session() - - # for all mappers, ensure the primary key has been calculated as just the "person_id" - # column - self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) - self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) - self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore")) - - def test_multi_join(self): - sess = create_session() - - e = aliased(Person) - c = aliased(Company) - - q = sess.query(Company, Person, c, e).join((Person, Company.employees)).join((e, c.employees)).\ - filter(Person.name=='dilbert').filter(e.name=='wally') - - self.assertEquals(q.count(), 1) - self.assertEquals(q.all(), [ - ( - Company(company_id=1,name=u'MegaCorp, Inc.'), - Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), - Company(company_id=1,name=u'MegaCorp, Inc.'), - Engineer(status=u'regular engineer',engineer_name=u'wally',name=u'wally',company_id=1,primary_language=u'c++',person_id=2,type=u'engineer') - ) - ]) - - def test_filter_on_subclass(self): - sess = create_session() - self.assertEquals(sess.query(Engineer).all()[0], Engineer(name="dilbert")) - - self.assertEquals(sess.query(Engineer).first(), Engineer(name="dilbert")) - - self.assertEquals(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert")) - - self.assertEquals(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert")) - - self.assertEquals(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) - - self.assertEquals(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) - - def test_join_from_polymorphic(self): - sess = create_session() - - for aliased in (True, False): - self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) - - self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) - - self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1]) - - self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) - - def test_join_from_with_polymorphic(self): - sess = create_session() - - for aliased in (True, False): - sess.expunge_all() - self.assertEquals(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) - - sess.expunge_all() - self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) - - sess.expunge_all() - self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) - - def test_join_to_polymorphic(self): - sess = create_session() - self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2) - - self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2) - - def test_polymorphic_any(self): - sess = create_session() - - self.assertEquals( - sess.query(Company).\ - filter(Company.employees.any(Person.name=='vlad')).all(), [c2] - ) - - # test that the aliasing on "Person" does not bleed into the - # EXISTS clause generated by any() - self.assertEquals( - sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\ - filter(Company.employees.any(Person.name=='wally')).all(), [c1] - ) - - self.assertEquals( - sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\ - filter(Company.employees.any(Person.name=='vlad')).all(), [] - ) - - self.assertEquals( - sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(), - c2 - ) - - calias = aliased(Company) - self.assertEquals( - sess.query(calias).filter(calias.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(), - c2 - ) - - self.assertEquals( - sess.query(Company).filter(Company.employees.of_type(Boss).any(Boss.golf_swing=='fore')).one(), - c1 - ) - self.assertEquals( - sess.query(Company).filter(Company.employees.of_type(Boss).any(Manager.manager_name=='pointy')).one(), - c1 - ) - - if select_type != '': - self.assertEquals( - sess.query(Person).filter(Engineer.machines.any(Machine.name=="Commodore 64")).all(), [e2, e3] - ) - - self.assertEquals( - sess.query(Person).filter(Person.paperwork.any(Paperwork.description=="review #2")).all(), [m1] - ) - - self.assertEquals( - sess.query(Company).filter(Company.employees.of_type(Engineer).any(and_(Engineer.primary_language=='cobol'))).one(), - c2 - ) - - def test_join_from_columns_or_subclass(self): - sess = create_session() - - self.assertEquals( - sess.query(Manager.name).order_by(Manager.name).all(), - [(u'dogbert',), (u'pointy haired boss',)] - ) - - self.assertEquals( - sess.query(Manager.name).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(), - [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)] - ) - - self.assertEquals( - sess.query(Person.name).join((Paperwork, Person.paperwork)).order_by(Person.name).all(), - [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)] - ) - - self.assertEquals( - sess.query(Person.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Person.name).all(), - [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)] - ) - - self.assertEquals( - sess.query(Manager).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(), - [m1, b1] - ) - - self.assertEquals( - sess.query(Manager.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(), - [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)] - ) - - self.assertEquals( - sess.query(Manager.person_id).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(), - [(4,), (4,), (3,)] - ) - - self.assertEquals( - sess.query(Manager.name, Paperwork.description).join((Paperwork, Manager.person_id==Paperwork.person_id)).all(), - [(u'pointy haired boss', u'review #1'), (u'dogbert', u'review #2'), (u'dogbert', u'review #3')] - ) - - malias = aliased(Manager) - self.assertEquals( - sess.query(malias.name).join((paperwork, malias.person_id==paperwork.c.person_id)).all(), - [(u'pointy haired boss',), (u'dogbert',), (u'dogbert',)] - ) - - def test_expire(self): - """test that individual column refresh doesn't get tripped up by the select_table mapper""" - - sess = create_session() - m1 = sess.query(Manager).filter(Manager.name=='dogbert').one() - sess.expire(m1) - assert m1.status == 'regular manager' - - m2 = sess.query(Manager).filter(Manager.name=='pointy haired boss').one() - sess.expire(m2, ['manager_name', 'golf_swing']) - assert m2.golf_swing=='fore' - - def test_with_polymorphic(self): - - sess = create_session() - - - self.assertRaises(sa_exc.InvalidRequestError, sess.query(Person).with_polymorphic, Paperwork) - self.assertRaises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Boss) - self.assertRaises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Person) - - # compare to entities without related collections to prevent additional lazy SQL from firing on - # loaded entities - emps_without_relations = [ - Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer"), - Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer"), - Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss"), - Manager(name="dogbert", manager_name="dogbert", status="regular manager"), - Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") - ] - self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) - - - def go(): - self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1]) - self.assert_sql_count(testing.db, go, 1) - - sess.expunge_all() - def go(): - self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) - self.assert_sql_count(testing.db, go, 1) - - sess.expunge_all() - def go(): - self.assertEquals(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations) - self.assert_sql_count(testing.db, go, 3) - - sess.expunge_all() - def go(): - self.assertEquals(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations) - self.assert_sql_count(testing.db, go, 3) - - sess.expunge_all() - def go(): - # limit the polymorphic join down to just "Person", overriding select_table - self.assertEquals(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations) - self.assert_sql_count(testing.db, go, 6) - - def test_relation_to_polymorphic(self): - assert_result = [ - Company(name="MegaCorp, Inc.", employees=[ - Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")]), - Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer"), - Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss"), - Manager(name="dogbert", manager_name="dogbert", status="regular manager"), - ]), - Company(name="Elbonia, Inc.", employees=[ - Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") - ]) - ] - - sess = create_session() - - def go(): - # test load Companies with lazy load to 'employees' - self.assertEquals(sess.query(Company).all(), assert_result) - self.assert_sql_count(testing.db, go, {'':9, 'Polymorphic':4}.get(select_type, 5)) - - sess = create_session() - def go(): - # currently, it doesn't matter if we say Company.employees, or Company.employees.of_type(Engineer). eagerloader doesn't - # pick up on the "of_type()" as of yet. - self.assertEquals(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result) - - # in the case of select_type='', the eagerload doesn't take in this case; - # it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines" - self.assert_sql_count(testing.db, go, {'':7, 'Polymorphic':1}.get(select_type, 2)) - - def test_eagerload_on_subclass(self): - sess = create_session() - def go(): - # test load People with eagerload to engineers + machines - self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(), - [Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])] - ) - self.assert_sql_count(testing.db, go, 1) - - def test_join_to_subclass(self): - sess = create_session() - self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) - - if select_type == '': - self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) - - ealias = aliased(Engineer) - self.assertEquals(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1]) - - self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3]) - self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) - self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2]) - self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) - else: - self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) - self.assertEquals(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) - self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2]) - self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) - - # non-polymorphic - self.assertEquals(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3]) - self.assertEquals(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) - - # here's the new way - self.assertEquals(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) - - def test_join_through_polymorphic(self): - - sess = create_session() - - for aliased in (True, False): - self.assertEquals( - sess.query(Company).\ - join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), - [c1] - ) - - self.assertEquals( - sess.query(Company).\ - join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#%')).all(), - [c1, c2] - ) - - self.assertEquals( - sess.query(Company).\ - join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#2%')).all(), - [c1] - ) - - self.assertEquals( - sess.query(Company).\ - join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#%')).all(), - [c1, c2] - ) - - self.assertEquals( - sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ - join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), - [c1] - ) - - self.assertEquals( - sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ - join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#%')).all(), - [c1, c2] - ) - def test_explicit_polymorphic_join(self): - sess = create_session() - - # join from Company to Engineer; join condition formulated by - # ORMJoin using regular table foreign key connections. Engineer - # is expressed as "(select * people join engineers) as anon_1" - # so the join is contained. - self.assertEquals( - sess.query(Company).join(Engineer).filter(Engineer.engineer_name=='vlad').one(), - c2 - ) - - # same, using explicit join condition. Query.join() must adapt the on clause - # here to match the subquery wrapped around "people join engineers". - self.assertEquals( - sess.query(Company).join((Engineer, Company.company_id==Engineer.company_id)).filter(Engineer.engineer_name=='vlad').one(), - c2 - ) - - - def test_filter_on_baseclass(self): - sess = create_session() - - self.assertEquals(sess.query(Person).all(), all_employees) - - self.assertEquals(sess.query(Person).first(), all_employees[0]) - - self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2) - - def test_from_alias(self): - sess = create_session() - - palias = aliased(Person) - self.assertEquals( - sess.query(palias).filter(palias.name.in_(['dilbert', 'wally'])).all(), - [e1, e2] - ) - - def test_self_referential(self): - sess = create_session() - - c1_employees = [e1, e2, b1, m1] - - palias = aliased(Person) - self.assertEquals( - sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ - filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), - [ - (m1, e1), - (m1, e2), - (m1, b1), - ] - ) - - self.assertEquals( - sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ - filter(Person.person_id>palias.person_id).from_self().order_by(Person.person_id, palias.person_id).all(), - [ - (m1, e1), - (m1, e2), - (m1, b1), - ] - ) - - def test_nesting_queries(self): - sess = create_session() - - # query.statement places a flag "no_adapt" on the returned statement. This prevents - # the polymorphic adaptation in the second "filter" from hitting it, which would pollute - # the subquery and usually results in recursion overflow errors within the adaption. - subq = sess.query(engineers.c.person_id).filter(Engineer.primary_language=='java').statement.as_scalar() - - self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1) - - def test_mixed_entities(self): - sess = create_session() - - self.assertEquals( - sess.query(Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), - [(u'Elbonia, Inc.', - Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'))] - ) - - self.assertEquals( - sess.query(Person, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), - [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'), - u'Elbonia, Inc.')] - ) - - - self.assertEquals( - sess.query(Manager.name).all(), - [('pointy haired boss', ), ('dogbert',)] - ) - - self.assertEquals( - sess.query(Manager.name + " foo").all(), - [('pointy haired boss foo', ), ('dogbert foo',)] - ) - - row = sess.query(Engineer.name, Engineer.primary_language).filter(Engineer.name=='dilbert').first() - assert row.name == 'dilbert' - assert row.primary_language == 'java' - - - self.assertEquals( - sess.query(Engineer.name, Engineer.primary_language).all(), - [(u'dilbert', u'java'), (u'wally', u'c++'), (u'vlad', u'cobol')] - ) - - self.assertEquals( - sess.query(Boss.name, Boss.golf_swing).all(), - [(u'pointy haired boss', u'fore')] - ) - - # TODO: I think raise error on these for now. different inheritance/loading schemes have different - # results here, all incorrect - # - # self.assertEquals( - # sess.query(Person.name, Engineer.primary_language).all(), - # [] - # ) - - # self.assertEquals( - # sess.query(Person.name, Engineer.primary_language, Manager.manager_name).all(), - # [] - # ) - - self.assertEquals( - sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), - [(u'vlad',u'Elbonia, Inc.')] - ) - - self.assertEquals( - sess.query(Engineer.primary_language).filter(Person.type=='engineer').all(), - [(u'java',), (u'c++',), (u'cobol',)] - ) - - if select_type != '': - self.assertEquals( - sess.query(Engineer, Company.name).join(Company.employees).filter(Person.type=='engineer').all(), - [ - (Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'MegaCorp, Inc.'), - (Engineer(status=u'regular engineer',engineer_name=u'wally',name=u'wally',company_id=1,primary_language=u'c++',person_id=2,type=u'engineer'), u'MegaCorp, Inc.'), - (Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',company_id=2,primary_language=u'cobol',person_id=5,type=u'engineer'), u'Elbonia, Inc.') - ] - ) - - self.assertEquals( - sess.query(Engineer.primary_language, Company.name).join(Company.employees).filter(Person.type=='engineer').order_by(desc(Engineer.primary_language)).all(), - [(u'java', u'MegaCorp, Inc.'), (u'cobol', u'Elbonia, Inc.'), (u'c++', u'MegaCorp, Inc.')] - ) - - palias = aliased(Person) - self.assertEquals( - sess.query(Person, Company.name, palias).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), - [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'), - u'Elbonia, Inc.', - Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'))] - ) - - self.assertEquals( - sess.query(palias, Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), - [(Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), - u'Elbonia, Inc.', - Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),) - ] - ) - - self.assertEquals( - sess.query(Person.name, Company.name, palias.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), - [(u'vlad', u'Elbonia, Inc.', u'dilbert')] - ) - - palias = aliased(Person) - self.assertEquals( - sess.query(Person.type, Person.name, palias.type, palias.name).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ - filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), - [(u'manager', u'dogbert', u'engineer', u'dilbert'), - (u'manager', u'dogbert', u'engineer', u'wally'), - (u'manager', u'dogbert', u'boss', u'pointy haired boss')] - ) - - self.assertEquals( - sess.query(Person.name, Paperwork.description).filter(Person.person_id==Paperwork.person_id).order_by(Person.name, Paperwork.description).all(), - [(u'dilbert', u'tps report #1'), (u'dilbert', u'tps report #2'), (u'dogbert', u'review #2'), - (u'dogbert', u'review #3'), - (u'pointy haired boss', u'review #1'), - (u'vlad', u'elbonian missive #3'), - (u'wally', u'tps report #3'), - (u'wally', u'tps report #4'), - ] - ) - - if select_type != '': - self.assertEquals( - sess.query(func.count(Person.person_id)).filter(Engineer.primary_language=='java').all(), - [(1, )] - ) - - self.assertEquals( - sess.query(Company.name, func.count(Person.person_id)).filter(Company.company_id==Person.company_id).group_by(Company.name).order_by(Company.name).all(), - [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)] - ) - - self.assertEquals( - sess.query(Company.name, func.count(Person.person_id)).join(Company.employees).group_by(Company.name).order_by(Company.name).all(), - [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)] - ) - - - PolymorphicQueryTest.__name__ = "Polymorphic%sTest" % select_type - return PolymorphicQueryTest - -for select_type in ('', 'Polymorphic', 'Unions', 'AliasedJoins', 'Joins'): - testclass = make_test(select_type) - exec("%s = testclass" % testclass.__name__) - -del testclass - -class SelfReferentialTestJoinedToBase(_base.MappedTest): - run_setup_mappers = 'once' - - def define_tables(self, metadata): - global people, engineers - people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('primary_language', String(50)), - Column('reports_to_id', Integer, ForeignKey('people.person_id')) - ) - - def setup_mappers(self): - mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') - mapper(Engineer, engineers, inherits=Person, - inherit_condition=engineers.c.person_id==people.c.person_id, - polymorphic_identity='engineer', properties={ - 'reports_to':relation(Person, primaryjoin=people.c.person_id==engineers.c.reports_to_id) - }) - - def test_has(self): - - p1 = Person(name='dogbert') - e1 = Engineer(name='dilbert', primary_language='java', reports_to=p1) - sess = create_session() - sess.add(p1) - sess.add(e1) - sess.flush() - sess.expunge_all() - - self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert')) - - def test_oftype_aliases_in_exists(self): - e1 = Engineer(name='dilbert', primary_language='java') - e2 = Engineer(name='wally', primary_language='c++', reports_to=e1) - sess = create_session() - sess.add_all([e1, e2]) - sess.flush() - - self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.of_type(Engineer).has(Engineer.name=='dilbert')).first(), e2) - - def test_join(self): - p1 = Person(name='dogbert') - e1 = Engineer(name='dilbert', primary_language='java', reports_to=p1) - sess = create_session() - sess.add(p1) - sess.add(e1) - sess.flush() - sess.expunge_all() - - self.assertEquals( - sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), - Engineer(name='dilbert')) - -class SelfReferentialJ2JTest(_base.MappedTest): - run_setup_mappers = 'once' - - def define_tables(self, metadata): - global people, engineers, managers - people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('primary_language', String(50)), - Column('reports_to_id', Integer, ForeignKey('managers.person_id')) - ) - - managers = Table('managers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - ) - - def setup_mappers(self): - mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') - mapper(Manager, managers, inherits=Person, polymorphic_identity='manager') - - mapper(Engineer, engineers, inherits=Person, - polymorphic_identity='engineer', properties={ - 'reports_to':relation(Manager, primaryjoin=managers.c.person_id==engineers.c.reports_to_id, backref='engineers') - }) - - def test_has(self): - - m1 = Manager(name='dogbert') - e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1) - sess = create_session() - sess.add(m1) - sess.add(e1) - sess.flush() - sess.expunge_all() - - self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Manager.name=='dogbert')).first(), Engineer(name='dilbert')) - - def test_join(self): - m1 = Manager(name='dogbert') - e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1) - sess = create_session() - sess.add(m1) - sess.add(e1) - sess.flush() - sess.expunge_all() - - self.assertEquals( - sess.query(Engineer).join('reports_to', aliased=True).filter(Manager.name=='dogbert').first(), - Engineer(name='dilbert')) - - def test_filter_aliasing(self): - m1 = Manager(name='dogbert') - m2 = Manager(name='foo') - e1 = Engineer(name='wally', primary_language='java', reports_to=m1) - e2 = Engineer(name='dilbert', primary_language='c++', reports_to=m2) - e3 = Engineer(name='etc', primary_language='c++') - sess = create_session() - sess.add_all([m1, m2, e1, e2, e3]) - sess.flush() - sess.expunge_all() - - # filter aliasing applied to Engineer doesn't whack Manager - self.assertEquals( - sess.query(Manager).join(Manager.engineers).filter(Manager.name=='dogbert').all(), - [m1] - ) - - self.assertEquals( - sess.query(Manager).join(Manager.engineers).filter(Engineer.name=='dilbert').all(), - [m2] - ) - - self.assertEquals( - sess.query(Manager, Engineer).join(Manager.engineers).order_by(Manager.name.desc()).all(), - [ - (m2, e2), - (m1, e1), - ] - ) - - def test_relation_compare(self): - m1 = Manager(name='dogbert') - m2 = Manager(name='foo') - e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1) - e2 = Engineer(name='wally', primary_language='c++', reports_to=m2) - e3 = Engineer(name='etc', primary_language='c++') - sess = create_session() - sess.add(m1) - sess.add(m2) - sess.add(e1) - sess.add(e2) - sess.add(e3) - sess.flush() - sess.expunge_all() - - self.assertEquals( - sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==None).all(), - [] - ) - - self.assertEquals( - sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==m1).all(), - [m1] - ) - - - -class M2MFilterTest(_base.MappedTest): - run_setup_mappers = 'once' - run_inserts = 'once' - run_deletes = None - - def define_tables(self, metadata): - global people, engineers, organizations, engineers_to_org - - organizations = Table('organizations', metadata, - Column('id', Integer, Sequence('org_id_seq', optional=True), primary_key=True), - Column('name', String(50)), - ) - engineers_to_org = Table('engineers_org', metadata, - Column('org_id', Integer, ForeignKey('organizations.id')), - Column('engineer_id', Integer, ForeignKey('engineers.person_id')), - ) - - people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), - Column('name', String(50)), - Column('type', String(30))) - - engineers = Table('engineers', metadata, - Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), - Column('primary_language', String(50)), - ) - - def setup_mappers(self): - global Organization - class Organization(_fixtures.Base): - pass - - mapper(Organization, organizations, properties={ - 'engineers':relation(Engineer, secondary=engineers_to_org, backref='organizations') - }) - - mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') - mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer') - - def insert_data(self): - e1 = Engineer(name='e1') - e2 = Engineer(name='e2') - e3 = Engineer(name='e3') - e4 = Engineer(name='e4') - org1 = Organization(name='org1', engineers=[e1, e2]) - org2 = Organization(name='org2', engineers=[e3, e4]) - - sess = create_session() - sess.add(org1) - sess.add(org2) - sess.flush() - - def test_not_contains(self): - sess = create_session() - - e1 = sess.query(Person).filter(Engineer.name=='e1').one() - - # this works - self.assertEquals(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')]) - - # this had a bug - self.assertEquals(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')]) - - def test_any(self): - sess = create_session() - self.assertEquals(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')]) - self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')]) - -class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL): - run_setup_mappers = 'once' - - def define_tables(self, metadata): - global Parent, Child1, Child2 - - Base = declarative_base(metadata=metadata) - - secondary_table = Table('secondary', Base.metadata, - Column('left_id', Integer, ForeignKey('parent.id'), nullable=False), - Column('right_id', Integer, ForeignKey('parent.id'), nullable=False)) - - class Parent(Base): - __tablename__ = 'parent' - id = Column(Integer, primary_key=True) - cls = Column(String(50)) - __mapper_args__ = dict(polymorphic_on = cls ) - - class Child1(Parent): - __tablename__ = 'child1' - id = Column(Integer, ForeignKey('parent.id'), primary_key=True) - __mapper_args__ = dict(polymorphic_identity = 'child1') - - class Child2(Parent): - __tablename__ = 'child2' - id = Column(Integer, ForeignKey('parent.id'), primary_key=True) - __mapper_args__ = dict(polymorphic_identity = 'child2') - - Child1.left_child2 = relation(Child2, secondary = secondary_table, - primaryjoin = Parent.id == secondary_table.c.right_id, - secondaryjoin = Parent.id == secondary_table.c.left_id, - uselist = False, backref="right_children" - ) - - - def test_query_crit(self): - session = create_session() - c11, c12, c13 = Child1(), Child1(), Child1() - c21, c22, c23 = Child2(), Child2(), Child2() - - c11.left_child2 = c22 - c12.left_child2 = c22 - c13.left_child2 = c23 - - session.add_all([c11, c12, c13, c21, c22, c23]) - session.flush() - - # test that the join to Child2 doesn't alias Child1 in the select - eq_( - set(session.query(Child1).join(Child1.left_child2)), - set([c11, c12, c13]) - ) - - eq_( - set(session.query(Child1, Child2).join(Child1.left_child2)), - set([(c11, c22), (c12, c22), (c13, c23)]) - ) - - # test __eq__() on property is annotating correctly - eq_( - set(session.query(Child2).join(Child2.right_children).filter(Child1.left_child2==c22)), - set([c22]) - ) - - # test the same again - self.assert_compile( - session.query(Child2).join(Child2.right_children).filter(Child1.left_child2==c22).with_labels().statement, - "SELECT parent.id AS parent_id, child2.id AS child2_id, parent.cls AS parent_cls FROM " - "secondary AS secondary_1, parent JOIN child2 ON parent.id = child2.id JOIN secondary AS secondary_2 " - "ON parent.id = secondary_2.left_id JOIN (SELECT parent.id AS parent_id, parent.cls AS parent_cls, " - "child1.id AS child1_id FROM parent JOIN child1 ON parent.id = child1.id) AS anon_1 ON " - "anon_1.parent_id = secondary_2.right_id WHERE anon_1.parent_id = secondary_1.right_id AND :param_1 = secondary_1.left_id", - dialect=default.DefaultDialect() - ) - - def test_eager_join(self): - session = create_session() - - c1 = Child1() - c1.left_child2 = Child2() - session.add(c1) - session.flush() - - q = session.query(Child1).options(eagerload('left_child2')) - - # test that the splicing of the join works here, doesnt break in the middle of "parent join child1" - self.assert_compile(q.limit(1).with_labels().statement, - "SELECT anon_1.parent_id AS anon_1_parent_id, anon_1.child1_id AS anon_1_child1_id, "\ - "anon_1.parent_cls AS anon_1_parent_cls, anon_2.parent_id AS anon_2_parent_id, "\ - "anon_2.child2_id AS anon_2_child2_id, anon_2.parent_cls AS anon_2_parent_cls FROM "\ - "(SELECT parent.id AS parent_id, child1.id AS child1_id, parent.cls AS parent_cls FROM parent "\ - "JOIN child1 ON parent.id = child1.id LIMIT 1) AS anon_1 LEFT OUTER JOIN secondary AS secondary_1 "\ - "ON anon_1.parent_id = secondary_1.right_id LEFT OUTER JOIN (SELECT parent.id AS parent_id, "\ - "parent.cls AS parent_cls, child2.id AS child2_id FROM parent JOIN child2 ON parent.id = child2.id) "\ - "AS anon_2 ON anon_2.parent_id = secondary_1.left_id" - , dialect=default.DefaultDialect()) - - # another way to check - assert q.limit(1).with_labels().subquery().count().scalar() == 1 - - assert q.first() is c1 - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/selects.py b/test/orm/inheritance/selects.py deleted file mode 100644 index e54a0ad13..000000000 --- a/test/orm/inheritance/selects.py +++ /dev/null @@ -1,53 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.orm import * - -from testlib import testing -from orm._fixtures import Base -from orm._base import MappedTest - - -class InheritingSelectablesTest(MappedTest): - def define_tables(self, metadata): - global foo, bar, baz - foo = Table('foo', metadata, - Column('a', String(30), primary_key=1), - Column('b', String(30), nullable=0)) - - bar = foo.select(foo.c.b == 'bar').alias('bar') - baz = foo.select(foo.c.b == 'baz').alias('baz') - - def test_load(self): - # TODO: add persistence test also - testing.db.execute(foo.insert(), a='not bar', b='baz') - testing.db.execute(foo.insert(), a='also not bar', b='baz') - testing.db.execute(foo.insert(), a='i am bar', b='bar') - testing.db.execute(foo.insert(), a='also bar', b='bar') - - class Foo(Base): pass - class Bar(Foo): pass - class Baz(Foo): pass - - mapper(Foo, foo, polymorphic_on=foo.c.b) - - mapper(Baz, baz, - with_polymorphic=('*', foo.join(baz, foo.c.b=='baz').alias('baz')), - inherits=Foo, - inherit_condition=(foo.c.a==baz.c.a), - inherit_foreign_keys=[baz.c.a], - polymorphic_identity='baz') - - mapper(Bar, bar, - with_polymorphic=('*', foo.join(bar, foo.c.b=='bar').alias('bar')), - inherits=Foo, - inherit_condition=(foo.c.a==bar.c.a), - inherit_foreign_keys=[bar.c.a], - polymorphic_identity='bar') - - s = sessionmaker(bind=testing.db)() - - assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all() - assert [Bar(), Bar()] == s.query(Bar).all() - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/single.py b/test/orm/inheritance/single.py deleted file mode 100644 index 7aee25031..000000000 --- a/test/orm/inheritance/single.py +++ /dev/null @@ -1,396 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.orm import * - -from testlib import testing -from orm import _fixtures -from orm._base import MappedTest, ComparableEntity - - -class SingleInheritanceTest(MappedTest): - def define_tables(self, metadata): - Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), - Column('name', String(50)), - Column('manager_data', String(50)), - Column('engineer_info', String(50)), - Column('type', String(20))) - - Table('reports', metadata, - Column('report_id', Integer, primary_key=True), - Column('employee_id', ForeignKey('employees.employee_id')), - Column('name', String(50)), - ) - - def setup_classes(self): - class Employee(ComparableEntity): - pass - class Manager(Employee): - pass - class Engineer(Employee): - pass - class JuniorEngineer(Engineer): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Manager, inherits=Employee, polymorphic_identity='manager') - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') - mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer') - - @testing.resolve_artifact_names - def test_single_inheritance(self): - - session = create_session() - - m1 = Manager(name='Tom', manager_data='knows how to manage things') - e1 = Engineer(name='Kurt', engineer_info='knows how to hack') - e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') - session.add_all([m1, e1, e2]) - session.flush() - - assert session.query(Employee).all() == [m1, e1, e2] - assert session.query(Engineer).all() == [e1, e2] - assert session.query(Manager).all() == [m1] - assert session.query(JuniorEngineer).all() == [e2] - - m1 = session.query(Manager).one() - session.expire(m1, ['manager_data']) - self.assertEquals(m1.manager_data, "knows how to manage things") - - row = session.query(Engineer.name, Engineer.employee_id).filter(Engineer.name=='Kurt').first() - assert row.name == 'Kurt' - assert row.employee_id == e1.employee_id - - @testing.resolve_artifact_names - def test_multi_qualification(self): - session = create_session() - - m1 = Manager(name='Tom', manager_data='knows how to manage things') - e1 = Engineer(name='Kurt', engineer_info='knows how to hack') - e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') - - session.add_all([m1, e1, e2]) - session.flush() - - ealias = aliased(Engineer) - self.assertEquals( - session.query(Manager, ealias).all(), - [(m1, e1), (m1, e2)] - ) - - self.assertEquals( - session.query(Manager.name).all(), - [("Tom",)] - ) - - self.assertEquals( - session.query(Manager.name, ealias.name).all(), - [("Tom", "Kurt"), ("Tom", "Ed")] - ) - - self.assertEquals( - session.query(func.upper(Manager.name), func.upper(ealias.name)).all(), - [("TOM", "KURT"), ("TOM", "ED")] - ) - - self.assertEquals( - session.query(Manager).add_entity(ealias).all(), - [(m1, e1), (m1, e2)] - ) - - self.assertEquals( - session.query(Manager.name).add_column(ealias.name).all(), - [("Tom", "Kurt"), ("Tom", "Ed")] - ) - - # TODO: I think raise error on this for now - # self.assertEquals( - # session.query(Employee.name, Manager.manager_data, Engineer.engineer_info).all(), - # [] - # ) - - @testing.resolve_artifact_names - def test_select_from(self): - sess = create_session() - m1 = Manager(name='Tom', manager_data='data1') - m2 = Manager(name='Tom2', manager_data='data2') - e1 = Engineer(name='Kurt', engineer_info='knows how to hack') - e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') - sess.add_all([m1, m2, e1, e2]) - sess.flush() - - self.assertEquals( - sess.query(Manager).select_from(employees.select().limit(10)).all(), - [m1, m2] - ) - - @testing.resolve_artifact_names - def test_count(self): - sess = create_session() - m1 = Manager(name='Tom', manager_data='data1') - m2 = Manager(name='Tom2', manager_data='data2') - e1 = Engineer(name='Kurt', engineer_info='data3') - e2 = JuniorEngineer(name='marvin', engineer_info='data4') - sess.add_all([m1, m2, e1, e2]) - sess.flush() - - self.assertEquals(sess.query(Manager).count(), 2) - self.assertEquals(sess.query(Engineer).count(), 2) - self.assertEquals(sess.query(Employee).count(), 4) - - self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2) - self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3) - - @testing.resolve_artifact_names - def test_type_filtering(self): - class Report(ComparableEntity): pass - - mapper(Report, reports, properties={ - 'employee': relation(Employee, backref='reports')}) - sess = create_session() - - m1 = Manager(name='Tom', manager_data='data1') - r1 = Report(employee=m1) - sess.add_all([m1, r1]) - sess.flush() - rq = sess.query(Report) - - assert len(rq.filter(Report.employee.of_type(Manager).has()).all()) == 1 - assert len(rq.filter(Report.employee.of_type(Engineer).has()).all()) == 0 - - @testing.resolve_artifact_names - def test_type_joins(self): - class Report(ComparableEntity): pass - - mapper(Report, reports, properties={ - 'employee': relation(Employee, backref='reports')}) - sess = create_session() - - m1 = Manager(name='Tom', manager_data='data1') - r1 = Report(employee=m1) - sess.add_all([m1, r1]) - sess.flush() - - rq = sess.query(Report) - - assert len(rq.join(Report.employee.of_type(Manager)).all()) == 1 - assert len(rq.join(Report.employee.of_type(Engineer)).all()) == 0 - - -class RelationToSingleTest(MappedTest): - def define_tables(self, metadata): - Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), - Column('name', String(50)), - Column('manager_data', String(50)), - Column('engineer_info', String(50)), - Column('type', String(20)), - Column('company_id', Integer, ForeignKey('companies.company_id')) - ) - - Table('companies', metadata, - Column('company_id', Integer, primary_key=True), - Column('name', String(50)), - ) - - def setup_classes(self): - class Company(ComparableEntity): - pass - - class Employee(ComparableEntity): - pass - class Manager(Employee): - pass - class Engineer(Employee): - pass - class JuniorEngineer(Engineer): - pass - - @testing.resolve_artifact_names - def test_of_type(self): - mapper(Company, companies, properties={ - 'employees':relation(Employee, backref='company') - }) - mapper(Employee, employees, polymorphic_on=employees.c.type) - mapper(Manager, inherits=Employee, polymorphic_identity='manager') - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') - mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer') - sess = sessionmaker()() - - c1 = Company(name='c1') - c2 = Company(name='c2') - - m1 = Manager(name='Tom', manager_data='data1', company=c1) - m2 = Manager(name='Tom2', manager_data='data2', company=c2) - e1 = Engineer(name='Kurt', engineer_info='knows how to hack', company=c2) - e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed', company=c1) - sess.add_all([c1, c2, m1, m2, e1, e2]) - sess.commit() - sess.expunge_all() - self.assertEquals( - sess.query(Company).filter(Company.employees.of_type(JuniorEngineer).any()).all(), - [ - Company(name='c1'), - ] - ) - - self.assertEquals( - sess.query(Company).join(Company.employees.of_type(JuniorEngineer)).all(), - [ - Company(name='c1'), - ] - ) - - - @testing.resolve_artifact_names - def test_relation_to_subclass(self): - mapper(Company, companies, properties={ - 'engineers':relation(Engineer) - }) - mapper(Employee, employees, polymorphic_on=employees.c.type, properties={ - 'company':relation(Company) - }) - mapper(Manager, inherits=Employee, polymorphic_identity='manager') - mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') - mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer') - sess = sessionmaker()() - - c1 = Company(name='c1') - c2 = Company(name='c2') - - m1 = Manager(name='Tom', manager_data='data1', company=c1) - m2 = Manager(name='Tom2', manager_data='data2', company=c2) - e1 = Engineer(name='Kurt', engineer_info='knows how to hack', company=c2) - e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed', company=c1) - sess.add_all([c1, c2, m1, m2, e1, e2]) - sess.commit() - - self.assertEquals(c1.engineers, [e2]) - self.assertEquals(c2.engineers, [e1]) - - sess.expunge_all() - self.assertEquals(sess.query(Company).order_by(Company.name).all(), - [ - Company(name='c1', engineers=[JuniorEngineer(name='Ed')]), - Company(name='c2', engineers=[Engineer(name='Kurt')]) - ] - ) - - # eager load join should limit to only "Engineer" - sess.expunge_all() - self.assertEquals(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(), - [ - Company(name='c1', engineers=[JuniorEngineer(name='Ed')]), - Company(name='c2', engineers=[Engineer(name='Kurt')]) - ] - ) - - # join() to Company.engineers, Employee as the requested entity - sess.expunge_all() - self.assertEquals(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(), - [ - (Company(name='c1'), JuniorEngineer(name='Ed')), - (Company(name='c2'), Engineer(name='Kurt')) - ] - ) - - # join() to Company.engineers, Engineer as the requested entity. - # this actually applies the IN criterion twice which is less than ideal. - sess.expunge_all() - self.assertEquals(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(), - [ - (Company(name='c1'), JuniorEngineer(name='Ed')), - (Company(name='c2'), Engineer(name='Kurt')) - ] - ) - - # join() to Company.engineers without any Employee/Engineer entity - sess.expunge_all() - self.assertEquals(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(), - [ - Company(name='c2') - ] - ) - - # this however fails as it does not limit the subtypes to just "Engineer". - # with joins constructed by filter(), we seem to be following a policy where - # we don't try to make decisions on how to join to the target class, whereas when using join() we - # seem to have a lot more capabilities. - # we might want to document "advantages of join() vs. straight filtering", or add a large - # section to "inheritance" laying out all the various behaviors Query has. - @testing.fails_on_everything_except() - def go(): - sess.expunge_all() - self.assertEquals(sess.query(Company).\ - filter(Company.company_id==Engineer.company_id).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(), - [ - Company(name='c2') - ] - ) - go() - -class SingleOnJoinedTest(MappedTest): - def define_tables(self, metadata): - global persons_table, employees_table - - persons_table = Table('persons', metadata, - Column('person_id', Integer, primary_key=True), - Column('name', String(50)), - Column('type', String(20), nullable=False) - ) - - employees_table = Table('employees', metadata, - Column('person_id', Integer, ForeignKey('persons.person_id'),primary_key=True), - Column('employee_data', String(50)), - Column('manager_data', String(50)), - ) - - def test_single_on_joined(self): - class Person(_fixtures.Base): - pass - class Employee(Person): - pass - class Manager(Employee): - pass - - mapper(Person, persons_table, polymorphic_on=persons_table.c.type, polymorphic_identity='person') - mapper(Employee, employees_table, inherits=Person,polymorphic_identity='engineer') - mapper(Manager, inherits=Employee,polymorphic_identity='manager') - - sess = create_session() - sess.add(Person(name='p1')) - sess.add(Employee(name='e1', employee_data='ed1')) - sess.add(Manager(name='m1', employee_data='ed2', manager_data='md1')) - sess.flush() - sess.expunge_all() - - self.assertEquals(sess.query(Person).order_by(Person.person_id).all(), [ - Person(name='p1'), - Employee(name='e1', employee_data='ed1'), - Manager(name='m1', employee_data='ed2', manager_data='md1') - ]) - sess.expunge_all() - - self.assertEquals(sess.query(Employee).order_by(Person.person_id).all(), [ - Employee(name='e1', employee_data='ed1'), - Manager(name='m1', employee_data='ed2', manager_data='md1') - ]) - sess.expunge_all() - - self.assertEquals(sess.query(Manager).order_by(Person.person_id).all(), [ - Manager(name='m1', employee_data='ed2', manager_data='md1') - ]) - sess.expunge_all() - - def go(): - self.assertEquals(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [ - Person(name='p1'), - Employee(name='e1', employee_data='ed1'), - Manager(name='m1', employee_data='ed2', manager_data='md1') - ]) - self.assert_sql_count(testing.db, go, 1) - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/test_abc_inheritance.py b/test/orm/inheritance/test_abc_inheritance.py new file mode 100644 index 000000000..4e55cf70e --- /dev/null +++ b/test/orm/inheritance/test_abc_inheritance.py @@ -0,0 +1,170 @@ +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE + +from sqlalchemy.test import testing +from test.orm import _base + + +def produce_test(parent, child, direction): + """produce a testcase for A->B->C inheritance with a self-referential + relationship between two of the classes, using either one-to-many or + many-to-one.""" + class ABCTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global ta, tb, tc + ta = ["a", metadata] + ta.append(Column('id', Integer, primary_key=True)), + ta.append(Column('a_data', String(30))) + if "a"== parent and direction == MANYTOONE: + ta.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) + elif "a" == child and direction == ONETOMANY: + ta.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) + ta = Table(*ta) + + tb = ["b", metadata] + tb.append(Column('id', Integer, ForeignKey("a.id"), primary_key=True, )) + + tb.append(Column('b_data', String(30))) + + if "b"== parent and direction == MANYTOONE: + tb.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) + elif "b" == child and direction == ONETOMANY: + tb.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) + tb = Table(*tb) + + tc = ["c", metadata] + tc.append(Column('id', Integer, ForeignKey("b.id"), primary_key=True, )) + + tc.append(Column('c_data', String(30))) + + if "c"== parent and direction == MANYTOONE: + tc.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) + elif "c" == child and direction == ONETOMANY: + tc.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) + tc = Table(*tc) + + def teardown(self): + if direction == MANYTOONE: + parent_table = {"a":ta, "b":tb, "c": tc}[parent] + parent_table.update(values={parent_table.c.child_id:None}).execute() + elif direction == ONETOMANY: + child_table = {"a":ta, "b":tb, "c": tc}[child] + child_table.update(values={child_table.c.parent_id:None}).execute() + super(ABCTest, self).teardown() + + def test_roundtrip(self): + parent_table = {"a":ta, "b":tb, "c": tc}[parent] + child_table = {"a":ta, "b":tb, "c": tc}[child] + + remote_side = None + + if direction == MANYTOONE: + foreign_keys = [parent_table.c.child_id] + elif direction == ONETOMANY: + foreign_keys = [child_table.c.parent_id] + + atob = ta.c.id==tb.c.id + btoc = tc.c.id==tb.c.id + + if direction == ONETOMANY: + relationjoin = parent_table.c.id==child_table.c.parent_id + elif direction == MANYTOONE: + relationjoin = parent_table.c.child_id==child_table.c.id + if parent is child: + remote_side = [child_table.c.id] + + abcjoin = polymorphic_union( + {"a":ta.select(tb.c.id==None, from_obj=[ta.outerjoin(tb, onclause=atob)]), + "b":ta.join(tb, onclause=atob).outerjoin(tc, onclause=btoc).select(tc.c.id==None, fold_equivalents=True), + "c":tc.join(tb, onclause=btoc).join(ta, onclause=atob) + },"type", "abcjoin" + ) + + bcjoin = polymorphic_union( + { + "b":ta.join(tb, onclause=atob).outerjoin(tc, onclause=btoc).select(tc.c.id==None, fold_equivalents=True), + "c":tc.join(tb, onclause=btoc).join(ta, onclause=atob) + },"type", "bcjoin" + ) + class A(object): + def __init__(self, name): + self.a_data = name + class B(A):pass + class C(B):pass + + mapper(A, ta, polymorphic_on=abcjoin.c.type, with_polymorphic=('*', abcjoin), polymorphic_identity="a") + mapper(B, tb, polymorphic_on=bcjoin.c.type, with_polymorphic=('*', bcjoin), polymorphic_identity="b", inherits=A, inherit_condition=atob) + mapper(C, tc, polymorphic_identity="c", inherits=B, inherit_condition=btoc) + + parent_mapper = class_mapper({ta:A, tb:B, tc:C}[parent_table]) + child_mapper = class_mapper({ta:A, tb:B, tc:C}[child_table]) + + parent_class = parent_mapper.class_ + child_class = child_mapper.class_ + + parent_mapper.add_property("collection", relation(child_mapper, primaryjoin=relationjoin, foreign_keys=foreign_keys, remote_side=remote_side, uselist=True)) + + sess = create_session() + + parent_obj = parent_class('parent1') + child_obj = child_class('child1') + somea = A('somea') + someb = B('someb') + somec = C('somec') + + #print "APPENDING", parent.__class__.__name__ , "TO", child.__class__.__name__ + + sess.add(parent_obj) + parent_obj.collection.append(child_obj) + if direction == ONETOMANY: + child2 = child_class('child2') + parent_obj.collection.append(child2) + sess.add(child2) + elif direction == MANYTOONE: + parent2 = parent_class('parent2') + parent2.collection.append(child_obj) + sess.add(parent2) + sess.add(somea) + sess.add(someb) + sess.add(somec) + sess.flush() + sess.expunge_all() + + # assert result via direct get() of parent object + result = sess.query(parent_class).get(parent_obj.id) + assert result.id == parent_obj.id + assert result.collection[0].id == child_obj.id + if direction == ONETOMANY: + assert result.collection[1].id == child2.id + elif direction == MANYTOONE: + result2 = sess.query(parent_class).get(parent2.id) + assert result2.id == parent2.id + assert result2.collection[0].id == child_obj.id + + sess.expunge_all() + + # assert result via polymorphic load of parent object + result = sess.query(A).filter_by(id=parent_obj.id).one() + assert result.id == parent_obj.id + assert result.collection[0].id == child_obj.id + if direction == ONETOMANY: + assert result.collection[1].id == child2.id + elif direction == MANYTOONE: + result2 = sess.query(A).filter_by(id=parent2.id).one() + assert result2.id == parent2.id + assert result2.collection[0].id == child_obj.id + + ABCTest.__name__ = "Test%sTo%s%s" % (parent, child, (direction is ONETOMANY and "O2M" or "M2O")) + return ABCTest + +# test all combinations of polymorphic a/b/c related to another of a/b/c +for parent in ["a", "b", "c"]: + for child in ["a", "b", "c"]: + for direction in [ONETOMANY, MANYTOONE]: + testclass = produce_test(parent, child, direction) + exec("%s = testclass" % testclass.__name__) + del testclass + +del produce_test \ No newline at end of file diff --git a/test/orm/inheritance/test_abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py new file mode 100644 index 000000000..8cad8ed78 --- /dev/null +++ b/test/orm/inheritance/test_abc_polymorphic.py @@ -0,0 +1,88 @@ +from sqlalchemy import * +from sqlalchemy import util +from sqlalchemy.orm import * + +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures + +class ABCTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global a, b, c + a = Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('adata', String(30)), + Column('type', String(30)), + ) + b = Table('b', metadata, + Column('id', Integer, ForeignKey('a.id'), primary_key=True), + Column('bdata', String(30))) + c = Table('c', metadata, + Column('id', Integer, ForeignKey('b.id'), primary_key=True), + Column('cdata', String(30))) + + def make_test(fetchtype): + def test_roundtrip(self): + class A(_fixtures.Base):pass + class B(A):pass + class C(B):pass + + if fetchtype == 'union': + abc = a.outerjoin(b).outerjoin(c) + bc = a.join(b).outerjoin(c) + else: + abc = bc = None + + mapper(A, a, with_polymorphic=('*', abc), polymorphic_on=a.c.type, polymorphic_identity='a') + mapper(B, b, with_polymorphic=('*', bc), inherits=A, polymorphic_identity='b') + mapper(C, c, inherits=B, polymorphic_identity='c') + + a1 = A(adata='a1') + b1 = B(bdata='b1', adata='b1') + b2 = B(bdata='b2', adata='b2') + b3 = B(bdata='b3', adata='b3') + c1 = C(cdata='c1', bdata='c1', adata='c1') + c2 = C(cdata='c2', bdata='c2', adata='c2') + c3 = C(cdata='c2', bdata='c2', adata='c2') + + sess = create_session() + for x in (a1, b1, b2, b3, c1, c2, c3): + sess.add(x) + sess.flush() + sess.expunge_all() + + #for obj in sess.query(A).all(): + # print obj + assert [ + A(adata='a1'), + B(bdata='b1', adata='b1'), + B(bdata='b2', adata='b2'), + B(bdata='b3', adata='b3'), + C(cdata='c1', bdata='c1', adata='c1'), + C(cdata='c2', bdata='c2', adata='c2'), + C(cdata='c2', bdata='c2', adata='c2'), + ] == sess.query(A).all() + + assert [ + B(bdata='b1', adata='b1'), + B(bdata='b2', adata='b2'), + B(bdata='b3', adata='b3'), + C(cdata='c1', bdata='c1', adata='c1'), + C(cdata='c2', bdata='c2', adata='c2'), + C(cdata='c2', bdata='c2', adata='c2'), + ] == sess.query(B).all() + + assert [ + C(cdata='c1', bdata='c1', adata='c1'), + C(cdata='c2', bdata='c2', adata='c2'), + C(cdata='c2', bdata='c2', adata='c2'), + ] == sess.query(C).all() + + test_roundtrip = function_named( + test_roundtrip, 'test_%s' % fetchtype) + return test_roundtrip + + test_union = make_test('union') + test_none = make_test('none') + + diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py new file mode 100644 index 000000000..fc4aae17d --- /dev/null +++ b/test/orm/inheritance/test_basic.py @@ -0,0 +1,1027 @@ +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy import * +from sqlalchemy import exc as sa_exc, util +from sqlalchemy.orm import * +from sqlalchemy.orm import exc as orm_exc + +from sqlalchemy.test import testing, engines +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures + +class O2MTest(_base.MappedTest): + """deals with inheritance and one-to-many relationships""" + @classmethod + def define_tables(cls, metadata): + global foo, bar, blub + foo = Table('foo', metadata, + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), + Column('data', String(20))) + + bar = Table('bar', metadata, + Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('data', String(20))) + + blub = Table('blub', metadata, + Column('id', Integer, ForeignKey('bar.id'), primary_key=True), + Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False), + Column('data', String(20))) + + def testbasic(self): + class Foo(object): + def __init__(self, data=None): + self.data = data + def __repr__(self): + return "Foo id %d, data %s" % (self.id, self.data) + mapper(Foo, foo) + + class Bar(Foo): + def __repr__(self): + return "Bar id %d, data %s" % (self.id, self.data) + + mapper(Bar, bar, inherits=Foo) + + class Blub(Bar): + def __repr__(self): + return "Blub id %d, data %s" % (self.id, self.data) + + mapper(Blub, blub, inherits=Bar, properties={ + 'parent_foo':relation(Foo) + }) + + sess = create_session() + b1 = Blub("blub #1") + b2 = Blub("blub #2") + f = Foo("foo #1") + sess.add(b1) + sess.add(b2) + sess.add(f) + b1.parent_foo = f + b2.parent_foo = f + sess.flush() + compare = ','.join([repr(b1), repr(b2), repr(b1.parent_foo), repr(b2.parent_foo)]) + sess.expunge_all() + l = sess.query(Blub).all() + result = ','.join([repr(l[0]), repr(l[1]), repr(l[0].parent_foo), repr(l[1].parent_foo)]) + print compare + print result + self.assert_(compare == result) + self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') + +class FalseDiscriminatorTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global t1 + t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('type', Integer, nullable=False)) + + def test_false_discriminator(self): + class Foo(object):pass + class Bar(Foo):pass + mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=1) + mapper(Bar, inherits=Foo, polymorphic_identity=0) + sess = create_session() + f1 = Bar() + sess.add(f1) + sess.flush() + assert f1.type == 0 + sess.expunge_all() + assert isinstance(sess.query(Foo).one(), Bar) + +class PolymorphicSynonymTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global t1, t2 + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('type', String(10), nullable=False), + Column('info', String(255))) + t2 = Table('t2', metadata, + Column('id', Integer, ForeignKey('t1.id'), primary_key=True), + Column('data', String(10), nullable=False)) + + def test_polymorphic_synonym(self): + class T1(_fixtures.Base): + def info(self): + return "THE INFO IS:" + self._info + def _set_info(self, x): + self._info = x + info = property(info, _set_info) + + class T2(T1):pass + + mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', properties={ + 'info':synonym('_info', map_column=True) + }) + mapper(T2, t2, inherits=T1, polymorphic_identity='t2') + sess = create_session() + at1 = T1(info='at1') + at2 = T2(info='at2', data='t2 data') + sess.add(at1) + sess.add(at2) + sess.flush() + sess.expunge_all() + eq_(sess.query(T2).filter(T2.info=='at2').one(), at2) + eq_(at2.info, "THE INFO IS:at2") + + +class CascadeTest(_base.MappedTest): + """that cascades on polymorphic relations continue + cascading along the path of the instance's mapper, not + the base mapper.""" + + @classmethod + def define_tables(cls, metadata): + global t1, t2, t3, t4 + t1= Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)) + ) + + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('t1id', Integer, ForeignKey('t1.id')), + Column('type', String(30)), + Column('data', String(30)) + ) + t3 = Table('t3', metadata, + Column('id', Integer, ForeignKey('t2.id'), primary_key=True), + Column('moredata', String(30))) + + t4 = Table('t4', metadata, + Column('id', Integer, primary_key=True), + Column('t3id', Integer, ForeignKey('t3.id')), + Column('data', String(30))) + + def test_cascade(self): + class T1(_fixtures.Base): + pass + class T2(_fixtures.Base): + pass + class T3(T2): + pass + class T4(_fixtures.Base): + pass + + mapper(T1, t1, properties={ + 't2s':relation(T2, cascade="all") + }) + mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2') + mapper(T3, t3, inherits=T2, polymorphic_identity='t3', properties={ + 't4s':relation(T4, cascade="all") + }) + mapper(T4, t4) + + sess = create_session() + t1_1 = T1(data='t1') + + t3_1 = T3(data ='t3', moredata='t3') + t2_1 = T2(data='t2') + + t1_1.t2s.append(t2_1) + t1_1.t2s.append(t3_1) + + t4_1 = T4(data='t4') + t3_1.t4s.append(t4_1) + + sess.add(t1_1) + + + assert t4_1 in sess.new + sess.flush() + + sess.delete(t1_1) + assert t4_1 in sess.deleted + sess.flush() + +class GetTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global foo, bar, blub + foo = Table('foo', metadata, + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), + Column('type', String(30)), + Column('data', String(20))) + + bar = Table('bar', metadata, + Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('data', String(20))) + + blub = Table('blub', metadata, + Column('id', Integer, primary_key=True), + Column('foo_id', Integer, ForeignKey('foo.id')), + Column('bar_id', Integer, ForeignKey('bar.id')), + Column('data', String(20))) + + def _create_test(polymorphic, name): + def test_get(self): + class Foo(object): + pass + + class Bar(Foo): + pass + + class Blub(Bar): + pass + + if polymorphic: + mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo') + mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar') + mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub') + else: + mapper(Foo, foo) + mapper(Bar, bar, inherits=Foo) + mapper(Blub, blub, inherits=Bar) + + sess = create_session() + f = Foo() + b = Bar() + bl = Blub() + sess.add(f) + sess.add(b) + sess.add(bl) + sess.flush() + + if polymorphic: + def go(): + assert sess.query(Foo).get(f.id) == f + assert sess.query(Foo).get(b.id) == b + assert sess.query(Foo).get(bl.id) == bl + assert sess.query(Bar).get(b.id) == b + assert sess.query(Bar).get(bl.id) == bl + assert sess.query(Blub).get(bl.id) == bl + + self.assert_sql_count(testing.db, go, 0) + else: + # this is testing the 'wrong' behavior of using get() + # polymorphically with mappers that are not configured to be + # polymorphic. the important part being that get() always + # returns an instance of the query's type. + def go(): + assert sess.query(Foo).get(f.id) == f + + bb = sess.query(Foo).get(b.id) + assert isinstance(b, Foo) and bb.id==b.id + + bll = sess.query(Foo).get(bl.id) + assert isinstance(bll, Foo) and bll.id==bl.id + + assert sess.query(Bar).get(b.id) == b + + bll = sess.query(Bar).get(bl.id) + assert isinstance(bll, Bar) and bll.id == bl.id + + assert sess.query(Blub).get(bl.id) == bl + + self.assert_sql_count(testing.db, go, 3) + + test_get = function_named(test_get, name) + return test_get + + test_get_polymorphic = _create_test(True, 'test_get_polymorphic') + test_get_nonpolymorphic = _create_test(False, 'test_get_nonpolymorphic') + +class EagerLazyTest(_base.MappedTest): + """tests eager load/lazy load of child items off inheritance mappers, tests that + LazyLoader constructs the right query condition.""" + @classmethod + def define_tables(cls, metadata): + global foo, bar, bar_foo + foo = Table('foo', metadata, + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), + Column('data', String(30))) + bar = Table('bar', metadata, + Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('data', String(30))) + + bar_foo = Table('bar_foo', metadata, + Column('bar_id', Integer, ForeignKey('bar.id')), + Column('foo_id', Integer, ForeignKey('foo.id')) + ) + + @testing.fails_on('maxdb', 'FIXME: unknown') + def testbasic(self): + class Foo(object): pass + class Bar(Foo): pass + + foos = mapper(Foo, foo) + bars = mapper(Bar, bar, inherits=foos) + bars.add_property('lazy', relation(foos, bar_foo, lazy=True)) + bars.add_property('eager', relation(foos, bar_foo, lazy=False)) + + foo.insert().execute(data='foo1') + bar.insert().execute(id=1, data='bar1') + + foo.insert().execute(data='foo2') + bar.insert().execute(id=2, data='bar2') + + foo.insert().execute(data='foo3') #3 + foo.insert().execute(data='foo4') #4 + + bar_foo.insert().execute(bar_id=1, foo_id=3) + bar_foo.insert().execute(bar_id=2, foo_id=4) + + sess = create_session() + q = sess.query(Bar) + self.assert_(len(q.first().lazy) == 1) + self.assert_(len(q.first().eager) == 1) + + +class FlushTest(_base.MappedTest): + """test dependency sorting among inheriting mappers""" + @classmethod + def define_tables(cls, metadata): + global users, roles, user_roles, admins + users = Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('email', String(128)), + Column('password', String(16)), + ) + + roles = Table('role', metadata, + Column('id', Integer, primary_key=True), + Column('description', String(32)) + ) + + user_roles = Table('user_role', metadata, + Column('user_id', Integer, ForeignKey('users.id'), primary_key=True), + Column('role_id', Integer, ForeignKey('role.id'), primary_key=True) + ) + + admins = Table('admin', metadata, + Column('admin_id', Integer, primary_key=True), + Column('user_id', Integer, ForeignKey('users.id')) + ) + + def testone(self): + class User(object):pass + class Role(object):pass + class Admin(User):pass + role_mapper = mapper(Role, roles) + user_mapper = mapper(User, users, properties = { + 'roles' : relation(Role, secondary=user_roles, lazy=False) + } + ) + admin_mapper = mapper(Admin, admins, inherits=user_mapper) + sess = create_session() + adminrole = Role() + sess.add(adminrole) + sess.flush() + + # create an Admin, and append a Role. the dependency processors + # corresponding to the "roles" attribute for the Admin mapper and the User mapper + # have to ensure that two dependency processors dont fire off and insert the + # many to many row twice. + a = Admin() + a.roles.append(adminrole) + a.password = 'admin' + sess.add(a) + sess.flush() + + assert user_roles.count().scalar() == 1 + + def testtwo(self): + class User(object): + def __init__(self, email=None, password=None): + self.email = email + self.password = password + + class Role(object): + def __init__(self, description=None): + self.description = description + + class Admin(User):pass + + role_mapper = mapper(Role, roles) + user_mapper = mapper(User, users, properties = { + 'roles' : relation(Role, secondary=user_roles, lazy=False) + } + ) + + admin_mapper = mapper(Admin, admins, inherits=user_mapper) + + # create roles + adminrole = Role('admin') + + sess = create_session() + sess.add(adminrole) + sess.flush() + + # create admin user + a = Admin(email='tim', password='admin') + a.roles.append(adminrole) + sess.add(a) + sess.flush() + + a.password = 'sadmin' + sess.flush() + assert user_roles.count().scalar() == 1 + +class VersioningTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global base, subtable, stuff + base = Table('base', metadata, + Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ), + Column('version_id', Integer, nullable=False), + Column('value', String(40)), + Column('discriminator', Integer, nullable=False) + ) + subtable = Table('subtable', metadata, + Column('id', None, ForeignKey('base.id'), primary_key=True), + Column('subdata', String(50)) + ) + stuff = Table('stuff', metadata, + Column('id', Integer, primary_key=True), + Column('parent', Integer, ForeignKey('base.id')) + ) + + @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.') + @engines.close_open_connections + def test_save_update(self): + class Base(_fixtures.Base): + pass + class Sub(Base): + pass + class Stuff(Base): + pass + mapper(Stuff, stuff) + mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1, properties={ + 'stuff':relation(Stuff) + }) + mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) + + sess = create_session() + + b1 = Base(value='b1') + s1 = Sub(value='sub1', subdata='some subdata') + sess.add(b1) + sess.add(s1) + + sess.flush() + + sess2 = create_session() + s2 = sess2.query(Base).get(s1.id) + s2.subdata = 'sess2 subdata' + + s1.subdata = 'sess1 subdata' + + sess.flush() + + try: + sess2.query(Base).with_lockmode('read').get(s1.id) + assert False + except orm_exc.ConcurrentModificationError, e: + assert True + + try: + sess2.flush() + assert False + except orm_exc.ConcurrentModificationError, e: + assert True + + sess2.refresh(s2) + assert s2.subdata == 'sess1 subdata' + s2.subdata = 'sess2 subdata' + sess2.flush() + + @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.') + def test_delete(self): + class Base(_fixtures.Base): + pass + class Sub(Base): + pass + + mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1) + mapper(Sub, subtable, inherits=Base, polymorphic_identity=2) + + sess = create_session() + + b1 = Base(value='b1') + s1 = Sub(value='sub1', subdata='some subdata') + s2 = Sub(value='sub2', subdata='some other subdata') + sess.add(b1) + sess.add(s1) + sess.add(s2) + + sess.flush() + + sess2 = create_session() + s3 = sess2.query(Base).get(s1.id) + sess2.delete(s3) + sess2.flush() + + s2.subdata = 'some new subdata' + sess.flush() + + try: + s1.subdata = 'some new subdata' + sess.flush() + assert False + except orm_exc.ConcurrentModificationError, e: + assert True + +class DistinctPKTest(_base.MappedTest): + """test the construction of mapper.primary_key when an inheriting relationship + joins on a column other than primary key column.""" + + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + global person_table, employee_table, Person, Employee + + person_table = Table("persons", metadata, + Column("id", Integer, primary_key=True), + Column("name", String(80)), + ) + + employee_table = Table("employees", metadata, + Column("id", Integer, primary_key=True), + Column("salary", Integer), + Column("person_id", Integer, ForeignKey("persons.id")), + ) + + class Person(object): + def __init__(self, name): + self.name = name + + class Employee(Person): pass + + @classmethod + def insert_data(cls): + person_insert = person_table.insert() + person_insert.execute(id=1, name='alice') + person_insert.execute(id=2, name='bob') + + employee_insert = employee_table.insert() + employee_insert.execute(id=2, salary=250, person_id=1) # alice + employee_insert.execute(id=3, salary=200, person_id=2) # bob + + def test_implicit(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper) + assert list(class_mapper(Employee).primary_key) == [person_table.c.id] + + def test_explicit_props(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id}) + self._do_test(True) + + def test_explicit_composite_pk(self): + person_mapper = mapper(Person, person_table) + try: + mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id]) + self._do_test(True) + assert False + except sa_exc.SAWarning, e: + assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e) + + def test_explicit_pk(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id]) + self._do_test(False) + + def _do_test(self, composite): + session = create_session() + query = session.query(Employee) + + if composite: + alice1 = query.get([1,2]) + bob = query.get([2,3]) + alice2 = query.get([1,2]) + else: + alice1 = query.get(1) + bob = query.get(2) + alice2 = query.get(1) + + assert alice1.name == alice2.name == 'alice' + assert bob.name == 'bob' + +class SyncCompileTest(_base.MappedTest): + """test that syncrules compile properly on custom inherit conds""" + @classmethod + def define_tables(cls, metadata): + global _a_table, _b_table, _c_table + + _a_table = Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('data1', String(128)) + ) + + _b_table = Table('b', metadata, + Column('a_id', Integer, ForeignKey('a.id'), primary_key=True), + Column('data2', String(128)) + ) + + _c_table = Table('c', metadata, + # Column('a_id', Integer, ForeignKey('b.a_id'), primary_key=True), #works + Column('b_a_id', Integer, ForeignKey('b.a_id'), primary_key=True), + Column('data3', String(128)) + ) + + def test_joins(self): + for j1 in (None, _b_table.c.a_id==_a_table.c.id, _a_table.c.id==_b_table.c.a_id): + for j2 in (None, _b_table.c.a_id==_c_table.c.b_a_id, _c_table.c.b_a_id==_b_table.c.a_id): + self._do_test(j1, j2) + for t in reversed(_a_table.metadata.sorted_tables): + t.delete().execute().close() + + def _do_test(self, j1, j2): + class A(object): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + class B(A): + pass + + class C(B): + pass + + mapper(A, _a_table) + mapper(B, _b_table, inherits=A, + inherit_condition=j1 + ) + mapper(C, _c_table, inherits=B, + inherit_condition=j2 + ) + + session = create_session() + + a = A(data1='a1') + session.add(a) + + b = B(data1='b1', data2='b2') + session.add(b) + + c = C(data1='c1', data2='c2', data3='c3') + session.add(c) + + session.flush() + session.expunge_all() + + assert len(session.query(A).all()) == 3 + assert len(session.query(B).all()) == 2 + assert len(session.query(C).all()) == 1 + +class OverrideColKeyTest(_base.MappedTest): + """test overriding of column attributes.""" + + @classmethod + def define_tables(cls, metadata): + global base, subtable + + base = Table('base', metadata, + Column('base_id', Integer, primary_key=True), + Column('data', String(255)), + Column('sqlite_fixer', String(10)) + ) + + subtable = Table('subtable', metadata, + Column('base_id', Integer, ForeignKey('base.base_id'), primary_key=True), + Column('subdata', String(255)) + ) + + def test_plain(self): + # control case + class Base(object): + pass + class Sub(Base): + pass + + mapper(Base, base) + mapper(Sub, subtable, inherits=Base) + + # Sub gets a "base_id" property using the "base_id" + # column of both tables. + eq_( + class_mapper(Sub).get_property('base_id').columns, + [base.c.base_id, subtable.c.base_id] + ) + + def test_override_explicit(self): + # this pattern is what you see when using declarative + # in particular, here we do a "manual" version of + # what we'd like the mapper to do. + + class Base(object): + pass + class Sub(Base): + pass + + mapper(Base, base, properties={ + 'id':base.c.base_id + }) + mapper(Sub, subtable, inherits=Base, properties={ + # this is the manual way to do it, is not really + # possible in declarative + 'id':[base.c.base_id, subtable.c.base_id] + }) + + eq_( + class_mapper(Sub).get_property('id').columns, + [base.c.base_id, subtable.c.base_id] + ) + + s1 = Sub() + s1.id = 10 + sess = create_session() + sess.add(s1) + sess.flush() + assert sess.query(Sub).get(10) is s1 + + def test_override_onlyinparent(self): + class Base(object): + pass + class Sub(Base): + pass + + mapper(Base, base, properties={ + 'id':base.c.base_id + }) + mapper(Sub, subtable, inherits=Base) + + eq_( + class_mapper(Sub).get_property('id').columns, + [base.c.base_id] + ) + + eq_( + class_mapper(Sub).get_property('base_id').columns, + [subtable.c.base_id] + ) + + s1 = Sub() + s1.id = 10 + + s2 = Sub() + s2.base_id = 15 + + sess = create_session() + sess.add_all([s1, s2]) + sess.flush() + + # s1 gets '10' + assert sess.query(Sub).get(10) is s1 + + # s2 gets a new id, base_id is overwritten by the ultimate + # PK col + assert s2.id == s2.base_id != 15 + + def test_override_implicit(self): + # this is how the pattern looks intuitively when + # using declarative. + # fixed as part of [ticket:1111] + + class Base(object): + pass + class Sub(Base): + pass + + mapper(Base, base, properties={ + 'id':base.c.base_id + }) + mapper(Sub, subtable, inherits=Base, properties={ + 'id':subtable.c.base_id + }) + + # Sub mapper compilation needs to detect that "base.c.base_id" + # is renamed in the inherited mapper as "id", even though + # it has its own "id" property. Sub's "id" property + # gets joined normally with the extra column. + + eq_( + class_mapper(Sub).get_property('id').columns, + [base.c.base_id, subtable.c.base_id] + ) + + s1 = Sub() + s1.id = 10 + sess = create_session() + sess.add(s1) + sess.flush() + assert sess.query(Sub).get(10) is s1 + + def test_plain_descriptor(self): + """test that descriptors prevent inheritance from propigating properties to subclasses.""" + + class Base(object): + pass + class Sub(Base): + @property + def data(self): + return "im the data" + + mapper(Base, base) + mapper(Sub, subtable, inherits=Base) + + s1 = Sub() + sess = create_session() + sess.add(s1) + sess.flush() + assert sess.query(Sub).one().data == "im the data" + + def test_custom_descriptor(self): + """test that descriptors prevent inheritance from propigating properties to subclasses.""" + + class MyDesc(object): + def __get__(self, instance, owner): + if instance is None: + return self + return "im the data" + + class Base(object): + pass + class Sub(Base): + data = MyDesc() + + mapper(Base, base) + mapper(Sub, subtable, inherits=Base) + + s1 = Sub() + sess = create_session() + sess.add(s1) + sess.flush() + assert sess.query(Sub).one().data == "im the data" + + def test_sub_columns_over_base_descriptors(self): + class Base(object): + @property + def subdata(self): + return "this is base" + + class Sub(Base): + pass + + mapper(Base, base) + mapper(Sub, subtable, inherits=Base) + + sess = create_session() + b1 = Base() + assert b1.subdata == "this is base" + s1 = Sub() + s1.subdata = "this is sub" + assert s1.subdata == "this is sub" + + sess.add_all([s1, b1]) + sess.flush() + sess.expunge_all() + + assert sess.query(Base).get(b1.base_id).subdata == "this is base" + assert sess.query(Sub).get(s1.base_id).subdata == "this is sub" + + def test_base_descriptors_over_base_cols(self): + class Base(object): + @property + def data(self): + return "this is base" + + class Sub(Base): + pass + + mapper(Base, base) + mapper(Sub, subtable, inherits=Base) + + sess = create_session() + b1 = Base() + assert b1.data == "this is base" + s1 = Sub() + assert s1.data == "this is base" + + sess.add_all([s1, b1]) + sess.flush() + sess.expunge_all() + + assert sess.query(Base).get(b1.base_id).data == "this is base" + assert sess.query(Sub).get(s1.base_id).data == "this is base" + +class OptimizedLoadTest(_base.MappedTest): + """test that the 'optimized load' routine doesn't crash when + a column in the join condition is not available. + + """ + @classmethod + def define_tables(cls, metadata): + global base, sub + base = Table('base', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('type', String(50)) + ) + sub = Table('sub', metadata, + Column('id', Integer, ForeignKey('base.id'), primary_key=True), + Column('sub', String(50)) + ) + + def test_optimized_passes(self): + class Base(object): + pass + class Sub(Base): + pass + + mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base') + + # redefine Sub's "id" to favor the "id" col in the subtable. + # "id" is also part of the primary join condition + mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={'id':sub.c.id}) + sess = create_session() + s1 = Sub() + s1.data = 's1data' + s1.sub = 's1sub' + sess.add(s1) + sess.flush() + sess.expunge_all() + + # load s1 via Base. s1.id won't populate since it's relative to + # the "sub" table. The optimized load kicks in and tries to + # generate on the primary join, but cannot since "id" is itself unloaded. + # the optimized load needs to return "None" so regular full-row loading proceeds + s1 = sess.query(Base).get(s1.id) + assert s1.sub == 's1sub' + +class PKDiscriminatorTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + parents = Table('parents', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(60))) + + children = Table('children', metadata, + Column('id', Integer, ForeignKey('parents.id'), primary_key=True), + Column('type', Integer,primary_key=True), + Column('name', String(60))) + + @testing.resolve_artifact_names + def test_pk_as_discriminator(self): + class Parent(object): + def __init__(self, name=None): + self.name = name + + class Child(object): + def __init__(self, name=None): + self.name = name + + class A(Child): + pass + + mapper(Parent, parents, properties={ + 'children': relation(Child, backref='parent'), + }) + mapper(Child, children, polymorphic_on=children.c.type, + polymorphic_identity=1) + + mapper(A, inherits=Child, polymorphic_identity=2) + + s = create_session() + p = Parent('p1') + a = A('a1') + p.children.append(a) + s.add(p) + s.flush() + + assert a.id + assert a.type == 2 + + +class DeleteOrphanTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global single, parent + single = Table('single', metadata, + Column('id', Integer, primary_key=True), + Column('type', String(50), nullable=False), + Column('data', String(50)), + Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False), + ) + + parent = Table('parent', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + def test_orphan_message(self): + class Base(_fixtures.Base): + pass + + class SubClass(Base): + pass + + class Parent(_fixtures.Base): + pass + + mapper(Base, single, polymorphic_on=single.c.type, polymorphic_identity='base') + mapper(SubClass, inherits=Base, polymorphic_identity='sub') + mapper(Parent, parent, properties={ + 'related':relation(Base, cascade="all, delete-orphan") + }) + + sess = create_session() + s1 = SubClass(data='s1') + sess.add(s1) + assert_raises_message(orm_exc.FlushError, + "is not attached to any parent 'Parent' instance via that classes' 'related' attribute", sess.flush) + + diff --git a/test/orm/inheritance/test_concrete.py b/test/orm/inheritance/test_concrete.py new file mode 100644 index 000000000..4a884cb86 --- /dev/null +++ b/test/orm/inheritance/test_concrete.py @@ -0,0 +1,519 @@ +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.orm import exc as orm_exc +from sqlalchemy.test import * +import sqlalchemy as sa +from sqlalchemy.test import testing +from test.orm import _base +from sqlalchemy.orm import attributes +from sqlalchemy.test.testing import eq_ + +class Employee(object): + def __init__(self, name): + self.name = name + def __repr__(self): + return self.__class__.__name__ + " " + self.name + +class Manager(Employee): + def __init__(self, name, manager_data): + self.name = name + self.manager_data = manager_data + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.manager_data + +class Engineer(Employee): + def __init__(self, name, engineer_info): + self.name = name + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " " + self.engineer_info + +class Hacker(Engineer): + def __init__(self, name, nickname, engineer_info): + self.name = name + self.nickname = nickname + self.engineer_info = engineer_info + def __repr__(self): + return self.__class__.__name__ + " " + self.name + " '" + \ + self.nickname + "' " + self.engineer_info + +class Company(object): + pass + + +class ConcreteTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global managers_table, engineers_table, hackers_table, companies, employees_table + + companies = Table('companies', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50))) + + employees_table = Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) + ) + + managers_table = Table('managers', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('manager_data', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) + ) + + engineers_table = Table('engineers', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('engineer_info', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')) + ) + + hackers_table = Table('hackers', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('engineer_info', String(50)), + Column('company_id', Integer, ForeignKey('companies.id')), + Column('nickname', String(50)) + ) + + + + def test_basic(self): + pjoin = polymorphic_union({ + 'manager':managers_table, + 'engineer':engineers_table + }, 'type', 'pjoin') + + employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) + manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, + concrete=True, polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, + concrete=True, polymorphic_identity='engineer') + + session = create_session() + session.add(Manager('Tom', 'knows how to manage things')) + session.add(Engineer('Kurt', 'knows how to hack')) + session.flush() + session.expunge_all() + + assert set([repr(x) for x in session.query(Employee)]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Manager)]) == set(["Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Engineer)]) == set(["Engineer Kurt knows how to hack"]) + + manager = session.query(Manager).one() + session.expire(manager, ['manager_data']) + eq_(manager.manager_data, "knows how to manage things") + + def test_multi_level_no_base(self): + pjoin = polymorphic_union({ + 'manager': managers_table, + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin') + + pjoin2 = polymorphic_union({ + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin2') + + employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) + manager_mapper = mapper(Manager, managers_table, + inherits=employee_mapper, concrete=True, + polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, + with_polymorphic=('*', pjoin2), + polymorphic_on=pjoin2.c.type, + inherits=employee_mapper, concrete=True, + polymorphic_identity='engineer') + hacker_mapper = mapper(Hacker, hackers_table, + inherits=engineer_mapper, + concrete=True, polymorphic_identity='hacker') + + session = create_session() + tom = Manager('Tom', 'knows how to manage things') + jerry = Engineer('Jerry', 'knows how to program') + hacker = Hacker('Kurt', 'Badass', 'knows how to hack') + session.add_all((tom, jerry, hacker)) + session.flush() + + # ensure "readonly" on save logic didn't pollute the expired_attributes + # collection + assert 'nickname' not in attributes.instance_state(jerry).expired_attributes + assert 'name' not in attributes.instance_state(jerry).expired_attributes + assert 'name' not in attributes.instance_state(hacker).expired_attributes + assert 'nickname' not in attributes.instance_state(hacker).expired_attributes + def go(): + eq_(jerry.name, "Jerry") + eq_(hacker.nickname, "Badass") + self.assert_sql_count(testing.db, go, 0) + + session.expunge_all() + + assert repr(session.query(Employee).filter(Employee.name=='Tom').one()) == "Manager Tom knows how to manage things" + assert repr(session.query(Manager).filter(Manager.name=='Tom').one()) == "Manager Tom knows how to manage things" + + + assert set([repr(x) for x in session.query(Employee).all()]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Hacker).all()]) == set(["Hacker Kurt 'Badass' knows how to hack"]) + + def test_multi_level_with_base(self): + pjoin = polymorphic_union({ + 'employee':employees_table, + 'manager': managers_table, + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin') + + pjoin2 = polymorphic_union({ + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin2') + + employee_mapper = mapper(Employee, employees_table, + with_polymorphic=('*', pjoin), polymorphic_on=pjoin.c.type) + manager_mapper = mapper(Manager, managers_table, + inherits=employee_mapper, concrete=True, + polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, + with_polymorphic=('*', pjoin2), + polymorphic_on=pjoin2.c.type, + inherits=employee_mapper, concrete=True, + polymorphic_identity='engineer') + hacker_mapper = mapper(Hacker, hackers_table, + inherits=engineer_mapper, + concrete=True, polymorphic_identity='hacker') + + session = create_session() + tom = Manager('Tom', 'knows how to manage things') + jerry = Engineer('Jerry', 'knows how to program') + hacker = Hacker('Kurt', 'Badass', 'knows how to hack') + session.add_all((tom, jerry, hacker)) + session.flush() + + def go(): + eq_(jerry.name, "Jerry") + eq_(hacker.nickname, "Badass") + self.assert_sql_count(testing.db, go, 0) + + session.expunge_all() + + # check that we aren't getting a cartesian product in the raw SQL. + # this requires that Engineer's polymorphic discriminator is not rendered + # in the statement which is only against Employee's "pjoin" + assert len(testing.db.execute(session.query(Employee).with_labels().statement).fetchall()) == 3 + + assert set([repr(x) for x in session.query(Employee)]) == set(["Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Manager)]) == set(["Manager Tom knows how to manage things"]) + assert set([repr(x) for x in session.query(Engineer)]) == set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) + assert set([repr(x) for x in session.query(Hacker)]) == set(["Hacker Kurt 'Badass' knows how to hack"]) + + + def test_without_default_polymorphic(self): + pjoin = polymorphic_union({ + 'employee':employees_table, + 'manager': managers_table, + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin') + + pjoin2 = polymorphic_union({ + 'engineer': engineers_table, + 'hacker': hackers_table + }, 'type', 'pjoin2') + + employee_mapper = mapper(Employee, employees_table, + polymorphic_identity='employee') + manager_mapper = mapper(Manager, managers_table, + inherits=employee_mapper, concrete=True, + polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, + inherits=employee_mapper, concrete=True, + polymorphic_identity='engineer') + hacker_mapper = mapper(Hacker, hackers_table, + inherits=engineer_mapper, + concrete=True, polymorphic_identity='hacker') + + session = create_session() + jdoe = Employee('Jdoe') + tom = Manager('Tom', 'knows how to manage things') + jerry = Engineer('Jerry', 'knows how to program') + hacker = Hacker('Kurt', 'Badass', 'knows how to hack') + session.add_all((jdoe, tom, jerry, hacker)) + session.flush() + + eq_( + len(testing.db.execute(session.query(Employee).with_polymorphic('*', pjoin, pjoin.c.type).with_labels().statement).fetchall()), + 4 + ) + + eq_( + session.query(Employee).get(jdoe.employee_id), jdoe + ) + eq_( + session.query(Engineer).get(jerry.employee_id), jerry + ) + eq_( + set([repr(x) for x in session.query(Employee).with_polymorphic('*', pjoin, pjoin.c.type)]), + set(["Employee Jdoe", "Engineer Jerry knows how to program", "Manager Tom knows how to manage things", "Hacker Kurt 'Badass' knows how to hack"]) + ) + eq_( + set([repr(x) for x in session.query(Manager)]), + set(["Manager Tom knows how to manage things"]) + ) + eq_( + set([repr(x) for x in session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type)]), + set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) + ) + eq_( + set([repr(x) for x in session.query(Hacker)]), + set(["Hacker Kurt 'Badass' knows how to hack"]) + ) + # test adaption of the column by wrapping the query in a subquery + eq_( + len(testing.db.execute( + session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type).from_self().statement + ).fetchall()), + 2 + ) + eq_( + set([repr(x) for x in session.query(Engineer).with_polymorphic('*', pjoin2, pjoin2.c.type).from_self()]), + set(["Engineer Jerry knows how to program", "Hacker Kurt 'Badass' knows how to hack"]) + ) + + def test_relation(self): + pjoin = polymorphic_union({ + 'manager':managers_table, + 'engineer':engineers_table + }, 'type', 'pjoin') + + mapper(Company, companies, properties={ + 'employees':relation(Employee) + }) + employee_mapper = mapper(Employee, pjoin, polymorphic_on=pjoin.c.type) + manager_mapper = mapper(Manager, managers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='manager') + engineer_mapper = mapper(Engineer, engineers_table, inherits=employee_mapper, concrete=True, polymorphic_identity='engineer') + + session = create_session() + c = Company() + c.employees.append(Manager('Tom', 'knows how to manage things')) + c.employees.append(Engineer('Kurt', 'knows how to hack')) + session.add(c) + session.flush() + session.expunge_all() + + def go(): + c2 = session.query(Company).get(c.id) + assert set([repr(x) for x in c2.employees]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) + self.assert_sql_count(testing.db, go, 2) + session.expunge_all() + def go(): + c2 = session.query(Company).options(eagerload(Company.employees)).get(c.id) + assert set([repr(x) for x in c2.employees]) == set(["Engineer Kurt knows how to hack", "Manager Tom knows how to manage things"]) + self.assert_sql_count(testing.db, go, 1) + +class PropertyInheritanceTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('a_table', metadata, + Column('id', Integer, primary_key=True), + Column('some_c_id', Integer, ForeignKey('c_table.id')), + Column('aname', String(50)), + ) + Table('b_table', metadata, + Column('id', Integer, primary_key=True), + Column('some_c_id', Integer, ForeignKey('c_table.id')), + Column('bname', String(50)), + ) + Table('c_table', metadata, + Column('id', Integer, primary_key=True), + Column('cname', String(50)), + + ) + + @classmethod + def setup_classes(cls): + class A(_base.ComparableEntity): + pass + + class B(A): + pass + + class C(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def test_noninherited_warning(self): + mapper(A, a_table, properties={ + 'some_c':relation(C) + }) + mapper(B, b_table,inherits=A, concrete=True) + mapper(C, c_table) + + b = B() + c = C() + assert_raises(AttributeError, setattr, b, 'some_c', c) + + clear_mappers() + mapper(A, a_table, properties={ + 'a_id':a_table.c.id + }) + mapper(B, b_table,inherits=A, concrete=True) + mapper(C, c_table) + b = B() + assert_raises(AttributeError, setattr, b, 'a_id', 3) + + clear_mappers() + mapper(A, a_table, properties={ + 'a_id':a_table.c.id + }) + mapper(B, b_table,inherits=A, concrete=True) + mapper(C, c_table) + + @testing.resolve_artifact_names + def test_inheriting(self): + mapper(A, a_table, properties={ + 'some_c':relation(C, back_populates='many_a') + }) + mapper(B, b_table,inherits=A, concrete=True, properties={ + 'some_c':relation(C, back_populates='many_b') + }) + mapper(C, c_table, properties={ + 'many_a':relation(A, back_populates='some_c'), + 'many_b':relation(B, back_populates='some_c'), + }) + + sess = sessionmaker()() + + c1 = C(cname='c1') + c2 = C(cname='c2') + a1 = A(some_c=c1, aname='a1') + a2 = A(some_c=c2, aname='a2') + b1 = B(some_c=c1, bname='b1') + b2 = B(some_c=c1, bname='b2') + + assert_raises(AttributeError, setattr, b1, 'aname', 'foo') + assert_raises(AttributeError, getattr, A, 'bname') + + assert c2.many_a == [a2] + assert c1.many_a == [a1] + assert c1.many_b == [b1, b2] + + sess.add_all([c1, c2]) + sess.commit() + + assert sess.query(C).filter(C.many_a.contains(a2)).one() is c2 + assert c2.many_a == [a2] + assert c1.many_a == [a1] + assert c1.many_b == [b1, b2] + + assert sess.query(B).filter(B.bname=='b1').one() is b1 + + @testing.resolve_artifact_names + def test_polymorphic_backref(self): + """test multiple backrefs to the same polymorphically-loading attribute.""" + + ajoin = polymorphic_union( + {'a':a_table, + 'b':b_table + }, 'type', 'ajoin' + ) + mapper(A, a_table, with_polymorphic=('*', ajoin), + polymorphic_on=ajoin.c.type, polymorphic_identity='a', + properties={ + 'some_c':relation(C, back_populates='many_a') + }) + mapper(B, b_table,inherits=A, concrete=True, + polymorphic_identity='b', + properties={ + 'some_c':relation(C, back_populates='many_a') + }) + mapper(C, c_table, properties={ + 'many_a':relation(A, back_populates='some_c', order_by=ajoin.c.id), + }) + + sess = sessionmaker()() + + c1 = C(cname='c1') + c2 = C(cname='c2') + a1 = A(some_c=c1, aname='a1', id=1) + a2 = A(some_c=c2, aname='a2', id=2) + b1 = B(some_c=c1, bname='b1', id=3) + b2 = B(some_c=c1, bname='b2', id=4) + + eq_([a2], c2.many_a) + eq_([a1, b1, b2], c1.many_a) + + sess.add_all([c1, c2]) + sess.commit() + + assert sess.query(C).filter(C.many_a.contains(a2)).one() is c2 + assert sess.query(C).filter(C.many_a.contains(b1)).one() is c1 + eq_(c2.many_a, [a2]) + eq_(c1.many_a, [a1, b1, b2]) + + sess.expire_all() + + def go(): + eq_( + [C(many_a=[A(aname='a1'), B(bname='b1'), B(bname='b2')]), C(many_a=[A(aname='a2')])], + sess.query(C).options(eagerload(C.many_a)).order_by(C.id).all(), + ) + self.assert_sql_count(testing.db, go, 1) + + +class ColKeysTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global offices_table, refugees_table + refugees_table = Table('refugee', metadata, + Column('refugee_fid', Integer, primary_key=True), + Column('refugee_name', Unicode(30), key='name')) + + offices_table = Table('office', metadata, + Column('office_fid', Integer, primary_key=True), + Column('office_name', Unicode(30), key='name')) + + @classmethod + def insert_data(cls): + refugees_table.insert().execute( + dict(refugee_fid=1, name=u"refugee1"), + dict(refugee_fid=2, name=u"refugee2") + ) + offices_table.insert().execute( + dict(office_fid=1, name=u"office1"), + dict(office_fid=2, name=u"office2") + ) + + def test_keys(self): + pjoin = polymorphic_union({ + 'refugee': refugees_table, + 'office': offices_table + }, 'type', 'pjoin') + class Location(object): + pass + + class Refugee(Location): + pass + + class Office(Location): + pass + + location_mapper = mapper(Location, pjoin, polymorphic_on=pjoin.c.type, + polymorphic_identity='location') + office_mapper = mapper(Office, offices_table, inherits=location_mapper, + concrete=True, polymorphic_identity='office') + refugee_mapper = mapper(Refugee, refugees_table, inherits=location_mapper, + concrete=True, polymorphic_identity='refugee') + + sess = create_session() + eq_(sess.query(Refugee).get(1).name, "refugee1") + eq_(sess.query(Refugee).get(2).name, "refugee2") + + eq_(sess.query(Office).get(1).name, "office1") + eq_(sess.query(Office).get(2).name, "office2") + diff --git a/test/orm/inheritance/test_magazine.py b/test/orm/inheritance/test_magazine.py new file mode 100644 index 000000000..067301251 --- /dev/null +++ b/test/orm/inheritance/test_magazine.py @@ -0,0 +1,219 @@ +from sqlalchemy import * +from sqlalchemy.orm import * + +from sqlalchemy.test import testing +from sqlalchemy.util import function_named +from test.orm import _base + +class BaseObject(object): + def __init__(self, *args, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) +class Publication(BaseObject): + pass + +class Issue(BaseObject): + pass + +class Location(BaseObject): + def __repr__(self): + return "%s(%s, %s)" % (self.__class__.__name__, str(getattr(self, 'issue_id', None)), repr(str(self._name.name))) + + def _get_name(self): + return self._name + + def _set_name(self, name): + session = create_session() + s = session.query(LocationName).filter(LocationName.name==name).first() + session.expunge_all() + if s is not None: + self._name = s + + return + + found = False + + for i in session.new: + if isinstance(i, LocationName) and i.name == name: + self._name = i + found = True + + break + + if found == False: + self._name = LocationName(name=name) + + name = property(_get_name, _set_name) + +class LocationName(BaseObject): + def __repr__(self): + return "%s()" % (self.__class__.__name__) + +class PageSize(BaseObject): + def __repr__(self): + return "%s(%sx%s, %s)" % (self.__class__.__name__, self.width, self.height, self.name) + +class Magazine(BaseObject): + def __repr__(self): + return "%s(%s, %s)" % (self.__class__.__name__, repr(self.location), repr(self.size)) + +class Page(BaseObject): + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, str(self.page_no)) + +class MagazinePage(Page): + def __repr__(self): + return "%s(%s, %s)" % (self.__class__.__name__, str(self.page_no), repr(self.magazine)) + +class ClassifiedPage(MagazinePage): + pass + + +class MagazineTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global publication_table, issue_table, location_table, location_name_table, magazine_table, \ + page_table, magazine_page_table, classified_page_table, page_size_table + + zerodefault = {} #{'default':0} + publication_table = Table('publication', metadata, + Column('id', Integer, primary_key=True, default=None), + Column('name', String(45), default=''), + ) + issue_table = Table('issue', metadata, + Column('id', Integer, primary_key=True, default=None), + Column('publication_id', Integer, ForeignKey('publication.id'), **zerodefault), + Column('issue', Integer, **zerodefault), + ) + location_table = Table('location', metadata, + Column('id', Integer, primary_key=True, default=None), + Column('issue_id', Integer, ForeignKey('issue.id'), **zerodefault), + Column('ref', CHAR(3), default=''), + Column('location_name_id', Integer, ForeignKey('location_name.id'), **zerodefault), + ) + location_name_table = Table('location_name', metadata, + Column('id', Integer, primary_key=True, default=None), + Column('name', String(45), default=''), + ) + magazine_table = Table('magazine', metadata, + Column('id', Integer, primary_key=True, default=None), + Column('location_id', Integer, ForeignKey('location.id'), **zerodefault), + Column('page_size_id', Integer, ForeignKey('page_size.id'), **zerodefault), + ) + page_table = Table('page', metadata, + Column('id', Integer, primary_key=True, default=None), + Column('page_no', Integer, **zerodefault), + Column('type', CHAR(1), default='p'), + ) + magazine_page_table = Table('magazine_page', metadata, + Column('page_id', Integer, ForeignKey('page.id'), primary_key=True, **zerodefault), + Column('magazine_id', Integer, ForeignKey('magazine.id'), **zerodefault), + Column('orders', TEXT, default=''), + ) + classified_page_table = Table('classified_page', metadata, + Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True, **zerodefault), + Column('titles', String(45), default=''), + ) + page_size_table = Table('page_size', metadata, + Column('id', Integer, primary_key=True, default=None), + Column('width', Integer, **zerodefault), + Column('height', Integer, **zerodefault), + Column('name', String(45), default=''), + ) + +def generate_round_trip_test(use_unions=False, use_joins=False): + def test_roundtrip(self): + publication_mapper = mapper(Publication, publication_table) + + issue_mapper = mapper(Issue, issue_table, properties = { + 'publication': relation(Publication, backref=backref('issues', cascade="all, delete-orphan")), + }) + + location_name_mapper = mapper(LocationName, location_name_table) + + location_mapper = mapper(Location, location_table, properties = { + 'issue': relation(Issue, backref=backref('locations', lazy=False, cascade="all, delete-orphan")), + '_name': relation(LocationName), + }) + + page_size_mapper = mapper(PageSize, page_size_table) + + magazine_mapper = mapper(Magazine, magazine_table, properties = { + 'location': relation(Location, backref=backref('magazine', uselist=False)), + 'size': relation(PageSize), + }) + + if use_unions: + page_join = polymorphic_union( + { + 'm': page_table.join(magazine_page_table), + 'c': page_table.join(magazine_page_table).join(classified_page_table), + 'p': page_table.select(page_table.c.type=='p'), + }, None, 'page_join') + page_mapper = mapper(Page, page_table, with_polymorphic=('*', page_join), polymorphic_on=page_join.c.type, polymorphic_identity='p') + elif use_joins: + page_join = page_table.outerjoin(magazine_page_table).outerjoin(classified_page_table) + page_mapper = mapper(Page, page_table, with_polymorphic=('*', page_join), polymorphic_on=page_table.c.type, polymorphic_identity='p') + else: + page_mapper = mapper(Page, page_table, polymorphic_on=page_table.c.type, polymorphic_identity='p') + + if use_unions: + magazine_join = polymorphic_union( + { + 'm': page_table.join(magazine_page_table), + 'c': page_table.join(magazine_page_table).join(classified_page_table), + }, None, 'page_join') + magazine_page_mapper = mapper(MagazinePage, magazine_page_table, with_polymorphic=('*', magazine_join), inherits=page_mapper, polymorphic_identity='m', properties={ + 'magazine': relation(Magazine, backref=backref('pages', order_by=magazine_join.c.page_no)) + }) + elif use_joins: + magazine_join = page_table.join(magazine_page_table).outerjoin(classified_page_table) + magazine_page_mapper = mapper(MagazinePage, magazine_page_table, with_polymorphic=('*', magazine_join), inherits=page_mapper, polymorphic_identity='m', properties={ + 'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no)) + }) + else: + magazine_page_mapper = mapper(MagazinePage, magazine_page_table, inherits=page_mapper, polymorphic_identity='m', properties={ + 'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no)) + }) + + classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c', primary_key=[page_table.c.id]) + #compile_mappers() + #print [str(s) for s in classified_page_mapper.primary_key] + #print classified_page_mapper.columntoproperty[page_table.c.id] + + + session = create_session() + + pub = Publication(name='Test') + issue = Issue(issue=46,publication=pub) + + location = Location(ref='ABC',name='London',issue=issue) + + page_size = PageSize(name='A4',width=210,height=297) + + magazine = Magazine(location=location,size=page_size) + page = ClassifiedPage(magazine=magazine,page_no=1) + page2 = MagazinePage(magazine=magazine,page_no=2) + page3 = ClassifiedPage(magazine=magazine,page_no=3) + session.add(pub) + + session.flush() + print [x for x in session] + session.expunge_all() + + session.flush() + session.expunge_all() + p = session.query(Publication).filter(Publication.name=="Test").one() + + print p.issues[0].locations[0].magazine.pages + print [page, page2, page3] + assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]), repr(p.issues[0].locations[0].magazine.pages) + + test_roundtrip = function_named( + test_roundtrip, "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions")) + setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip) + +for (use_union, use_join) in [(True, False), (False, True), (False, False)]: + generate_round_trip_test(use_union, use_join) + + diff --git a/test/orm/inheritance/test_manytomany.py b/test/orm/inheritance/test_manytomany.py new file mode 100644 index 000000000..f7e676bbb --- /dev/null +++ b/test/orm/inheritance/test_manytomany.py @@ -0,0 +1,249 @@ +from sqlalchemy.test.testing import eq_ +from sqlalchemy import * +from sqlalchemy.orm import * + +from sqlalchemy.test import testing +from test.orm import _base + + +class InheritTest(_base.MappedTest): + """deals with inheritance and many-to-many relationships""" + @classmethod + def define_tables(cls, metadata): + global principals + global users + global groups + global user_group_map + + principals = Table('principals', metadata, + Column('principal_id', Integer, + Sequence('principal_id_seq', optional=False), + primary_key=True), + Column('name', String(50), nullable=False)) + + users = Table('prin_users', metadata, + Column('principal_id', Integer, + ForeignKey('principals.principal_id'), primary_key=True), + Column('password', String(50), nullable=False), + Column('email', String(50), nullable=False), + Column('login_id', String(50), nullable=False)) + + groups = Table('prin_groups', metadata, + Column('principal_id', Integer, + ForeignKey('principals.principal_id'), primary_key=True)) + + user_group_map = Table('prin_user_group_map', metadata, + Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), + primary_key=True ), + Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), + primary_key=True ), + ) + + def testbasic(self): + class Principal(object): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + + class User(Principal): + pass + + class Group(Principal): + pass + + mapper(Principal, principals) + mapper(User, users, inherits=Principal) + + mapper(Group, groups, inherits=Principal, properties={ + 'users': relation(User, secondary=user_group_map, + lazy=True, backref="groups") + }) + + g = Group(name="group1") + g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1")) + sess = create_session() + sess.add(g) + sess.flush() + # TODO: put an assertion + +class InheritTest2(_base.MappedTest): + """deals with inheritance and many-to-many relationships""" + @classmethod + def define_tables(cls, metadata): + global foo, bar, foo_bar + foo = Table('foo', metadata, + Column('id', Integer, Sequence('foo_id_seq', optional=True), + primary_key=True), + Column('data', String(20)), + ) + + bar = Table('bar', metadata, + Column('bid', Integer, ForeignKey('foo.id'), primary_key=True), + #Column('fid', Integer, ForeignKey('foo.id'), ) + ) + + foo_bar = Table('foo_bar', metadata, + Column('foo_id', Integer, ForeignKey('foo.id')), + Column('bar_id', Integer, ForeignKey('bar.bid'))) + + def testget(self): + class Foo(object): + def __init__(self, data=None): + self.data = data + class Bar(Foo):pass + + mapper(Foo, foo) + mapper(Bar, bar, inherits=Foo) + print foo.join(bar).primary_key + print class_mapper(Bar).primary_key + b = Bar('somedata') + sess = create_session() + sess.add(b) + sess.flush() + sess.expunge_all() + + # test that "bar.bid" does not need to be referenced in a get + # (ticket 185) + assert sess.query(Bar).get(b.id).id == b.id + + def testbasic(self): + class Foo(object): + def __init__(self, data=None): + self.data = data + + mapper(Foo, foo) + class Bar(Foo): + pass + + mapper(Bar, bar, inherits=Foo, properties={ + 'foos': relation(Foo, secondary=foo_bar, lazy=False) + }) + + sess = create_session() + b = Bar('barfoo') + sess.add(b) + sess.flush() + + f1 = Foo('subfoo1') + f2 = Foo('subfoo2') + b.foos.append(f1) + b.foos.append(f2) + + sess.flush() + sess.expunge_all() + + l = sess.query(Bar).all() + print l[0] + print l[0].foos + self.assert_unordered_result(l, Bar, +# {'id':1, 'data':'barfoo', 'bid':1, 'foos':(Foo, [{'id':2,'data':'subfoo1'}, {'id':3,'data':'subfoo2'}])}, + {'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])}, + ) + +class InheritTest3(_base.MappedTest): + """deals with inheritance and many-to-many relationships""" + @classmethod + def define_tables(cls, metadata): + global foo, bar, blub, bar_foo, blub_bar, blub_foo + + # the 'data' columns are to appease SQLite which cant handle a blank INSERT + foo = Table('foo', metadata, + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), + Column('data', String(20))) + + bar = Table('bar', metadata, + Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('data', String(20))) + + blub = Table('blub', metadata, + Column('id', Integer, ForeignKey('bar.id'), primary_key=True), + Column('data', String(20))) + + bar_foo = Table('bar_foo', metadata, + Column('bar_id', Integer, ForeignKey('bar.id')), + Column('foo_id', Integer, ForeignKey('foo.id'))) + + blub_bar = Table('bar_blub', metadata, + Column('blub_id', Integer, ForeignKey('blub.id')), + Column('bar_id', Integer, ForeignKey('bar.id'))) + + blub_foo = Table('blub_foo', metadata, + Column('blub_id', Integer, ForeignKey('blub.id')), + Column('foo_id', Integer, ForeignKey('foo.id'))) + + def testbasic(self): + class Foo(object): + def __init__(self, data=None): + self.data = data + def __repr__(self): + return "Foo id %d, data %s" % (self.id, self.data) + mapper(Foo, foo) + + class Bar(Foo): + def __repr__(self): + return "Bar id %d, data %s" % (self.id, self.data) + + mapper(Bar, bar, inherits=Foo, properties={ + 'foos' :relation(Foo, secondary=bar_foo, lazy=True) + }) + + sess = create_session() + b = Bar('bar #1') + sess.add(b) + b.foos.append(Foo("foo #1")) + b.foos.append(Foo("foo #2")) + sess.flush() + compare = repr(b) + repr(sorted([repr(o) for o in b.foos])) + sess.expunge_all() + l = sess.query(Bar).all() + print repr(l[0]) + repr(l[0].foos) + found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos])) + eq_(found, compare) + + @testing.fails_on('maxdb', 'FIXME: unknown') + def testadvanced(self): + class Foo(object): + def __init__(self, data=None): + self.data = data + def __repr__(self): + return "Foo id %d, data %s" % (self.id, self.data) + mapper(Foo, foo) + + class Bar(Foo): + def __repr__(self): + return "Bar id %d, data %s" % (self.id, self.data) + mapper(Bar, bar, inherits=Foo) + + class Blub(Bar): + def __repr__(self): + return "Blub id %d, data %s, bars %s, foos %s" % (self.id, self.data, repr([b for b in self.bars]), repr([f for f in self.foos])) + + mapper(Blub, blub, inherits=Bar, properties={ + 'bars':relation(Bar, secondary=blub_bar, lazy=False), + 'foos':relation(Foo, secondary=blub_foo, lazy=False), + }) + + sess = create_session() + f1 = Foo("foo #1") + b1 = Bar("bar #1") + b2 = Bar("bar #2") + bl1 = Blub("blub #1") + for o in (f1, b1, b2, bl1): + sess.add(o) + bl1.foos.append(f1) + bl1.bars.append(b2) + sess.flush() + compare = repr(bl1) + blubid = bl1.id + sess.expunge_all() + + l = sess.query(Blub).all() + print l + self.assert_(repr(l[0]) == compare) + sess.expunge_all() + x = sess.query(Blub).filter_by(id=blubid).one() + print x + self.assert_(repr(x) == compare) + + diff --git a/test/orm/inheritance/test_poly_linked_list.py b/test/orm/inheritance/test_poly_linked_list.py new file mode 100644 index 000000000..67b543f31 --- /dev/null +++ b/test/orm/inheritance/test_poly_linked_list.py @@ -0,0 +1,197 @@ +from sqlalchemy import * +from sqlalchemy.orm import * + +from test.orm import _base +from sqlalchemy.test import testing + + +class PolymorphicCircularTest(_base.MappedTest): + run_setup_mappers = 'once' + + @classmethod + def define_tables(cls, metadata): + global Table1, Table1B, Table2, Table3, Data + table1 = Table('table1', metadata, + Column('id', Integer, primary_key=True), + Column('related_id', Integer, ForeignKey('table1.id'), nullable=True), + Column('type', String(30)), + Column('name', String(30)) + ) + + table2 = Table('table2', metadata, + Column('id', Integer, ForeignKey('table1.id'), primary_key=True), + ) + + table3 = Table('table3', metadata, + Column('id', Integer, ForeignKey('table1.id'), primary_key=True), + ) + + data = Table('data', metadata, + Column('id', Integer, primary_key=True), + Column('node_id', Integer, ForeignKey('table1.id')), + Column('data', String(30)) + ) + + #join = polymorphic_union( + # { + # 'table3' : table1.join(table3), + # 'table2' : table1.join(table2), + # 'table1' : table1.select(table1.c.type.in_(['table1', 'table1b'])), + # }, None, 'pjoin') + + join = table1.outerjoin(table2).outerjoin(table3).alias('pjoin') + #join = None + + class Table1(object): + def __init__(self, name, data=None): + self.name = name + if data is not None: + self.data = data + def __repr__(self): + return "%s(%s, %s, %s)" % (self.__class__.__name__, self.id, repr(str(self.name)), repr(self.data)) + + class Table1B(Table1): + pass + + class Table2(Table1): + pass + + class Table3(Table1): + pass + + class Data(object): + def __init__(self, data): + self.data = data + def __repr__(self): + return "%s(%s, %s)" % (self.__class__.__name__, self.id, repr(str(self.data))) + + try: + # this is how the mapping used to work. ensure that this raises an error now + table1_mapper = mapper(Table1, table1, + select_table=join, + polymorphic_on=table1.c.type, + polymorphic_identity='table1', + properties={ + 'next': relation(Table1, + backref=backref('prev', foreignkey=join.c.id, uselist=False), + uselist=False, primaryjoin=join.c.id==join.c.related_id), + 'data':relation(mapper(Data, data)) + }, + order_by=table1.c.id) + table1_mapper.compile() + assert False + except: + assert True + clear_mappers() + + # currently, the "eager" relationships degrade to lazy relationships + # due to the polymorphic load. + # the "next" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential" + # exception now. since eager loading would never work for that relation anyway, its better that the user + # gets an exception instead of it silently not eager loading. + table1_mapper = mapper(Table1, table1, + #select_table=join, + polymorphic_on=table1.c.type, + polymorphic_identity='table1', + properties={ + 'next': relation(Table1, + backref=backref('prev', remote_side=table1.c.id, uselist=False), + uselist=False, primaryjoin=table1.c.id==table1.c.related_id), + 'data':relation(mapper(Data, data), lazy=False, order_by=data.c.id) + }, + order_by=table1.c.id + ) + + table1b_mapper = mapper(Table1B, inherits=table1_mapper, polymorphic_identity='table1b') + + table2_mapper = mapper(Table2, table2, + inherits=table1_mapper, + polymorphic_identity='table2') + + table3_mapper = mapper(Table3, table3, inherits=table1_mapper, polymorphic_identity='table3') + + table1_mapper.compile() + assert table1_mapper.primary_key == [table1.c.id], table1_mapper.primary_key + + @testing.fails_on('maxdb', 'FIXME: unknown') + def testone(self): + self._testlist([Table1, Table2, Table1, Table2]) + + @testing.fails_on('maxdb', 'FIXME: unknown') + def testtwo(self): + self._testlist([Table3]) + + @testing.fails_on('maxdb', 'FIXME: unknown') + def testthree(self): + self._testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1]) + + @testing.fails_on('maxdb', 'FIXME: unknown') + def testfour(self): + self._testlist([ + Table2('t2', [Data('data1'), Data('data2')]), + Table1('t1', []), + Table3('t3', [Data('data3')]), + Table1B('t1b', [Data('data4'), Data('data5')]) + ]) + + def _testlist(self, classes): + sess = create_session( ) + + # create objects in a linked list + count = 1 + obj = None + for c in classes: + if isinstance(c, type): + newobj = c('item %d' % count) + count += 1 + else: + newobj = c + if obj is not None: + obj.next = newobj + else: + t = newobj + obj = newobj + + # save to DB + sess.add(t) + sess.flush() + + # string version of the saved list + assertlist = [] + node = t + while (node): + assertlist.append(node) + n = node.next + if n is not None: + assert n.prev is node + node = n + original = repr(assertlist) + + + # clear and query forwards + sess.expunge_all() + node = sess.query(Table1).filter(Table1.id==t.id).first() + assertlist = [] + while (node): + assertlist.append(node) + n = node.next + if n is not None: + assert n.prev is node + node = n + forwards = repr(assertlist) + + # clear and query backwards + sess.expunge_all() + node = sess.query(Table1).filter(Table1.id==obj.id).first() + assertlist = [] + while (node): + assertlist.insert(0, node) + n = node.prev + if n is not None: + assert n.next is node + node = n + backwards = repr(assertlist) + + # everything should match ! + assert original == forwards == backwards + diff --git a/test/orm/inheritance/test_polymorph.py b/test/orm/inheritance/test_polymorph.py new file mode 100644 index 000000000..cd3b2d89e --- /dev/null +++ b/test/orm/inheritance/test_polymorph.py @@ -0,0 +1,304 @@ +"""tests basic polymorphic mapper loading/saving, minimal relations""" + +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.orm import exc as orm_exc +from sqlalchemy.test import Column, testing +from sqlalchemy.util import function_named +from test.orm import _fixtures, _base + +class Person(_fixtures.Base): + pass +class Engineer(Person): + pass +class Manager(Person): + pass +class Boss(Manager): + pass +class Company(_fixtures.Base): + pass + +class PolymorphTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global companies, people, engineers, managers, boss + + companies = Table('companies', metadata, + Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('name', String(50))) + + people = Table('people', metadata, + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('company_id', Integer, ForeignKey('companies.company_id')), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('engineer_name', String(50)), + Column('primary_language', String(50)), + ) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('manager_name', String(50)) + ) + + boss = Table('boss', metadata, + Column('boss_id', Integer, ForeignKey('managers.person_id'), primary_key=True), + Column('golf_swing', String(30)), + ) + + metadata.create_all() + +class InsertOrderTest(PolymorphTest): + def test_insert_order(self): + """test that classes of multiple types mix up mapper inserts + so that insert order of individual tables is maintained""" + person_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + 'person':people.select(people.c.type=='person'), + }, None, 'pjoin') + + person_mapper = mapper(Person, people, with_polymorphic=('*', person_join), polymorphic_on=person_join.c.type, polymorphic_identity='person') + + mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') + mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') + mapper(Company, companies, properties={ + 'employees': relation(Person, + backref='company', + order_by=person_join.c.person_id) + }) + + session = create_session() + c = Company(name='company1') + c.employees.append(Manager(status='AAB', manager_name='manager1', name='pointy haired boss')) + c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', name='dilbert')) + c.employees.append(Person(status='HHH', name='joesmith')) + c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', name='wally')) + c.employees.append(Manager(status='ABA', manager_name='manager2', name='jsmith')) + session.add(c) + session.flush() + session.expunge_all() + eq_(session.query(Company).get(c.company_id), c) + +class RelationToSubclassTest(PolymorphTest): + def test_basic(self): + """test a relation to an inheriting mapper where the relation is to a subclass + but the join condition is expressed by the parent table. + + also test that backrefs work in this case. + + this test touches upon a lot of the join/foreign key determination code in properties.py + and creates the need for properties.py to search for conditions individually within + the mapper's local table as well as the mapper's 'mapped' table, so that relations + requiring lots of specificity (like self-referential joins) as well as relations requiring + more generalization (like the example here) both come up with proper results.""" + + mapper(Person, people) + + mapper(Engineer, engineers, inherits=Person) + mapper(Manager, managers, inherits=Person) + + mapper(Company, companies, properties={ + 'managers': relation(Manager, backref="company") + }) + + sess = create_session() + + c = Company(name='company1') + c.managers.append(Manager(status='AAB', manager_name='manager1', name='pointy haired boss')) + sess.add(c) + sess.flush() + sess.expunge_all() + + eq_(sess.query(Company).filter_by(company_id=c.company_id).one(), c) + assert c.managers[0].company is c + +class RoundTripTest(PolymorphTest): + pass + +def _generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic): + """generates a round trip test. + + include_base - whether or not to include the base 'person' type in the union. + lazy_relation - whether or not the Company relation to People is lazy or eager. + redefine_colprop - if we redefine the 'name' column to be 'people_name' on the base Person class + use_literal_join - primary join condition is explicitly specified + """ + def test_roundtrip(self): + if with_polymorphic == 'unions': + if include_base: + person_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + 'person':people.select(people.c.type=='person'), + }, None, 'pjoin') + else: + person_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + }, None, 'pjoin') + + manager_join = people.join(managers).outerjoin(boss) + person_with_polymorphic = ['*', person_join] + manager_with_polymorphic = ['*', manager_join] + elif with_polymorphic == 'joins': + person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss) + manager_join = people.join(managers).outerjoin(boss) + person_with_polymorphic = ['*', person_join] + manager_with_polymorphic = ['*', manager_join] + elif with_polymorphic == 'auto': + person_with_polymorphic = '*' + manager_with_polymorphic = '*' + else: + person_with_polymorphic = None + manager_with_polymorphic = None + + if redefine_colprop: + person_mapper = mapper(Person, people, with_polymorphic=person_with_polymorphic, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name}) + else: + person_mapper = mapper(Person, people, with_polymorphic=person_with_polymorphic, polymorphic_on=people.c.type, polymorphic_identity='person') + + mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') + mapper(Manager, managers, inherits=person_mapper, with_polymorphic=manager_with_polymorphic, polymorphic_identity='manager') + + mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss') + + mapper(Company, companies, properties={ + 'employees': relation(Person, lazy=lazy_relation, + cascade="all, delete-orphan", + backref="company", order_by=people.c.person_id + ) + }) + + if redefine_colprop: + person_attribute_name = 'person_name' + else: + person_attribute_name = 'name' + + employees = [ + Manager(status='AAB', manager_name='manager1', **{person_attribute_name:'pointy haired boss'}), + Engineer(status='BBA', engineer_name='engineer1', primary_language='java', **{person_attribute_name:'dilbert'}), + ] + if include_base: + employees.append(Person(**{person_attribute_name:'joesmith'})) + employees += [ + Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'}), + Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'}) + ] + + pointy = employees[0] + jsmith = employees[-1] + dilbert = employees[1] + + session = create_session() + c = Company(name='company1') + c.employees = employees + session.add(c) + + session.flush() + session.expunge_all() + + eq_(session.query(Person).get(dilbert.person_id), dilbert) + session.expunge_all() + + eq_(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert) + session.expunge_all() + + def go(): + cc = session.query(Company).get(c.company_id) + eq_(cc.employees, employees) + + if not lazy_relation: + if with_polymorphic != 'none': + self.assert_sql_count(testing.db, go, 1) + else: + self.assert_sql_count(testing.db, go, 5) + + else: + if with_polymorphic != 'none': + self.assert_sql_count(testing.db, go, 2) + else: + self.assert_sql_count(testing.db, go, 6) + + # test selecting from the query, using the base mapped table (people) as the selection criterion. + # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" + eq_( + session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first(), + dilbert + ) + + assert session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first().person_id + + eq_( + session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first(), + dilbert + ) + + # test selecting from the query, joining against an alias of the base "people" table. test that + # the "palias" alias does *not* get sucked up into the "person_join" conversion. + palias = people.alias("palias") + dilbert = session.query(Person).get(dilbert.person_id) + assert dilbert is session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() + assert dilbert is session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() + assert dilbert is session.query(Person).filter((Engineer.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id)).first() + assert dilbert is session.query(Engineer).filter(Engineer.engineer_name=="engineer1")[0] + + dilbert.engineer_name = 'hes dibert!' + + session.flush() + session.expunge_all() + + def go(): + session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() + self.assert_sql_count(testing.db, go, 1) + session.expunge_all() + dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() + def go(): + # assert that only primary table is queried for already-present-in-session + d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() + self.assert_sql_count(testing.db, go, 1) + + # test standalone orphans + daboss = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'}) + session.add(daboss) + assert_raises(orm_exc.FlushError, session.flush) + c = session.query(Company).first() + daboss.company = c + manager_list = [e for e in c.employees if isinstance(e, Manager)] + session.flush() + session.expunge_all() + + eq_(session.query(Manager).order_by(Manager.person_id).all(), manager_list) + c = session.query(Company).first() + + session.delete(c) + session.flush() + + eq_(people.count().scalar(), 0) + + test_roundtrip = function_named( + test_roundtrip, "test_%s%s%s_%s" % ( + (lazy_relation and "lazy" or "eager"), + (include_base and "_inclbase" or ""), + (redefine_colprop and "_redefcol" or ""), + with_polymorphic)) + setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip) + +for lazy_relation in [True, False]: + for redefine_colprop in [True, False]: + for with_polymorphic in ['unions', 'joins', 'auto', 'none']: + if with_polymorphic == 'unions': + for include_base in [True, False]: + _generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic) + else: + _generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic) + diff --git a/test/orm/inheritance/test_polymorph2.py b/test/orm/inheritance/test_polymorph2.py new file mode 100644 index 000000000..51b6d4970 --- /dev/null +++ b/test/orm/inheritance/test_polymorph2.py @@ -0,0 +1,1121 @@ +"""this is a test suite consisting mainly of end-user test cases, testing all kinds of painful +inheritance setups for which we maintain compatibility. +""" + +from sqlalchemy.test.testing import eq_ +from sqlalchemy import * +from sqlalchemy import util +from sqlalchemy.orm import * + +from sqlalchemy.test import TestBase, AssertsExecutionResults, testing +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures +from sqlalchemy.test.testing import eq_ + +class AttrSettable(object): + def __init__(self, **kwargs): + [setattr(self, k, v) for k, v in kwargs.iteritems()] + def __repr__(self): + return self.__class__.__name__ + "(%s)" % (hex(id(self))) + + +class RelationTest1(_base.MappedTest): + """test self-referential relationships on polymorphic mappers""" + @classmethod + def define_tables(cls, metadata): + global people, managers + + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('manager_id', Integer, ForeignKey('managers.person_id', use_alter=True, name="mpid_fq")), + Column('name', String(50)), + Column('type', String(30))) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('manager_name', String(50)) + ) + + def teardown(self): + people.update(values={people.c.manager_id:None}).execute() + super(RelationTest1, self).teardown() + + def test_parent_refs_descendant(self): + class Person(AttrSettable): + pass + class Manager(Person): + pass + + # note that up until recently (0.4.4), we had to specify "foreign_keys" here + # for this primary join. + mapper(Person, people, properties={ + 'manager':relation(Manager, primaryjoin=(people.c.manager_id == + managers.c.person_id), + uselist=False, post_update=True) + }) + mapper(Manager, managers, inherits=Person, + inherit_condition=people.c.person_id==managers.c.person_id) + + eq_(class_mapper(Person).get_property('manager').synchronize_pairs, [(managers.c.person_id,people.c.manager_id)]) + + session = create_session() + p = Person(name='some person') + m = Manager(name='some manager') + p.manager = m + session.add(p) + session.flush() + session.expunge_all() + + p = session.query(Person).get(p.person_id) + m = session.query(Manager).get(m.person_id) + print p, m, p.manager + assert p.manager is m + + def test_descendant_refs_parent(self): + class Person(AttrSettable): + pass + class Manager(Person): + pass + + mapper(Person, people) + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, properties={ + 'employee':relation(Person, primaryjoin=(people.c.manager_id == + managers.c.person_id), + foreign_keys=[people.c.manager_id], + uselist=False, post_update=True) + }) + + session = create_session() + p = Person(name='some person') + m = Manager(name='some manager') + m.employee = p + session.add(m) + session.flush() + session.expunge_all() + + p = session.query(Person).get(p.person_id) + m = session.query(Manager).get(m.person_id) + print p, m, m.employee + assert m.employee is p + +class RelationTest2(_base.MappedTest): + """test self-referential relationships on polymorphic mappers""" + @classmethod + def define_tables(cls, metadata): + global people, managers, data + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + Column('type', String(30))) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('manager_id', Integer, ForeignKey('people.person_id')), + Column('status', String(30)), + ) + + data = Table('data', metadata, + Column('person_id', Integer, ForeignKey('managers.person_id'), primary_key=True), + Column('data', String(30)) + ) + + def testrelationonsubclass_j1_nodata(self): + self.do_test("join1", False) + def testrelationonsubclass_j2_nodata(self): + self.do_test("join2", False) + def testrelationonsubclass_j1_data(self): + self.do_test("join1", True) + def testrelationonsubclass_j2_data(self): + self.do_test("join2", True) + def testrelationonsubclass_j3_nodata(self): + self.do_test("join3", False) + def testrelationonsubclass_j3_data(self): + self.do_test("join3", True) + + def do_test(self, jointype="join1", usedata=False): + class Person(AttrSettable): + pass + class Manager(Person): + pass + + if jointype == "join1": + poly_union = polymorphic_union({ + 'person':people.select(people.c.type=='person'), + 'manager':join(people, managers, people.c.person_id==managers.c.person_id) + }, None) + polymorphic_on=poly_union.c.type + elif jointype == "join2": + poly_union = polymorphic_union({ + 'person':people.select(people.c.type=='person'), + 'manager':managers.join(people, people.c.person_id==managers.c.person_id) + }, None) + polymorphic_on=poly_union.c.type + elif jointype == "join3": + poly_union = None + polymorphic_on = people.c.type + + if usedata: + class Data(object): + def __init__(self, data): + self.data = data + mapper(Data, data) + + mapper(Person, people, with_polymorphic=('*', poly_union), polymorphic_identity='person', polymorphic_on=polymorphic_on) + + if usedata: + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager', + properties={ + 'colleague':relation(Person, primaryjoin=managers.c.manager_id==people.c.person_id, lazy=True, uselist=False), + 'data':relation(Data, uselist=False) + } + ) + else: + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager', + properties={ + 'colleague':relation(Person, primaryjoin=managers.c.manager_id==people.c.person_id, lazy=True, uselist=False) + } + ) + + sess = create_session() + p = Person(name='person1') + m = Manager(name='manager1') + m.colleague = p + if usedata: + m.data = Data('ms data') + sess.add(m) + sess.flush() + + sess.expunge_all() + p = sess.query(Person).get(p.person_id) + m = sess.query(Manager).get(m.person_id) + print p + print m + assert m.colleague is p + if usedata: + assert m.data.data == 'ms data' + +class RelationTest3(_base.MappedTest): + """test self-referential relationships on polymorphic mappers""" + @classmethod + def define_tables(cls, metadata): + global people, managers, data + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('colleague_id', Integer, ForeignKey('people.person_id')), + Column('name', String(50)), + Column('type', String(30))) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + ) + + data = Table('data', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('data', String(30)) + ) + +def _generate_test(jointype="join1", usedata=False): + def do_test(self): + class Person(AttrSettable): + pass + class Manager(Person): + pass + + if usedata: + class Data(object): + def __init__(self, data): + self.data = data + + if jointype == "join1": + poly_union = polymorphic_union({ + 'manager':managers.join(people, people.c.person_id==managers.c.person_id), + 'person':people.select(people.c.type=='person') + }, None) + elif jointype =="join2": + poly_union = polymorphic_union({ + 'manager':join(people, managers, people.c.person_id==managers.c.person_id), + 'person':people.select(people.c.type=='person') + }, None) + elif jointype == 'join3': + poly_union = people.outerjoin(managers) + elif jointype == "join4": + poly_union=None + + if usedata: + mapper(Data, data) + + if usedata: + mapper(Person, people, with_polymorphic=('*', poly_union), polymorphic_identity='person', polymorphic_on=people.c.type, + properties={ + 'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, remote_side=people.c.colleague_id, uselist=True), + 'data':relation(Data, uselist=False) + } + ) + else: + mapper(Person, people, with_polymorphic=('*', poly_union), polymorphic_identity='person', polymorphic_on=people.c.type, + properties={ + 'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, + remote_side=people.c.colleague_id, uselist=True) + } + ) + + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager') + + sess = create_session() + p = Person(name='person1') + p2 = Person(name='person2') + p3 = Person(name='person3') + m = Manager(name='manager1') + p.colleagues.append(p2) + m.colleagues.append(p3) + if usedata: + p.data = Data('ps data') + m.data = Data('ms data') + + sess.add(m) + sess.add(p) + sess.flush() + + sess.expunge_all() + p = sess.query(Person).get(p.person_id) + p2 = sess.query(Person).get(p2.person_id) + p3 = sess.query(Person).get(p3.person_id) + m = sess.query(Person).get(m.person_id) + print p, p2, p.colleagues, m.colleagues + assert len(p.colleagues) == 1 + assert p.colleagues == [p2] + assert m.colleagues == [p3] + if usedata: + assert p.data.data == 'ps data' + assert m.data.data == 'ms data' + + do_test = function_named( + do_test, 'test_relation_on_base_class_%s_%s' % ( + jointype, data and "nodata" or "data")) + return do_test + +for jointype in ["join1", "join2", "join3", "join4"]: + for data in (True, False): + func = _generate_test(jointype, data) + setattr(RelationTest3, func.__name__, func) +del func + +class RelationTest4(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global people, engineers, managers, cars + people = Table('people', metadata, + Column('person_id', Integer, primary_key=True), + Column('name', String(50))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30))) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('longer_status', String(70))) + + cars = Table('cars', metadata, + Column('car_id', Integer, primary_key=True), + Column('owner', Integer, ForeignKey('people.person_id'))) + + def testmanytoonepolymorphic(self): + """in this test, the polymorphic union is between two subclasses, but does not include the base table by itself + in the union. however, the primaryjoin condition is going to be against the base table, and its a many-to-one + relationship (unlike the test in polymorph.py) so the column in the base table is explicit. Can the ClauseAdapter + figure out how to alias the primaryjoin to the polymorphic union ?""" + + # class definitions + class Person(object): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + def __repr__(self): + return "Ordinary person %s" % self.name + class Engineer(Person): + def __repr__(self): + return "Engineer %s, status %s" % (self.name, self.status) + class Manager(Person): + def __repr__(self): + return "Manager %s, status %s" % (self.name, self.longer_status) + class Car(object): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + def __repr__(self): + return "Car number %d" % self.car_id + + # create a union that represents both types of joins. + employee_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + }, "type", 'employee_join') + + person_mapper = mapper(Person, people, with_polymorphic=('*', employee_join), polymorphic_on=employee_join.c.type, polymorphic_identity='person') + engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') + manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') + car_mapper = mapper(Car, cars, properties= {'employee':relation(person_mapper)}) + + session = create_session() + + # creating 5 managers named from M1 to E5 + for i in range(1,5): + session.add(Manager(name="M%d" % i,longer_status="YYYYYYYYY")) + # creating 5 engineers named from E1 to E5 + for i in range(1,5): + session.add(Engineer(name="E%d" % i,status="X")) + + session.flush() + + engineer4 = session.query(Engineer).filter(Engineer.name=="E4").first() + manager3 = session.query(Manager).filter(Manager.name=="M3").first() + + car1 = Car(employee=engineer4) + session.add(car1) + car2 = Car(employee=manager3) + session.add(car2) + session.flush() + + session.expunge_all() + + def go(): + testcar = session.query(Car).options(eagerload('employee')).get(car1.car_id) + assert str(testcar.employee) == "Engineer E4, status X" + self.assert_sql_count(testing.db, go, 1) + + print "----------------------------" + car1 = session.query(Car).get(car1.car_id) + print "----------------------------" + usingGet = session.query(person_mapper).get(car1.owner) + print "----------------------------" + usingProperty = car1.employee + print "----------------------------" + + # All print should output the same person (engineer E4) + assert str(engineer4) == "Engineer E4, status X" + print str(usingGet) + assert str(usingGet) == "Engineer E4, status X" + assert str(usingProperty) == "Engineer E4, status X" + + session.expunge_all() + print "-----------------------------------------------------------------" + # and now for the lightning round, eager ! + + def go(): + testcar = session.query(Car).options(eagerload('employee')).get(car1.car_id) + assert str(testcar.employee) == "Engineer E4, status X" + self.assert_sql_count(testing.db, go, 1) + + session.expunge_all() + s = session.query(Car) + c = s.join("employee").filter(Person.name=="E4")[0] + assert c.car_id==car1.car_id + +class RelationTest5(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global people, engineers, managers, cars + people = Table('people', metadata, + Column('person_id', Integer, primary_key=True), + Column('name', String(50)), + Column('type', String(50))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30))) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('longer_status', String(70))) + + cars = Table('cars', metadata, + Column('car_id', Integer, primary_key=True), + Column('owner', Integer, ForeignKey('people.person_id'))) + + def testeagerempty(self): + """an easy one...test parent object with child relation to an inheriting mapper, using eager loads, + works when there are no child objects present""" + class Person(object): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + def __repr__(self): + return "Ordinary person %s" % self.name + class Engineer(Person): + def __repr__(self): + return "Engineer %s, status %s" % (self.name, self.status) + class Manager(Person): + def __repr__(self): + return "Manager %s, status %s" % (self.name, self.longer_status) + class Car(object): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + def __repr__(self): + return "Car number %d" % self.car_id + + person_mapper = mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') + engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') + manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') + car_mapper = mapper(Car, cars, properties= {'manager':relation(manager_mapper, lazy=False)}) + + sess = create_session() + car1 = Car() + car2 = Car() + car2.manager = Manager() + sess.add(car1) + sess.add(car2) + sess.flush() + sess.expunge_all() + + carlist = sess.query(Car).all() + assert carlist[0].manager is None + assert carlist[1].manager.person_id == car2.manager.person_id + +class RelationTest6(_base.MappedTest): + """test self-referential relationships on a single joined-table inheritance mapper""" + @classmethod + def define_tables(cls, metadata): + global people, managers, data + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + ) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('colleague_id', Integer, ForeignKey('managers.person_id')), + Column('status', String(30)), + ) + + def testbasic(self): + class Person(AttrSettable): + pass + class Manager(Person): + pass + + mapper(Person, people) + # relationship is from people.join(managers) -> people.join(managers). self referential logic + # needs to be used to figure out the lazy clause, meaning create_lazy_clause must go from parent.mapped_table + # to parent.mapped_table + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, + properties={ + 'colleague':relation(Manager, primaryjoin=managers.c.colleague_id==managers.c.person_id, lazy=True, uselist=False) + } + ) + + sess = create_session() + m = Manager(name='manager1') + m2 =Manager(name='manager2') + m.colleague = m2 + sess.add(m) + sess.flush() + + sess.expunge_all() + m = sess.query(Manager).get(m.person_id) + m2 = sess.query(Manager).get(m2.person_id) + assert m.colleague is m2 + +class RelationTest7(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global people, engineers, managers, cars, offroad_cars + cars = Table('cars', metadata, + Column('car_id', Integer, primary_key=True), + Column('name', String(30))) + + offroad_cars = Table('offroad_cars', metadata, + Column('car_id',Integer, ForeignKey('cars.car_id'),nullable=False,primary_key=True)) + + people = Table('people', metadata, + Column('person_id', Integer, primary_key=True), + Column('car_id', Integer, ForeignKey('cars.car_id'), nullable=False), + Column('name', String(50))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('field', String(30))) + + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('category', String(70))) + + def test_manytoone_lazyload(self): + """test that lazy load clause to a polymorphic child mapper generates correctly [ticket:493]""" + class PersistentObject(object): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + + class Status(PersistentObject): + def __repr__(self): + return "Status %s" % self.name + + class Person(PersistentObject): + def __repr__(self): + return "Ordinary person %s" % self.name + + class Engineer(Person): + def __repr__(self): + return "Engineer %s, field %s" % (self.name, self.field) + + class Manager(Person): + def __repr__(self): + return "Manager %s, category %s" % (self.name, self.category) + + class Car(PersistentObject): + def __repr__(self): + return "Car number %d, name %s" % (self.car_id, self.name) + + class Offraod_Car(Car): + def __repr__(self): + return "Offroad Car number %d, name %s" % (self.car_id,self.name) + + employee_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + }, "type", 'employee_join') + + car_join = polymorphic_union( + { + 'car' : cars.outerjoin(offroad_cars).select(offroad_cars.c.car_id == None, fold_equivalents=True), + 'offroad' : cars.join(offroad_cars) + }, "type", 'car_join') + + car_mapper = mapper(Car, cars, + with_polymorphic=('*', car_join) ,polymorphic_on=car_join.c.type, + polymorphic_identity='car', + ) + offroad_car_mapper = mapper(Offraod_Car, offroad_cars, inherits=car_mapper, polymorphic_identity='offroad') + person_mapper = mapper(Person, people, + with_polymorphic=('*', employee_join), polymorphic_on=employee_join.c.type, + polymorphic_identity='person', + properties={ + 'car':relation(car_mapper) + }) + engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') + manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') + + session = create_session() + basic_car=Car(name="basic") + offroad_car=Offraod_Car(name="offroad") + + for i in range(1,4): + if i%2: + car=Car() + else: + car=Offraod_Car() + session.add(Manager(name="M%d" % i,category="YYYYYYYYY",car=car)) + session.add(Engineer(name="E%d" % i,field="X",car=car)) + session.flush() + session.expunge_all() + + r = session.query(Person).all() + for p in r: + assert p.car_id == p.car.car_id + +class RelationTest8(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global taggable, users + taggable = Table('taggable', metadata, + Column('id', Integer, primary_key=True), + Column('type', String(30)), + Column('owner_id', Integer, ForeignKey('taggable.id')), + ) + users = Table ('users', metadata, + Column('id', Integer, ForeignKey('taggable.id'), primary_key=True), + Column('data', String(50)), + ) + + def test_selfref_onjoined(self): + class Taggable(_base.ComparableEntity): + pass + + class User(Taggable): + pass + + mapper( Taggable, taggable, polymorphic_on=taggable.c.type, polymorphic_identity='taggable', properties = { + 'owner' : relation (User, + primaryjoin=taggable.c.owner_id ==taggable.c.id, + remote_side=taggable.c.id + ), + }) + + + mapper(User, users, inherits=Taggable, polymorphic_identity='user', + inherit_condition=users.c.id == taggable.c.id, + ) + + + u1 = User(data='u1') + t1 = Taggable(owner=u1) + sess = create_session() + sess.add(t1) + sess.flush() + + sess.expunge_all() + eq_( + sess.query(Taggable).order_by(Taggable.id).all(), + [User(data='u1'), Taggable(owner=User(data='u1'))] + ) + +class GenerativeTest(TestBase, AssertsExecutionResults): + @classmethod + def setup_class(cls): + # cars---owned by--- people (abstract) --- has a --- status + # | ^ ^ | + # | | | | + # | engineers managers | + # | | + # +--------------------------------------- has a ------+ + + global metadata, status, people, engineers, managers, cars + metadata = MetaData(testing.db) + # table definitions + status = Table('status', metadata, + Column('status_id', Integer, primary_key=True), + Column('name', String(20))) + + people = Table('people', metadata, + Column('person_id', Integer, primary_key=True), + Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), + Column('name', String(50))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('field', String(30))) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('category', String(70))) + + cars = Table('cars', metadata, + Column('car_id', Integer, primary_key=True), + Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), + Column('owner', Integer, ForeignKey('people.person_id'), nullable=False)) + + metadata.create_all() + + @classmethod + def teardown_class(cls): + metadata.drop_all() + def teardown(self): + clear_mappers() + for t in reversed(metadata.sorted_tables): + t.delete().execute() + + def testjointo(self): + # class definitions + class PersistentObject(object): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + class Status(PersistentObject): + def __repr__(self): + return "Status %s" % self.name + class Person(PersistentObject): + def __repr__(self): + return "Ordinary person %s" % self.name + class Engineer(Person): + def __repr__(self): + return "Engineer %s, field %s, status %s" % (self.name, self.field, self.status) + class Manager(Person): + def __repr__(self): + return "Manager %s, category %s, status %s" % (self.name, self.category, self.status) + class Car(PersistentObject): + def __repr__(self): + return "Car number %d" % self.car_id + + # create a union that represents both types of joins. + employee_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + }, "type", 'employee_join') + + status_mapper = mapper(Status, status) + person_mapper = mapper(Person, people, + with_polymorphic=('*', employee_join), polymorphic_on=employee_join.c.type, + polymorphic_identity='person', properties={'status':relation(status_mapper)}) + engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') + manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') + car_mapper = mapper(Car, cars, properties= {'employee':relation(person_mapper), 'status':relation(status_mapper)}) + + session = create_session() + + active = Status(name="active") + dead = Status(name="dead") + + session.add(active) + session.add(dead) + session.flush() + + # TODO: we haven't created assertions for all the data combinations created here + + # creating 5 managers named from M1 to M5 and 5 engineers named from E1 to E5 + # M4, M5, E4 and E5 are dead + for i in range(1,5): + if i<4: + st=active + else: + st=dead + session.add(Manager(name="M%d" % i,category="YYYYYYYYY",status=st)) + session.add(Engineer(name="E%d" % i,field="X",status=st)) + + session.flush() + + # get E4 + engineer4 = session.query(engineer_mapper).filter_by(name="E4").one() + + # create 2 cars for E4, one active and one dead + car1 = Car(employee=engineer4,status=active) + car2 = Car(employee=engineer4,status=dead) + session.add(car1) + session.add(car2) + session.flush() + + # this particular adapt used to cause a recursion overflow; + # added here for testing + e = exists([Car.owner], Car.owner==employee_join.c.person_id) + Query(Person)._adapt_clause(employee_join, False, False) + + r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active") + assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]" + r = session.query(Engineer).join('status').filter(Person.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active")).order_by(Person.name) + assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]" + + r = session.query(Person).filter(exists([1], Car.owner==Person.person_id)) + assert str(list(r)) == "[Engineer E4, field X, status Status dead]" + +class MultiLevelTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global table_Employee, table_Engineer, table_Manager + table_Employee = Table( 'Employee', metadata, + Column( 'name', type_= String(100), ), + Column( 'id', primary_key= True, type_= Integer, ), + Column( 'atype', type_= String(100), ), + ) + + table_Engineer = Table( 'Engineer', metadata, + Column( 'machine', type_= String(100), ), + Column( 'id', Integer, ForeignKey( 'Employee.id', ), primary_key= True, ), + ) + + table_Manager = Table( 'Manager', metadata, + Column( 'duties', type_= String(100), ), + Column( 'id', Integer, ForeignKey( 'Engineer.id', ), primary_key= True, ), + ) + def test_threelevels(self): + class Employee( object): + def set( me, **kargs): + for k,v in kargs.iteritems(): setattr( me, k, v) + return me + def __str__(me): return str(me.__class__.__name__)+':'+str(me.name) + __repr__ = __str__ + class Engineer( Employee): pass + class Manager( Engineer): pass + + pu_Employee = polymorphic_union( { + 'Manager': table_Employee.join( table_Engineer).join( table_Manager), + 'Engineer': select([table_Employee, table_Engineer.c.machine], table_Employee.c.atype == 'Engineer', from_obj=[table_Employee.join(table_Engineer)]), + 'Employee': table_Employee.select( table_Employee.c.atype == 'Employee'), + }, None, 'pu_employee', ) + +# pu_Employee = polymorphic_union( { +# 'Manager': table_Employee.join( table_Engineer).join( table_Manager), +# 'Engineer': table_Employee.join(table_Engineer).select(table_Employee.c.atype == 'Engineer'), +# 'Employee': table_Employee.select( table_Employee.c.atype == 'Employee'), +# }, None, 'pu_employee', ) + + mapper_Employee = mapper( Employee, table_Employee, + polymorphic_identity= 'Employee', + polymorphic_on= pu_Employee.c.atype, + with_polymorphic=('*', pu_Employee), + ) + + pu_Engineer = polymorphic_union( { + 'Manager': table_Employee.join( table_Engineer).join( table_Manager), + 'Engineer': select([table_Employee, table_Engineer.c.machine], table_Employee.c.atype == 'Engineer', from_obj=[table_Employee.join(table_Engineer)]), + }, None, 'pu_engineer', ) + mapper_Engineer = mapper( Engineer, table_Engineer, + inherit_condition= table_Engineer.c.id == table_Employee.c.id, + inherits= mapper_Employee, + polymorphic_identity= 'Engineer', + polymorphic_on= pu_Engineer.c.atype, + with_polymorphic=('*', pu_Engineer), + ) + + mapper_Manager = mapper( Manager, table_Manager, + inherit_condition= table_Manager.c.id == table_Engineer.c.id, + inherits= mapper_Engineer, + polymorphic_identity= 'Manager', + ) + + a = Employee().set( name= 'one') + b = Engineer().set( egn= 'two', machine= 'any') + c = Manager().set( name= 'head', machine= 'fast', duties= 'many') + + session = create_session() + session.add(a) + session.add(b) + session.add(c) + session.flush() + assert set(session.query(Employee).all()) == set([a,b,c]) + assert set(session.query( Engineer).all()) == set([b,c]) + assert session.query( Manager).all() == [c] + +class ManyToManyPolyTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global base_item_table, item_table, base_item_collection_table, collection_table + base_item_table = Table( + 'base_item', metadata, + Column('id', Integer, primary_key=True), + Column('child_name', String(255), default=None)) + + item_table = Table( + 'item', metadata, + Column('id', Integer, ForeignKey('base_item.id'), primary_key=True), + Column('dummy', Integer, default=0)) # Dummy column to avoid weird insert problems + + base_item_collection_table = Table( + 'base_item_collection', metadata, + Column('item_id', Integer, ForeignKey('base_item.id')), + Column('collection_id', Integer, ForeignKey('collection.id'))) + + collection_table = Table( + 'collection', metadata, + Column('id', Integer, primary_key=True), + Column('name', Unicode(255))) + + def test_pjoin_compile(self): + """test that remote_side columns in the secondary join table arent attempted to be + matched to the target polymorphic selectable""" + class BaseItem(object): pass + class Item(BaseItem): pass + class Collection(object): pass + item_join = polymorphic_union( { + 'BaseItem':base_item_table.select(base_item_table.c.child_name=='BaseItem'), + 'Item':base_item_table.join(item_table), + }, None, 'item_join') + + mapper( + BaseItem, base_item_table, + with_polymorphic=('*', item_join), + polymorphic_on=base_item_table.c.child_name, + polymorphic_identity='BaseItem', + properties=dict(collections=relation(Collection, secondary=base_item_collection_table, backref="items"))) + + mapper( + Item, item_table, + inherits=BaseItem, + polymorphic_identity='Item') + + mapper(Collection, collection_table) + + class_mapper(BaseItem) + +class CustomPKTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global t1, t2 + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('type', String(30), nullable=False), + Column('data', String(30))) + # note that the primary key column in t2 is named differently + t2 = Table('t2', metadata, + Column('t2id', Integer, ForeignKey('t1.id'), primary_key=True), + Column('t2data', String(30))) + + def test_custompk(self): + """test that the primary_key attribute is propagated to the polymorphic mapper""" + + class T1(object):pass + class T2(T1):pass + + # create a polymorphic union with the select against the base table first. + # with the join being second, the alias of the union will + # pick up two "primary key" columns. technically the alias should have a + # 2-col pk in any case but the leading select has a NULL for the "t2id" column + d = util.OrderedDict() + d['t1'] = t1.select(t1.c.type=='t1') + d['t2'] = t1.join(t2) + pjoin = polymorphic_union(d, None, 'pjoin') + + mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', with_polymorphic=('*', pjoin), primary_key=[pjoin.c.id]) + mapper(T2, t2, inherits=T1, polymorphic_identity='t2') + print [str(c) for c in class_mapper(T1).primary_key] + ot1 = T1() + ot2 = T2() + sess = create_session() + sess.add(ot1) + sess.add(ot2) + sess.flush() + sess.expunge_all() + + # query using get(), using only one value. this requires the select_table mapper + # has the same single-col primary key. + assert sess.query(T1).get(ot1.id).id == ot1.id + + ot1 = sess.query(T1).get(ot1.id) + ot1.data = 'hi' + sess.flush() + + def test_pk_collapses(self): + """test that a composite primary key attribute formed by a join is "collapsed" into its + minimal columns""" + + class T1(object):pass + class T2(T1):pass + + # create a polymorphic union with the select against the base table first. + # with the join being second, the alias of the union will + # pick up two "primary key" columns. technically the alias should have a + # 2-col pk in any case but the leading select has a NULL for the "t2id" column + d = util.OrderedDict() + d['t1'] = t1.select(t1.c.type=='t1') + d['t2'] = t1.join(t2) + pjoin = polymorphic_union(d, None, 'pjoin') + + mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', with_polymorphic=('*', pjoin)) + mapper(T2, t2, inherits=T1, polymorphic_identity='t2') + assert len(class_mapper(T1).primary_key) == 1 + + print [str(c) for c in class_mapper(T1).primary_key] + ot1 = T1() + ot2 = T2() + sess = create_session() + sess.add(ot1) + sess.add(ot2) + sess.flush() + sess.expunge_all() + + # query using get(), using only one value. this requires the select_table mapper + # has the same single-col primary key. + assert sess.query(T1).get(ot1.id).id == ot1.id + + ot1 = sess.query(T1).get(ot1.id) + ot1.data = 'hi' + sess.flush() + +class InheritingEagerTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global people, employees, tags, peopleTags + + people = Table('people', metadata, + Column('id', Integer, primary_key=True), + Column('_type', String(30), nullable=False), + ) + + + employees = Table('employees', metadata, + Column('id', Integer, ForeignKey('people.id'),primary_key=True), + ) + + tags = Table('tags', metadata, + Column('id', Integer, primary_key=True), + Column('label', String(50), nullable=False), + ) + + peopleTags = Table('peopleTags', metadata, + Column('person_id', Integer,ForeignKey('people.id')), + Column('tag_id', Integer,ForeignKey('tags.id')), + ) + + def test_basic(self): + """test that Query uses the full set of mapper._eager_loaders when generating SQL""" + + class Person(_fixtures.Base): + pass + + class Employee(Person): + def __init__(self, name='bob'): + self.name = name + + class Tag(_fixtures.Base): + def __init__(self, label): + self.label = label + + mapper(Person, people, polymorphic_on=people.c._type,polymorphic_identity='person', properties={ + 'tags': relation(Tag, secondary=peopleTags,backref='people', lazy=False) + }) + mapper(Employee, employees, inherits=Person,polymorphic_identity='employee') + mapper(Tag, tags) + + session = create_session() + + bob = Employee() + session.add(bob) + + tag = Tag('crazy') + bob.tags.append(tag) + + tag = Tag('funny') + bob.tags.append(tag) + session.flush() + + session.expunge_all() + # query from Employee with limit, query needs to apply eager limiting subquery + instance = session.query(Employee).filter_by(id=1).limit(1).first() + assert len(instance.tags) == 2 + +class MissingPolymorphicOnTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global tablea, tableb, tablec, tabled + tablea = Table('tablea', metadata, + Column('id', Integer, primary_key=True), + Column('adata', String(50)), + ) + tableb = Table('tableb', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('tablea.id')), + Column('data', String(50)), + ) + tablec = Table('tablec', metadata, + Column('id', Integer, ForeignKey('tablea.id'), primary_key=True), + Column('cdata', String(50)), + ) + tabled = Table('tabled', metadata, + Column('id', Integer, ForeignKey('tablec.id'), primary_key=True), + Column('ddata', String(50)), + ) + + def test_polyon_col_setsup(self): + class A(_fixtures.Base): + pass + class B(_fixtures.Base): + pass + class C(A): + pass + class D(C): + pass + + poly_select = select([tablea, tableb.c.data.label('discriminator')], from_obj=tablea.join(tableb)).alias('poly') + + mapper(B, tableb) + mapper(A, tablea, with_polymorphic=('*', poly_select), polymorphic_on=poly_select.c.discriminator, properties={ + 'b':relation(B, uselist=False) + }) + mapper(C, tablec, inherits=A,polymorphic_identity='c') + mapper(D, tabled, inherits=C, polymorphic_identity='d') + + c = C(cdata='c1', adata='a1', b=B(data='c')) + d = D(cdata='c2', adata='a2', ddata='d2', b=B(data='d')) + sess = create_session() + sess.add(c) + sess.add(d) + sess.flush() + sess.expunge_all() + eq_(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')]) + diff --git a/test/orm/inheritance/test_productspec.py b/test/orm/inheritance/test_productspec.py new file mode 100644 index 000000000..b2bcb85d5 --- /dev/null +++ b/test/orm/inheritance/test_productspec.py @@ -0,0 +1,318 @@ +from datetime import datetime +from sqlalchemy import * +from sqlalchemy.orm import * + + +from sqlalchemy.test import testing +from test.orm import _base + + +class InheritTest(_base.MappedTest): + """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships""" + @classmethod + def define_tables(cls, metadata): + global products_table, specification_table, documents_table + global Product, Detail, Assembly, SpecLine, Document, RasterDocument + + products_table = Table('products', metadata, + Column('product_id', Integer, primary_key=True), + Column('product_type', String(128)), + Column('name', String(128)), + Column('mark', String(128)), + ) + + specification_table = Table('specification', metadata, + Column('spec_line_id', Integer, primary_key=True), + Column('master_id', Integer, ForeignKey("products.product_id"), + nullable=True), + Column('slave_id', Integer, ForeignKey("products.product_id"), + nullable=True), + Column('quantity', Float, default=1.), + ) + + documents_table = Table('documents', metadata, + Column('document_id', Integer, primary_key=True), + Column('document_type', String(128)), + Column('product_id', Integer, ForeignKey('products.product_id')), + Column('create_date', DateTime, default=lambda:datetime.now()), + Column('last_updated', DateTime, default=lambda:datetime.now(), + onupdate=lambda:datetime.now()), + Column('name', String(128)), + Column('data', Binary), + Column('size', Integer, default=0), + ) + + class Product(object): + def __init__(self, name, mark=''): + self.name = name + self.mark = mark + def __repr__(self): + return '<%s %s>' % (self.__class__.__name__, self.name) + + class Detail(Product): + def __init__(self, name): + self.name = name + + class Assembly(Product): + def __repr__(self): + return Product.__repr__(self) + " " + " ".join([x + "=" + repr(getattr(self, x, None)) for x in ['specification', 'documents']]) + + class SpecLine(object): + def __init__(self, master=None, slave=None, quantity=1): + self.master = master + self.slave = slave + self.quantity = quantity + + def __repr__(self): + return '<%s %.01f %s>' % ( + self.__class__.__name__, + self.quantity or 0., + repr(self.slave) + ) + + class Document(object): + def __init__(self, name, data=None): + self.name = name + self.data = data + def __repr__(self): + return '<%s %s>' % (self.__class__.__name__, self.name) + + class RasterDocument(Document): + pass + + def testone(self): + product_mapper = mapper(Product, products_table, + polymorphic_on=products_table.c.product_type, + polymorphic_identity='product') + + detail_mapper = mapper(Detail, inherits=product_mapper, + polymorphic_identity='detail') + + assembly_mapper = mapper(Assembly, inherits=product_mapper, + polymorphic_identity='assembly') + + specification_mapper = mapper(SpecLine, specification_table, + properties=dict( + master=relation(Assembly, + foreign_keys=[specification_table.c.master_id], + primaryjoin=specification_table.c.master_id==products_table.c.product_id, + lazy=True, backref=backref('specification'), + uselist=False), + slave=relation(Product, + foreign_keys=[specification_table.c.slave_id], + primaryjoin=specification_table.c.slave_id==products_table.c.product_id, + lazy=True, uselist=False), + quantity=specification_table.c.quantity, + ) + ) + + session = create_session( ) + + a1 = Assembly(name='a1') + + p1 = Product(name='p1') + a1.specification.append(SpecLine(slave=p1)) + + d1 = Detail(name='d1') + a1.specification.append(SpecLine(slave=d1)) + + session.add(a1) + orig = repr(a1) + session.flush() + session.expunge_all() + + a1 = session.query(Product).filter_by(name='a1').one() + new = repr(a1) + print orig + print new + assert orig == new == ' specification=[>, >] documents=None' + + def testtwo(self): + product_mapper = mapper(Product, products_table, + polymorphic_on=products_table.c.product_type, + polymorphic_identity='product') + + detail_mapper = mapper(Detail, inherits=product_mapper, + polymorphic_identity='detail') + + specification_mapper = mapper(SpecLine, specification_table, + properties=dict( + slave=relation(Product, + foreign_keys=[specification_table.c.slave_id], + primaryjoin=specification_table.c.slave_id==products_table.c.product_id, + lazy=True, uselist=False), + ) + ) + + session = create_session( ) + + s = SpecLine(slave=Product(name='p1')) + s2 = SpecLine(slave=Detail(name='d1')) + session.add(s) + session.add(s2) + orig = repr([s, s2]) + session.flush() + session.expunge_all() + new = repr(session.query(SpecLine).all()) + print orig + print new + assert orig == new == '[>, >]' + + def testthree(self): + product_mapper = mapper(Product, products_table, + polymorphic_on=products_table.c.product_type, + polymorphic_identity='product') + detail_mapper = mapper(Detail, inherits=product_mapper, + polymorphic_identity='detail') + assembly_mapper = mapper(Assembly, inherits=product_mapper, + polymorphic_identity='assembly') + + specification_mapper = mapper(SpecLine, specification_table, + properties=dict( + master=relation(Assembly, lazy=False, uselist=False, + foreign_keys=[specification_table.c.master_id], + primaryjoin=specification_table.c.master_id==products_table.c.product_id, + backref=backref('specification', cascade="all, delete-orphan"), + ), + slave=relation(Product, lazy=False, uselist=False, + foreign_keys=[specification_table.c.slave_id], + primaryjoin=specification_table.c.slave_id==products_table.c.product_id, + ), + quantity=specification_table.c.quantity, + ) + ) + + document_mapper = mapper(Document, documents_table, + polymorphic_on=documents_table.c.document_type, + polymorphic_identity='document', + properties=dict( + name=documents_table.c.name, + data=deferred(documents_table.c.data), + product=relation(Product, lazy=True, backref=backref('documents', cascade="all, delete-orphan")), + ), + ) + raster_document_mapper = mapper(RasterDocument, inherits=document_mapper, + polymorphic_identity='raster_document') + + session = create_session() + + a1 = Assembly(name='a1') + a1.specification.append(SpecLine(slave=Detail(name='d1'))) + a1.documents.append(Document('doc1')) + a1.documents.append(RasterDocument('doc2')) + session.add(a1) + orig = repr(a1) + session.flush() + session.expunge_all() + + a1 = session.query(Product).filter_by(name='a1').one() + new = repr(a1) + print orig + print new + assert orig == new == ' specification=[>] documents=[, ]' + + def testfour(self): + """this tests the RasterDocument being attached to the Assembly, but *not* the Document. this means only + a "sub-class" task, i.e. corresponding to an inheriting mapper but not the base mapper, is created. """ + product_mapper = mapper(Product, products_table, + polymorphic_on=products_table.c.product_type, + polymorphic_identity='product') + detail_mapper = mapper(Detail, inherits=product_mapper, + polymorphic_identity='detail') + assembly_mapper = mapper(Assembly, inherits=product_mapper, + polymorphic_identity='assembly') + + document_mapper = mapper(Document, documents_table, + polymorphic_on=documents_table.c.document_type, + polymorphic_identity='document', + properties=dict( + name=documents_table.c.name, + data=deferred(documents_table.c.data), + product=relation(Product, lazy=True, backref=backref('documents', cascade="all, delete-orphan")), + ), + ) + raster_document_mapper = mapper(RasterDocument, inherits=document_mapper, + polymorphic_identity='raster_document') + + session = create_session( ) + + a1 = Assembly(name='a1') + a1.documents.append(RasterDocument('doc2')) + session.add(a1) + orig = repr(a1) + session.flush() + session.expunge_all() + + a1 = session.query(Product).filter_by(name='a1').one() + new = repr(a1) + print orig + print new + assert orig == new == ' specification=None documents=[]' + + del a1.documents[0] + session.flush() + session.expunge_all() + + a1 = session.query(Product).filter_by(name='a1').one() + assert len(session.query(Document).all()) == 0 + + def testfive(self): + """tests the late compilation of mappers""" + + specification_mapper = mapper(SpecLine, specification_table, + properties=dict( + master=relation(Assembly, lazy=False, uselist=False, + foreign_keys=[specification_table.c.master_id], + primaryjoin=specification_table.c.master_id==products_table.c.product_id, + backref=backref('specification'), + ), + slave=relation(Product, lazy=False, uselist=False, + foreign_keys=[specification_table.c.slave_id], + primaryjoin=specification_table.c.slave_id==products_table.c.product_id, + ), + quantity=specification_table.c.quantity, + ) + ) + + product_mapper = mapper(Product, products_table, + polymorphic_on=products_table.c.product_type, + polymorphic_identity='product', properties={ + 'documents' : relation(Document, lazy=True, + backref='product', cascade='all, delete-orphan'), + }) + + detail_mapper = mapper(Detail, inherits=Product, + polymorphic_identity='detail') + + document_mapper = mapper(Document, documents_table, + polymorphic_on=documents_table.c.document_type, + polymorphic_identity='document', + properties=dict( + name=documents_table.c.name, + data=deferred(documents_table.c.data), + ), + ) + + raster_document_mapper = mapper(RasterDocument, inherits=Document, + polymorphic_identity='raster_document') + + assembly_mapper = mapper(Assembly, inherits=Product, + polymorphic_identity='assembly') + + session = create_session() + + a1 = Assembly(name='a1') + a1.specification.append(SpecLine(slave=Detail(name='d1'))) + a1.documents.append(Document('doc1')) + a1.documents.append(RasterDocument('doc2')) + session.add(a1) + orig = repr(a1) + session.flush() + session.expunge_all() + + a1 = session.query(Product).filter_by(name='a1').one() + new = repr(a1) + print orig + print new + assert orig == new == ' specification=[>] documents=[, ]' + diff --git a/test/orm/inheritance/test_query.py b/test/orm/inheritance/test_query.py new file mode 100644 index 000000000..5b57e8f45 --- /dev/null +++ b/test/orm/inheritance/test_query.py @@ -0,0 +1,1113 @@ +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy import exc as sa_exc +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.engine import default + +from sqlalchemy.test import AssertsCompiledSQL, testing +from test.orm import _base, _fixtures +from sqlalchemy.test.testing import eq_ + +class Company(_fixtures.Base): + pass + +class Person(_fixtures.Base): + pass +class Engineer(Person): + pass +class Manager(Person): + pass +class Boss(Manager): + pass + +class Machine(_fixtures.Base): + pass + +class Paperwork(_fixtures.Base): + pass + +def _produce_test(select_type): + class PolymorphicQueryTest(_base.MappedTest, AssertsCompiledSQL): + run_inserts = 'once' + run_setup_mappers = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + global companies, people, engineers, managers, boss, paperwork, machines + + companies = Table('companies', metadata, + Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key=True), + Column('name', String(50))) + + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('company_id', Integer, ForeignKey('companies.company_id')), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('engineer_name', String(50)), + Column('primary_language', String(50)), + ) + + machines = Table('machines', metadata, + Column('machine_id', Integer, primary_key=True), + Column('name', String(50)), + Column('engineer_id', Integer, ForeignKey('engineers.person_id'))) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + Column('manager_name', String(50)) + ) + + boss = Table('boss', metadata, + Column('boss_id', Integer, ForeignKey('managers.person_id'), primary_key=True), + Column('golf_swing', String(30)), + ) + + paperwork = Table('paperwork', metadata, + Column('paperwork_id', Integer, primary_key=True), + Column('description', String(50)), + Column('person_id', Integer, ForeignKey('people.person_id'))) + + clear_mappers() + + mapper(Company, companies, properties={ + 'employees':relation(Person, order_by=people.c.person_id) + }) + + mapper(Machine, machines) + + if select_type == '': + person_join = manager_join = None + person_with_polymorphic = None + manager_with_polymorphic = None + elif select_type == 'Polymorphic': + person_join = manager_join = None + person_with_polymorphic = '*' + manager_with_polymorphic = '*' + elif select_type == 'Unions': + person_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + }, None, 'pjoin') + + manager_join = people.join(managers).outerjoin(boss) + person_with_polymorphic = ([Person, Manager, Engineer], person_join) + manager_with_polymorphic = ('*', manager_join) + elif select_type == 'AliasedJoins': + person_join = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin') + manager_join = people.join(managers).outerjoin(boss).select(use_labels=True).alias('mjoin') + person_with_polymorphic = ([Person, Manager, Engineer], person_join) + manager_with_polymorphic = ('*', manager_join) + elif select_type == 'Joins': + person_join = people.outerjoin(engineers).outerjoin(managers) + manager_join = people.join(managers).outerjoin(boss) + person_with_polymorphic = ([Person, Manager, Engineer], person_join) + manager_with_polymorphic = ('*', manager_join) + + + # testing a order_by here as well; the surrogate mapper has to adapt it + mapper(Person, people, + with_polymorphic=person_with_polymorphic, + polymorphic_on=people.c.type, polymorphic_identity='person', order_by=people.c.person_id, + properties={ + 'paperwork':relation(Paperwork, order_by=paperwork.c.paperwork_id) + }) + mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer', properties={ + 'machines':relation(Machine, order_by=machines.c.machine_id) + }) + mapper(Manager, managers, with_polymorphic=manager_with_polymorphic, + inherits=Person, polymorphic_identity='manager') + mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss') + mapper(Paperwork, paperwork) + + + @classmethod + def insert_data(cls): + global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3, c1, c2 + + c1 = Company(name="MegaCorp, Inc.") + c2 = Company(name="Elbonia, Inc.") + e1 = Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", paperwork=[ + Paperwork(description="tps report #1"), + Paperwork(description="tps report #2") + ], machines=[ + Machine(name='IBM ThinkPad'), + Machine(name='IPhone'), + ]) + e2 = Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer", paperwork=[ + Paperwork(description="tps report #3"), + Paperwork(description="tps report #4") + ], machines=[ + Machine(name="Commodore 64") + ]) + b1 = Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss", paperwork=[ + Paperwork(description="review #1"), + ]) + m1 = Manager(name="dogbert", manager_name="dogbert", status="regular manager", paperwork=[ + Paperwork(description="review #2"), + Paperwork(description="review #3") + ]) + c1.employees = [e1, e2, b1, m1] + + e3 = Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer", paperwork=[ + Paperwork(description='elbonian missive #3') + ], machines=[ + Machine(name="Commodore 64"), + Machine(name="IBM 3270") + ]) + + c2.employees = [e3] + sess = create_session() + sess.add(c1) + sess.add(c2) + sess.flush() + sess.expunge_all() + + all_employees = [e1, e2, b1, m1, e3] + c1_employees = [e1, e2, b1, m1] + c2_employees = [e3] + + def test_loads_at_once(self): + """test that all objects load from the full query, when with_polymorphic is used""" + + sess = create_session() + def go(): + eq_(sess.query(Person).all(), all_employees) + self.assert_sql_count(testing.db, go, {'':14, 'Polymorphic':9}.get(select_type, 10)) + + def test_primary_eager_aliasing(self): + sess = create_session() + + def go(): + eq_(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) + self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4)) + + sess = create_session() + + # assert the JOINs dont over JOIN + assert sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().subquery().count().scalar() == 2 + + def go(): + eq_(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) + self.assert_sql_count(testing.db, go, 3) + + + def test_get(self): + sess = create_session() + + # for all mappers, ensure the primary key has been calculated as just the "person_id" + # column + eq_(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) + eq_(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) + eq_(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore")) + + def test_multi_join(self): + sess = create_session() + + e = aliased(Person) + c = aliased(Company) + + q = sess.query(Company, Person, c, e).join((Person, Company.employees)).join((e, c.employees)).\ + filter(Person.name=='dilbert').filter(e.name=='wally') + + eq_(q.count(), 1) + eq_(q.all(), [ + ( + Company(company_id=1,name=u'MegaCorp, Inc.'), + Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), + Company(company_id=1,name=u'MegaCorp, Inc.'), + Engineer(status=u'regular engineer',engineer_name=u'wally',name=u'wally',company_id=1,primary_language=u'c++',person_id=2,type=u'engineer') + ) + ]) + + def test_filter_on_subclass(self): + sess = create_session() + eq_(sess.query(Engineer).all()[0], Engineer(name="dilbert")) + + eq_(sess.query(Engineer).first(), Engineer(name="dilbert")) + + eq_(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert")) + + eq_(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert")) + + eq_(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) + + eq_(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) + + def test_join_from_polymorphic(self): + sess = create_session() + + for aliased in (True, False): + eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + + eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + + eq_(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1]) + + eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + + def test_join_from_with_polymorphic(self): + sess = create_session() + + for aliased in (True, False): + sess.expunge_all() + eq_(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + + sess.expunge_all() + eq_(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + + sess.expunge_all() + eq_(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + + def test_join_to_polymorphic(self): + sess = create_session() + eq_(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2) + + eq_(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2) + + def test_polymorphic_any(self): + sess = create_session() + + eq_( + sess.query(Company).\ + filter(Company.employees.any(Person.name=='vlad')).all(), [c2] + ) + + # test that the aliasing on "Person" does not bleed into the + # EXISTS clause generated by any() + eq_( + sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\ + filter(Company.employees.any(Person.name=='wally')).all(), [c1] + ) + + eq_( + sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\ + filter(Company.employees.any(Person.name=='vlad')).all(), [] + ) + + eq_( + sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(), + c2 + ) + + calias = aliased(Company) + eq_( + sess.query(calias).filter(calias.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(), + c2 + ) + + eq_( + sess.query(Company).filter(Company.employees.of_type(Boss).any(Boss.golf_swing=='fore')).one(), + c1 + ) + eq_( + sess.query(Company).filter(Company.employees.of_type(Boss).any(Manager.manager_name=='pointy')).one(), + c1 + ) + + if select_type != '': + eq_( + sess.query(Person).filter(Engineer.machines.any(Machine.name=="Commodore 64")).all(), [e2, e3] + ) + + eq_( + sess.query(Person).filter(Person.paperwork.any(Paperwork.description=="review #2")).all(), [m1] + ) + + eq_( + sess.query(Company).filter(Company.employees.of_type(Engineer).any(and_(Engineer.primary_language=='cobol'))).one(), + c2 + ) + + def test_join_from_columns_or_subclass(self): + sess = create_session() + + eq_( + sess.query(Manager.name).order_by(Manager.name).all(), + [(u'dogbert',), (u'pointy haired boss',)] + ) + + eq_( + sess.query(Manager.name).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(), + [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)] + ) + + eq_( + sess.query(Person.name).join((Paperwork, Person.paperwork)).order_by(Person.name).all(), + [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)] + ) + + eq_( + sess.query(Person.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Person.name).all(), + [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)] + ) + + eq_( + sess.query(Manager).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(), + [m1, b1] + ) + + eq_( + sess.query(Manager.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(), + [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)] + ) + + eq_( + sess.query(Manager.person_id).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(), + [(4,), (4,), (3,)] + ) + + eq_( + sess.query(Manager.name, Paperwork.description).join((Paperwork, Manager.person_id==Paperwork.person_id)).all(), + [(u'pointy haired boss', u'review #1'), (u'dogbert', u'review #2'), (u'dogbert', u'review #3')] + ) + + malias = aliased(Manager) + eq_( + sess.query(malias.name).join((paperwork, malias.person_id==paperwork.c.person_id)).all(), + [(u'pointy haired boss',), (u'dogbert',), (u'dogbert',)] + ) + + def test_expire(self): + """test that individual column refresh doesn't get tripped up by the select_table mapper""" + + sess = create_session() + m1 = sess.query(Manager).filter(Manager.name=='dogbert').one() + sess.expire(m1) + assert m1.status == 'regular manager' + + m2 = sess.query(Manager).filter(Manager.name=='pointy haired boss').one() + sess.expire(m2, ['manager_name', 'golf_swing']) + assert m2.golf_swing=='fore' + + def test_with_polymorphic(self): + + sess = create_session() + + + assert_raises(sa_exc.InvalidRequestError, sess.query(Person).with_polymorphic, Paperwork) + assert_raises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Boss) + assert_raises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Person) + + # compare to entities without related collections to prevent additional lazy SQL from firing on + # loaded entities + emps_without_relations = [ + Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer"), + Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer"), + Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss"), + Manager(name="dogbert", manager_name="dogbert", status="regular manager"), + Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") + ] + eq_(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) + + + def go(): + eq_(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1]) + self.assert_sql_count(testing.db, go, 1) + + sess.expunge_all() + def go(): + eq_(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 1) + + sess.expunge_all() + def go(): + eq_(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 3) + + sess.expunge_all() + def go(): + eq_(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 3) + + sess.expunge_all() + def go(): + # limit the polymorphic join down to just "Person", overriding select_table + eq_(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 6) + + def test_relation_to_polymorphic(self): + assert_result = [ + Company(name="MegaCorp, Inc.", employees=[ + Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")]), + Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer"), + Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss"), + Manager(name="dogbert", manager_name="dogbert", status="regular manager"), + ]), + Company(name="Elbonia, Inc.", employees=[ + Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") + ]) + ] + + sess = create_session() + + def go(): + # test load Companies with lazy load to 'employees' + eq_(sess.query(Company).all(), assert_result) + self.assert_sql_count(testing.db, go, {'':9, 'Polymorphic':4}.get(select_type, 5)) + + sess = create_session() + def go(): + # currently, it doesn't matter if we say Company.employees, or Company.employees.of_type(Engineer). eagerloader doesn't + # pick up on the "of_type()" as of yet. + eq_(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result) + + # in the case of select_type='', the eagerload doesn't take in this case; + # it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines" + self.assert_sql_count(testing.db, go, {'':7, 'Polymorphic':1}.get(select_type, 2)) + + def test_eagerload_on_subclass(self): + sess = create_session() + def go(): + # test load People with eagerload to engineers + machines + eq_(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(), + [Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])] + ) + self.assert_sql_count(testing.db, go, 1) + + def test_join_to_subclass(self): + sess = create_session() + eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) + + if select_type == '': + eq_(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) + + ealias = aliased(Engineer) + eq_(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1]) + + eq_(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3]) + eq_(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + eq_(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2]) + eq_(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + else: + eq_(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) + eq_(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + eq_(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2]) + eq_(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + + # non-polymorphic + eq_(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3]) + eq_(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + + # here's the new way + eq_(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + + def test_join_through_polymorphic(self): + + sess = create_session() + + for aliased in (True, False): + eq_( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), + [c1] + ) + + eq_( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#%')).all(), + [c1, c2] + ) + + eq_( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#2%')).all(), + [c1] + ) + + eq_( + sess.query(Company).\ + join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#%')).all(), + [c1, c2] + ) + + eq_( + sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ + join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), + [c1] + ) + + eq_( + sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ + join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#%')).all(), + [c1, c2] + ) + def test_explicit_polymorphic_join(self): + sess = create_session() + + # join from Company to Engineer; join condition formulated by + # ORMJoin using regular table foreign key connections. Engineer + # is expressed as "(select * people join engineers) as anon_1" + # so the join is contained. + eq_( + sess.query(Company).join(Engineer).filter(Engineer.engineer_name=='vlad').one(), + c2 + ) + + # same, using explicit join condition. Query.join() must adapt the on clause + # here to match the subquery wrapped around "people join engineers". + eq_( + sess.query(Company).join((Engineer, Company.company_id==Engineer.company_id)).filter(Engineer.engineer_name=='vlad').one(), + c2 + ) + + + def test_filter_on_baseclass(self): + sess = create_session() + + eq_(sess.query(Person).all(), all_employees) + + eq_(sess.query(Person).first(), all_employees[0]) + + eq_(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2) + + def test_from_alias(self): + sess = create_session() + + palias = aliased(Person) + eq_( + sess.query(palias).filter(palias.name.in_(['dilbert', 'wally'])).all(), + [e1, e2] + ) + + def test_self_referential(self): + sess = create_session() + + c1_employees = [e1, e2, b1, m1] + + palias = aliased(Person) + eq_( + sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ + filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), + [ + (m1, e1), + (m1, e2), + (m1, b1), + ] + ) + + eq_( + sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ + filter(Person.person_id>palias.person_id).from_self().order_by(Person.person_id, palias.person_id).all(), + [ + (m1, e1), + (m1, e2), + (m1, b1), + ] + ) + + def test_nesting_queries(self): + sess = create_session() + + # query.statement places a flag "no_adapt" on the returned statement. This prevents + # the polymorphic adaptation in the second "filter" from hitting it, which would pollute + # the subquery and usually results in recursion overflow errors within the adaption. + subq = sess.query(engineers.c.person_id).filter(Engineer.primary_language=='java').statement.as_scalar() + + eq_(sess.query(Person).filter(Person.person_id==subq).one(), e1) + + def test_mixed_entities(self): + sess = create_session() + + eq_( + sess.query(Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), + [(u'Elbonia, Inc.', + Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'))] + ) + + eq_( + sess.query(Person, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), + [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'), + u'Elbonia, Inc.')] + ) + + + eq_( + sess.query(Manager.name).all(), + [('pointy haired boss', ), ('dogbert',)] + ) + + eq_( + sess.query(Manager.name + " foo").all(), + [('pointy haired boss foo', ), ('dogbert foo',)] + ) + + row = sess.query(Engineer.name, Engineer.primary_language).filter(Engineer.name=='dilbert').first() + assert row.name == 'dilbert' + assert row.primary_language == 'java' + + + eq_( + sess.query(Engineer.name, Engineer.primary_language).all(), + [(u'dilbert', u'java'), (u'wally', u'c++'), (u'vlad', u'cobol')] + ) + + eq_( + sess.query(Boss.name, Boss.golf_swing).all(), + [(u'pointy haired boss', u'fore')] + ) + + # TODO: I think raise error on these for now. different inheritance/loading schemes have different + # results here, all incorrect + # + # self.assertEquals( + # sess.query(Person.name, Engineer.primary_language).all(), + # [] + # ) + + # self.assertEquals( + # sess.query(Person.name, Engineer.primary_language, Manager.manager_name).all(), + # [] + # ) + + eq_( + sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), + [(u'vlad',u'Elbonia, Inc.')] + ) + + eq_( + sess.query(Engineer.primary_language).filter(Person.type=='engineer').all(), + [(u'java',), (u'c++',), (u'cobol',)] + ) + + if select_type != '': + eq_( + sess.query(Engineer, Company.name).join(Company.employees).filter(Person.type=='engineer').all(), + [ + (Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'MegaCorp, Inc.'), + (Engineer(status=u'regular engineer',engineer_name=u'wally',name=u'wally',company_id=1,primary_language=u'c++',person_id=2,type=u'engineer'), u'MegaCorp, Inc.'), + (Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',company_id=2,primary_language=u'cobol',person_id=5,type=u'engineer'), u'Elbonia, Inc.') + ] + ) + + eq_( + sess.query(Engineer.primary_language, Company.name).join(Company.employees).filter(Person.type=='engineer').order_by(desc(Engineer.primary_language)).all(), + [(u'java', u'MegaCorp, Inc.'), (u'cobol', u'Elbonia, Inc.'), (u'c++', u'MegaCorp, Inc.')] + ) + + palias = aliased(Person) + eq_( + sess.query(Person, Company.name, palias).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), + [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'), + u'Elbonia, Inc.', + Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'))] + ) + + eq_( + sess.query(palias, Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), + [(Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), + u'Elbonia, Inc.', + Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),) + ] + ) + + eq_( + sess.query(Person.name, Company.name, palias.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), + [(u'vlad', u'Elbonia, Inc.', u'dilbert')] + ) + + palias = aliased(Person) + eq_( + sess.query(Person.type, Person.name, palias.type, palias.name).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ + filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), + [(u'manager', u'dogbert', u'engineer', u'dilbert'), + (u'manager', u'dogbert', u'engineer', u'wally'), + (u'manager', u'dogbert', u'boss', u'pointy haired boss')] + ) + + eq_( + sess.query(Person.name, Paperwork.description).filter(Person.person_id==Paperwork.person_id).order_by(Person.name, Paperwork.description).all(), + [(u'dilbert', u'tps report #1'), (u'dilbert', u'tps report #2'), (u'dogbert', u'review #2'), + (u'dogbert', u'review #3'), + (u'pointy haired boss', u'review #1'), + (u'vlad', u'elbonian missive #3'), + (u'wally', u'tps report #3'), + (u'wally', u'tps report #4'), + ] + ) + + if select_type != '': + eq_( + sess.query(func.count(Person.person_id)).filter(Engineer.primary_language=='java').all(), + [(1, )] + ) + + eq_( + sess.query(Company.name, func.count(Person.person_id)).filter(Company.company_id==Person.company_id).group_by(Company.name).order_by(Company.name).all(), + [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)] + ) + + eq_( + sess.query(Company.name, func.count(Person.person_id)).join(Company.employees).group_by(Company.name).order_by(Company.name).all(), + [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)] + ) + + + PolymorphicQueryTest.__name__ = "Polymorphic%sTest" % select_type + return PolymorphicQueryTest + +for select_type in ('', 'Polymorphic', 'Unions', 'AliasedJoins', 'Joins'): + testclass = _produce_test(select_type) + exec("%s = testclass" % testclass.__name__) + +del testclass + +class SelfReferentialTestJoinedToBase(_base.MappedTest): + run_setup_mappers = 'once' + + @classmethod + def define_tables(cls, metadata): + global people, engineers + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('primary_language', String(50)), + Column('reports_to_id', Integer, ForeignKey('people.person_id')) + ) + + @classmethod + def setup_mappers(cls): + mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') + mapper(Engineer, engineers, inherits=Person, + inherit_condition=engineers.c.person_id==people.c.person_id, + polymorphic_identity='engineer', properties={ + 'reports_to':relation(Person, primaryjoin=people.c.person_id==engineers.c.reports_to_id) + }) + + def test_has(self): + + p1 = Person(name='dogbert') + e1 = Engineer(name='dilbert', primary_language='java', reports_to=p1) + sess = create_session() + sess.add(p1) + sess.add(e1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert')) + + def test_oftype_aliases_in_exists(self): + e1 = Engineer(name='dilbert', primary_language='java') + e2 = Engineer(name='wally', primary_language='c++', reports_to=e1) + sess = create_session() + sess.add_all([e1, e2]) + sess.flush() + + eq_(sess.query(Engineer).filter(Engineer.reports_to.of_type(Engineer).has(Engineer.name=='dilbert')).first(), e2) + + def test_join(self): + p1 = Person(name='dogbert') + e1 = Engineer(name='dilbert', primary_language='java', reports_to=p1) + sess = create_session() + sess.add(p1) + sess.add(e1) + sess.flush() + sess.expunge_all() + + eq_( + sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), + Engineer(name='dilbert')) + +class SelfReferentialJ2JTest(_base.MappedTest): + run_setup_mappers = 'once' + + @classmethod + def define_tables(cls, metadata): + global people, engineers, managers + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('primary_language', String(50)), + Column('reports_to_id', Integer, ForeignKey('managers.person_id')) + ) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + ) + + @classmethod + def setup_mappers(cls): + mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') + mapper(Manager, managers, inherits=Person, polymorphic_identity='manager') + + mapper(Engineer, engineers, inherits=Person, + polymorphic_identity='engineer', properties={ + 'reports_to':relation(Manager, primaryjoin=managers.c.person_id==engineers.c.reports_to_id, backref='engineers') + }) + + def test_has(self): + + m1 = Manager(name='dogbert') + e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1) + sess = create_session() + sess.add(m1) + sess.add(e1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(Engineer).filter(Engineer.reports_to.has(Manager.name=='dogbert')).first(), Engineer(name='dilbert')) + + def test_join(self): + m1 = Manager(name='dogbert') + e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1) + sess = create_session() + sess.add(m1) + sess.add(e1) + sess.flush() + sess.expunge_all() + + eq_( + sess.query(Engineer).join('reports_to', aliased=True).filter(Manager.name=='dogbert').first(), + Engineer(name='dilbert')) + + def test_filter_aliasing(self): + m1 = Manager(name='dogbert') + m2 = Manager(name='foo') + e1 = Engineer(name='wally', primary_language='java', reports_to=m1) + e2 = Engineer(name='dilbert', primary_language='c++', reports_to=m2) + e3 = Engineer(name='etc', primary_language='c++') + sess = create_session() + sess.add_all([m1, m2, e1, e2, e3]) + sess.flush() + sess.expunge_all() + + # filter aliasing applied to Engineer doesn't whack Manager + eq_( + sess.query(Manager).join(Manager.engineers).filter(Manager.name=='dogbert').all(), + [m1] + ) + + eq_( + sess.query(Manager).join(Manager.engineers).filter(Engineer.name=='dilbert').all(), + [m2] + ) + + eq_( + sess.query(Manager, Engineer).join(Manager.engineers).order_by(Manager.name.desc()).all(), + [ + (m2, e2), + (m1, e1), + ] + ) + + def test_relation_compare(self): + m1 = Manager(name='dogbert') + m2 = Manager(name='foo') + e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1) + e2 = Engineer(name='wally', primary_language='c++', reports_to=m2) + e3 = Engineer(name='etc', primary_language='c++') + sess = create_session() + sess.add(m1) + sess.add(m2) + sess.add(e1) + sess.add(e2) + sess.add(e3) + sess.flush() + sess.expunge_all() + + eq_( + sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==None).all(), + [] + ) + + eq_( + sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==m1).all(), + [m1] + ) + + + +class M2MFilterTest(_base.MappedTest): + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + global people, engineers, organizations, engineers_to_org + + organizations = Table('organizations', metadata, + Column('id', Integer, Sequence('org_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + ) + engineers_to_org = Table('engineers_org', metadata, + Column('org_id', Integer, ForeignKey('organizations.id')), + Column('engineer_id', Integer, ForeignKey('engineers.person_id')), + ) + + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('primary_language', String(50)), + ) + + @classmethod + def setup_mappers(cls): + global Organization + class Organization(_fixtures.Base): + pass + + mapper(Organization, organizations, properties={ + 'engineers':relation(Engineer, secondary=engineers_to_org, backref='organizations') + }) + + mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') + mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer') + + @classmethod + def insert_data(cls): + e1 = Engineer(name='e1') + e2 = Engineer(name='e2') + e3 = Engineer(name='e3') + e4 = Engineer(name='e4') + org1 = Organization(name='org1', engineers=[e1, e2]) + org2 = Organization(name='org2', engineers=[e3, e4]) + + sess = create_session() + sess.add(org1) + sess.add(org2) + sess.flush() + + def test_not_contains(self): + sess = create_session() + + e1 = sess.query(Person).filter(Engineer.name=='e1').one() + + # this works + eq_(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')]) + + # this had a bug + eq_(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')]) + + def test_any(self): + sess = create_session() + eq_(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')]) + eq_(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')]) + +class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL): + run_setup_mappers = 'once' + + @classmethod + def define_tables(cls, metadata): + global Parent, Child1, Child2 + + Base = declarative_base(metadata=metadata) + + secondary_table = Table('secondary', Base.metadata, + Column('left_id', Integer, ForeignKey('parent.id'), nullable=False), + Column('right_id', Integer, ForeignKey('parent.id'), nullable=False)) + + class Parent(Base): + __tablename__ = 'parent' + id = Column(Integer, primary_key=True) + cls = Column(String(50)) + __mapper_args__ = dict(polymorphic_on = cls ) + + class Child1(Parent): + __tablename__ = 'child1' + id = Column(Integer, ForeignKey('parent.id'), primary_key=True) + __mapper_args__ = dict(polymorphic_identity = 'child1') + + class Child2(Parent): + __tablename__ = 'child2' + id = Column(Integer, ForeignKey('parent.id'), primary_key=True) + __mapper_args__ = dict(polymorphic_identity = 'child2') + + Child1.left_child2 = relation(Child2, secondary = secondary_table, + primaryjoin = Parent.id == secondary_table.c.right_id, + secondaryjoin = Parent.id == secondary_table.c.left_id, + uselist = False, backref="right_children" + ) + + + def test_query_crit(self): + session = create_session() + c11, c12, c13 = Child1(), Child1(), Child1() + c21, c22, c23 = Child2(), Child2(), Child2() + + c11.left_child2 = c22 + c12.left_child2 = c22 + c13.left_child2 = c23 + + session.add_all([c11, c12, c13, c21, c22, c23]) + session.flush() + + # test that the join to Child2 doesn't alias Child1 in the select + eq_( + set(session.query(Child1).join(Child1.left_child2)), + set([c11, c12, c13]) + ) + + eq_( + set(session.query(Child1, Child2).join(Child1.left_child2)), + set([(c11, c22), (c12, c22), (c13, c23)]) + ) + + # test __eq__() on property is annotating correctly + eq_( + set(session.query(Child2).join(Child2.right_children).filter(Child1.left_child2==c22)), + set([c22]) + ) + + # test the same again + self.assert_compile( + session.query(Child2).join(Child2.right_children).filter(Child1.left_child2==c22).with_labels().statement, + "SELECT parent.id AS parent_id, child2.id AS child2_id, parent.cls AS parent_cls FROM " + "secondary AS secondary_1, parent JOIN child2 ON parent.id = child2.id JOIN secondary AS secondary_2 " + "ON parent.id = secondary_2.left_id JOIN (SELECT parent.id AS parent_id, parent.cls AS parent_cls, " + "child1.id AS child1_id FROM parent JOIN child1 ON parent.id = child1.id) AS anon_1 ON " + "anon_1.parent_id = secondary_2.right_id WHERE anon_1.parent_id = secondary_1.right_id AND :param_1 = secondary_1.left_id", + dialect=default.DefaultDialect() + ) + + def test_eager_join(self): + session = create_session() + + c1 = Child1() + c1.left_child2 = Child2() + session.add(c1) + session.flush() + + q = session.query(Child1).options(eagerload('left_child2')) + + # test that the splicing of the join works here, doesnt break in the middle of "parent join child1" + self.assert_compile(q.limit(1).with_labels().statement, + "SELECT anon_1.parent_id AS anon_1_parent_id, anon_1.child1_id AS anon_1_child1_id, "\ + "anon_1.parent_cls AS anon_1_parent_cls, anon_2.parent_id AS anon_2_parent_id, "\ + "anon_2.child2_id AS anon_2_child2_id, anon_2.parent_cls AS anon_2_parent_cls FROM "\ + "(SELECT parent.id AS parent_id, child1.id AS child1_id, parent.cls AS parent_cls FROM parent "\ + "JOIN child1 ON parent.id = child1.id LIMIT 1) AS anon_1 LEFT OUTER JOIN secondary AS secondary_1 "\ + "ON anon_1.parent_id = secondary_1.right_id LEFT OUTER JOIN (SELECT parent.id AS parent_id, "\ + "parent.cls AS parent_cls, child2.id AS child2_id FROM parent JOIN child2 ON parent.id = child2.id) "\ + "AS anon_2 ON anon_2.parent_id = secondary_1.left_id" + , dialect=default.DefaultDialect()) + + # another way to check + assert q.limit(1).with_labels().subquery().count().scalar() == 1 + + assert q.first() is c1 + diff --git a/test/orm/inheritance/test_selects.py b/test/orm/inheritance/test_selects.py new file mode 100644 index 000000000..a151af4fa --- /dev/null +++ b/test/orm/inheritance/test_selects.py @@ -0,0 +1,51 @@ +from sqlalchemy import * +from sqlalchemy.orm import * + +from sqlalchemy.test import testing +from test.orm._fixtures import Base +from test.orm._base import MappedTest + + +class InheritingSelectablesTest(MappedTest): + @classmethod + def define_tables(cls, metadata): + global foo, bar, baz + foo = Table('foo', metadata, + Column('a', String(30), primary_key=1), + Column('b', String(30), nullable=0)) + + bar = foo.select(foo.c.b == 'bar').alias('bar') + baz = foo.select(foo.c.b == 'baz').alias('baz') + + def test_load(self): + # TODO: add persistence test also + testing.db.execute(foo.insert(), a='not bar', b='baz') + testing.db.execute(foo.insert(), a='also not bar', b='baz') + testing.db.execute(foo.insert(), a='i am bar', b='bar') + testing.db.execute(foo.insert(), a='also bar', b='bar') + + class Foo(Base): pass + class Bar(Foo): pass + class Baz(Foo): pass + + mapper(Foo, foo, polymorphic_on=foo.c.b) + + mapper(Baz, baz, + with_polymorphic=('*', foo.join(baz, foo.c.b=='baz').alias('baz')), + inherits=Foo, + inherit_condition=(foo.c.a==baz.c.a), + inherit_foreign_keys=[baz.c.a], + polymorphic_identity='baz') + + mapper(Bar, bar, + with_polymorphic=('*', foo.join(bar, foo.c.b=='bar').alias('bar')), + inherits=Foo, + inherit_condition=(foo.c.a==bar.c.a), + inherit_foreign_keys=[bar.c.a], + polymorphic_identity='bar') + + s = sessionmaker(bind=testing.db)() + + assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all() + assert [Bar(), Bar()] == s.query(Bar).all() + diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py new file mode 100644 index 000000000..705826885 --- /dev/null +++ b/test/orm/inheritance/test_single.py @@ -0,0 +1,400 @@ +from sqlalchemy.test.testing import eq_ +from sqlalchemy import * +from sqlalchemy.orm import * + +from sqlalchemy.test import testing +from test.orm import _fixtures +from test.orm._base import MappedTest, ComparableEntity + + +class SingleInheritanceTest(MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('manager_data', String(50)), + Column('engineer_info', String(50)), + Column('type', String(20))) + + Table('reports', metadata, + Column('report_id', Integer, primary_key=True), + Column('employee_id', ForeignKey('employees.employee_id')), + Column('name', String(50)), + ) + + @classmethod + def setup_classes(cls): + class Employee(ComparableEntity): + pass + class Manager(Employee): + pass + class Engineer(Employee): + pass + class JuniorEngineer(Engineer): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Employee, employees, polymorphic_on=employees.c.type) + mapper(Manager, inherits=Employee, polymorphic_identity='manager') + mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer') + + @testing.resolve_artifact_names + def test_single_inheritance(self): + + session = create_session() + + m1 = Manager(name='Tom', manager_data='knows how to manage things') + e1 = Engineer(name='Kurt', engineer_info='knows how to hack') + e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') + session.add_all([m1, e1, e2]) + session.flush() + + assert session.query(Employee).all() == [m1, e1, e2] + assert session.query(Engineer).all() == [e1, e2] + assert session.query(Manager).all() == [m1] + assert session.query(JuniorEngineer).all() == [e2] + + m1 = session.query(Manager).one() + session.expire(m1, ['manager_data']) + eq_(m1.manager_data, "knows how to manage things") + + row = session.query(Engineer.name, Engineer.employee_id).filter(Engineer.name=='Kurt').first() + assert row.name == 'Kurt' + assert row.employee_id == e1.employee_id + + @testing.resolve_artifact_names + def test_multi_qualification(self): + session = create_session() + + m1 = Manager(name='Tom', manager_data='knows how to manage things') + e1 = Engineer(name='Kurt', engineer_info='knows how to hack') + e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') + + session.add_all([m1, e1, e2]) + session.flush() + + ealias = aliased(Engineer) + eq_( + session.query(Manager, ealias).all(), + [(m1, e1), (m1, e2)] + ) + + eq_( + session.query(Manager.name).all(), + [("Tom",)] + ) + + eq_( + session.query(Manager.name, ealias.name).all(), + [("Tom", "Kurt"), ("Tom", "Ed")] + ) + + eq_( + session.query(func.upper(Manager.name), func.upper(ealias.name)).all(), + [("TOM", "KURT"), ("TOM", "ED")] + ) + + eq_( + session.query(Manager).add_entity(ealias).all(), + [(m1, e1), (m1, e2)] + ) + + eq_( + session.query(Manager.name).add_column(ealias.name).all(), + [("Tom", "Kurt"), ("Tom", "Ed")] + ) + + # TODO: I think raise error on this for now + # self.assertEquals( + # session.query(Employee.name, Manager.manager_data, Engineer.engineer_info).all(), + # [] + # ) + + @testing.resolve_artifact_names + def test_select_from(self): + sess = create_session() + m1 = Manager(name='Tom', manager_data='data1') + m2 = Manager(name='Tom2', manager_data='data2') + e1 = Engineer(name='Kurt', engineer_info='knows how to hack') + e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') + sess.add_all([m1, m2, e1, e2]) + sess.flush() + + eq_( + sess.query(Manager).select_from(employees.select().limit(10)).all(), + [m1, m2] + ) + + @testing.resolve_artifact_names + def test_count(self): + sess = create_session() + m1 = Manager(name='Tom', manager_data='data1') + m2 = Manager(name='Tom2', manager_data='data2') + e1 = Engineer(name='Kurt', engineer_info='data3') + e2 = JuniorEngineer(name='marvin', engineer_info='data4') + sess.add_all([m1, m2, e1, e2]) + sess.flush() + + eq_(sess.query(Manager).count(), 2) + eq_(sess.query(Engineer).count(), 2) + eq_(sess.query(Employee).count(), 4) + + eq_(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2) + eq_(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3) + + @testing.resolve_artifact_names + def test_type_filtering(self): + class Report(ComparableEntity): pass + + mapper(Report, reports, properties={ + 'employee': relation(Employee, backref='reports')}) + sess = create_session() + + m1 = Manager(name='Tom', manager_data='data1') + r1 = Report(employee=m1) + sess.add_all([m1, r1]) + sess.flush() + rq = sess.query(Report) + + assert len(rq.filter(Report.employee.of_type(Manager).has()).all()) == 1 + assert len(rq.filter(Report.employee.of_type(Engineer).has()).all()) == 0 + + @testing.resolve_artifact_names + def test_type_joins(self): + class Report(ComparableEntity): pass + + mapper(Report, reports, properties={ + 'employee': relation(Employee, backref='reports')}) + sess = create_session() + + m1 = Manager(name='Tom', manager_data='data1') + r1 = Report(employee=m1) + sess.add_all([m1, r1]) + sess.flush() + + rq = sess.query(Report) + + assert len(rq.join(Report.employee.of_type(Manager)).all()) == 1 + assert len(rq.join(Report.employee.of_type(Engineer)).all()) == 0 + + +class RelationToSingleTest(MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('employees', metadata, + Column('employee_id', Integer, primary_key=True), + Column('name', String(50)), + Column('manager_data', String(50)), + Column('engineer_info', String(50)), + Column('type', String(20)), + Column('company_id', Integer, ForeignKey('companies.company_id')) + ) + + Table('companies', metadata, + Column('company_id', Integer, primary_key=True), + Column('name', String(50)), + ) + + @classmethod + def setup_classes(cls): + class Company(ComparableEntity): + pass + + class Employee(ComparableEntity): + pass + class Manager(Employee): + pass + class Engineer(Employee): + pass + class JuniorEngineer(Engineer): + pass + + @testing.resolve_artifact_names + def test_of_type(self): + mapper(Company, companies, properties={ + 'employees':relation(Employee, backref='company') + }) + mapper(Employee, employees, polymorphic_on=employees.c.type) + mapper(Manager, inherits=Employee, polymorphic_identity='manager') + mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer') + sess = sessionmaker()() + + c1 = Company(name='c1') + c2 = Company(name='c2') + + m1 = Manager(name='Tom', manager_data='data1', company=c1) + m2 = Manager(name='Tom2', manager_data='data2', company=c2) + e1 = Engineer(name='Kurt', engineer_info='knows how to hack', company=c2) + e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed', company=c1) + sess.add_all([c1, c2, m1, m2, e1, e2]) + sess.commit() + sess.expunge_all() + eq_( + sess.query(Company).filter(Company.employees.of_type(JuniorEngineer).any()).all(), + [ + Company(name='c1'), + ] + ) + + eq_( + sess.query(Company).join(Company.employees.of_type(JuniorEngineer)).all(), + [ + Company(name='c1'), + ] + ) + + + @testing.resolve_artifact_names + def test_relation_to_subclass(self): + mapper(Company, companies, properties={ + 'engineers':relation(Engineer) + }) + mapper(Employee, employees, polymorphic_on=employees.c.type, properties={ + 'company':relation(Company) + }) + mapper(Manager, inherits=Employee, polymorphic_identity='manager') + mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') + mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer') + sess = sessionmaker()() + + c1 = Company(name='c1') + c2 = Company(name='c2') + + m1 = Manager(name='Tom', manager_data='data1', company=c1) + m2 = Manager(name='Tom2', manager_data='data2', company=c2) + e1 = Engineer(name='Kurt', engineer_info='knows how to hack', company=c2) + e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed', company=c1) + sess.add_all([c1, c2, m1, m2, e1, e2]) + sess.commit() + + eq_(c1.engineers, [e2]) + eq_(c2.engineers, [e1]) + + sess.expunge_all() + eq_(sess.query(Company).order_by(Company.name).all(), + [ + Company(name='c1', engineers=[JuniorEngineer(name='Ed')]), + Company(name='c2', engineers=[Engineer(name='Kurt')]) + ] + ) + + # eager load join should limit to only "Engineer" + sess.expunge_all() + eq_(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(), + [ + Company(name='c1', engineers=[JuniorEngineer(name='Ed')]), + Company(name='c2', engineers=[Engineer(name='Kurt')]) + ] + ) + + # join() to Company.engineers, Employee as the requested entity + sess.expunge_all() + eq_(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(), + [ + (Company(name='c1'), JuniorEngineer(name='Ed')), + (Company(name='c2'), Engineer(name='Kurt')) + ] + ) + + # join() to Company.engineers, Engineer as the requested entity. + # this actually applies the IN criterion twice which is less than ideal. + sess.expunge_all() + eq_(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(), + [ + (Company(name='c1'), JuniorEngineer(name='Ed')), + (Company(name='c2'), Engineer(name='Kurt')) + ] + ) + + # join() to Company.engineers without any Employee/Engineer entity + sess.expunge_all() + eq_(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(), + [ + Company(name='c2') + ] + ) + + # this however fails as it does not limit the subtypes to just "Engineer". + # with joins constructed by filter(), we seem to be following a policy where + # we don't try to make decisions on how to join to the target class, whereas when using join() we + # seem to have a lot more capabilities. + # we might want to document "advantages of join() vs. straight filtering", or add a large + # section to "inheritance" laying out all the various behaviors Query has. + @testing.fails_on_everything_except() + def go(): + sess.expunge_all() + eq_(sess.query(Company).\ + filter(Company.company_id==Engineer.company_id).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(), + [ + Company(name='c2') + ] + ) + go() + +class SingleOnJoinedTest(MappedTest): + @classmethod + def define_tables(cls, metadata): + global persons_table, employees_table + + persons_table = Table('persons', metadata, + Column('person_id', Integer, primary_key=True), + Column('name', String(50)), + Column('type', String(20), nullable=False) + ) + + employees_table = Table('employees', metadata, + Column('person_id', Integer, ForeignKey('persons.person_id'),primary_key=True), + Column('employee_data', String(50)), + Column('manager_data', String(50)), + ) + + def test_single_on_joined(self): + class Person(_fixtures.Base): + pass + class Employee(Person): + pass + class Manager(Employee): + pass + + mapper(Person, persons_table, polymorphic_on=persons_table.c.type, polymorphic_identity='person') + mapper(Employee, employees_table, inherits=Person,polymorphic_identity='engineer') + mapper(Manager, inherits=Employee,polymorphic_identity='manager') + + sess = create_session() + sess.add(Person(name='p1')) + sess.add(Employee(name='e1', employee_data='ed1')) + sess.add(Manager(name='m1', employee_data='ed2', manager_data='md1')) + sess.flush() + sess.expunge_all() + + eq_(sess.query(Person).order_by(Person.person_id).all(), [ + Person(name='p1'), + Employee(name='e1', employee_data='ed1'), + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + sess.expunge_all() + + eq_(sess.query(Employee).order_by(Person.person_id).all(), [ + Employee(name='e1', employee_data='ed1'), + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + sess.expunge_all() + + eq_(sess.query(Manager).order_by(Person.person_id).all(), [ + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + sess.expunge_all() + + def go(): + eq_(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [ + Person(name='p1'), + Employee(name='e1', employee_data='ed1'), + Manager(name='m1', employee_data='ed2', manager_data='md1') + ]) + self.assert_sql_count(testing.db, go, 1) + diff --git a/test/orm/instrumentation.py b/test/orm/instrumentation.py deleted file mode 100644 index fd15420d0..000000000 --- a/test/orm/instrumentation.py +++ /dev/null @@ -1,765 +0,0 @@ -import testenv; testenv.configure_for_tests() - -from testlib import sa -from testlib.sa import MetaData, Table, Column, Integer, ForeignKey, util -from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers -from testlib.testing import eq_, ne_ -from testlib.compat import _function_named -from orm import _base - - -def modifies_instrumentation_finders(fn): - def decorated(*args, **kw): - pristine = attributes.instrumentation_finders[:] - try: - fn(*args, **kw) - finally: - del attributes.instrumentation_finders[:] - attributes.instrumentation_finders.extend(pristine) - return _function_named(decorated, fn.func_name) - -def with_lookup_strategy(strategy): - def decorate(fn): - def wrapped(*args, **kw): - try: - attributes._install_lookup_strategy(strategy) - return fn(*args, **kw) - finally: - attributes._install_lookup_strategy(sa.util.symbol('native')) - return _function_named(wrapped, fn.func_name) - return decorate - - -class InitTest(_base.ORMTest): - def fixture(self): - return Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column('type', Integer), - Column('x', Integer), - Column('y', Integer)) - - def register(self, cls, canary): - original_init = cls.__init__ - attributes.register_class(cls) - ne_(cls.__init__, original_init) - manager = attributes.manager_of_class(cls) - def on_init(state, instance, args, kwargs): - canary.append((cls, 'on_init', type(instance))) - manager.events.add_listener('on_init', on_init) - - def test_ai(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - - obj = A() - eq_(inits, [(A, '__init__')]) - - def test_A(self): - inits = [] - - class A(object): pass - self.register(A, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A)]) - - def test_Ai(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - self.register(A, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A), (A, '__init__')]) - - def test_ai_B(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - - class B(A): pass - self.register(B, inits) - - obj = A() - eq_(inits, [(A, '__init__')]) - - del inits[:] - - obj = B() - eq_(inits, [(B, 'on_init', B), (A, '__init__')]) - - def test_ai_Bi(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - - class B(A): - def __init__(self): - inits.append((B, '__init__')) - super(B, self).__init__() - self.register(B, inits) - - obj = A() - eq_(inits, [(A, '__init__')]) - - del inits[:] - - obj = B() - eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')]) - - def test_Ai_bi(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - self.register(A, inits) - - class B(A): - def __init__(self): - inits.append((B, '__init__')) - super(B, self).__init__() - - obj = A() - eq_(inits, [(A, 'on_init', A), (A, '__init__')]) - - del inits[:] - - obj = B() - eq_(inits, [(B, '__init__'), (A, 'on_init', B), (A, '__init__')]) - - def test_Ai_Bi(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - self.register(A, inits) - - class B(A): - def __init__(self): - inits.append((B, '__init__')) - super(B, self).__init__() - self.register(B, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A), (A, '__init__')]) - - del inits[:] - - obj = B() - eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')]) - - def test_Ai_B(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - self.register(A, inits) - - class B(A): pass - self.register(B, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A), (A, '__init__')]) - - del inits[:] - - obj = B() - eq_(inits, [(B, 'on_init', B), (A, '__init__')]) - - def test_Ai_Bi_Ci(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - self.register(A, inits) - - class B(A): - def __init__(self): - inits.append((B, '__init__')) - super(B, self).__init__() - self.register(B, inits) - - class C(B): - def __init__(self): - inits.append((C, '__init__')) - super(C, self).__init__() - self.register(C, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A), (A, '__init__')]) - - del inits[:] - - obj = B() - eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')]) - - del inits[:] - obj = C() - eq_(inits, [(C, 'on_init', C), (C, '__init__'), (B, '__init__'), - (A, '__init__')]) - - def test_Ai_bi_Ci(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - self.register(A, inits) - - class B(A): - def __init__(self): - inits.append((B, '__init__')) - super(B, self).__init__() - - class C(B): - def __init__(self): - inits.append((C, '__init__')) - super(C, self).__init__() - self.register(C, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A), (A, '__init__')]) - - del inits[:] - - obj = B() - eq_(inits, [(B, '__init__'), (A, 'on_init', B), (A, '__init__')]) - - del inits[:] - obj = C() - eq_(inits, [(C, 'on_init', C), (C, '__init__'), (B, '__init__'), - (A, '__init__')]) - - def test_Ai_b_Ci(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - self.register(A, inits) - - class B(A): pass - - class C(B): - def __init__(self): - inits.append((C, '__init__')) - super(C, self).__init__() - self.register(C, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A), (A, '__init__')]) - - del inits[:] - - obj = B() - eq_(inits, [(A, 'on_init', B), (A, '__init__')]) - - del inits[:] - obj = C() - eq_(inits, [(C, 'on_init', C), (C, '__init__'), (A, '__init__')]) - - def test_Ai_B_Ci(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - self.register(A, inits) - - class B(A): pass - self.register(B, inits) - - class C(B): - def __init__(self): - inits.append((C, '__init__')) - super(C, self).__init__() - self.register(C, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A), (A, '__init__')]) - - del inits[:] - - obj = B() - eq_(inits, [(B, 'on_init', B), (A, '__init__')]) - - del inits[:] - obj = C() - eq_(inits, [(C, 'on_init', C), (C, '__init__'), (A, '__init__')]) - - def test_Ai_B_C(self): - inits = [] - - class A(object): - def __init__(self): - inits.append((A, '__init__')) - self.register(A, inits) - - class B(A): pass - self.register(B, inits) - - class C(B): pass - self.register(C, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A), (A, '__init__')]) - - del inits[:] - - obj = B() - eq_(inits, [(B, 'on_init', B), (A, '__init__')]) - - del inits[:] - obj = C() - eq_(inits, [(C, 'on_init', C), (A, '__init__')]) - - def test_A_Bi_C(self): - inits = [] - - class A(object): pass - self.register(A, inits) - - class B(A): - def __init__(self): - inits.append((B, '__init__')) - self.register(B, inits) - - class C(B): pass - self.register(C, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A)]) - - del inits[:] - - obj = B() - eq_(inits, [(B, 'on_init', B), (B, '__init__')]) - - del inits[:] - obj = C() - eq_(inits, [(C, 'on_init', C), (B, '__init__')]) - - def test_A_B_Ci(self): - inits = [] - - class A(object): pass - self.register(A, inits) - - class B(A): pass - self.register(B, inits) - - class C(B): - def __init__(self): - inits.append((C, '__init__')) - self.register(C, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A)]) - - del inits[:] - - obj = B() - eq_(inits, [(B, 'on_init', B)]) - - del inits[:] - obj = C() - eq_(inits, [(C, 'on_init', C), (C, '__init__')]) - - def test_A_B_C(self): - inits = [] - - class A(object): pass - self.register(A, inits) - - class B(A): pass - self.register(B, inits) - - class C(B): pass - self.register(C, inits) - - obj = A() - eq_(inits, [(A, 'on_init', A)]) - - del inits[:] - - obj = B() - eq_(inits, [(B, 'on_init', B)]) - - del inits[:] - obj = C() - eq_(inits, [(C, 'on_init', C)]) - - def test_defaulted_init(self): - class X(object): - def __init__(self_, a, b=123, c='abc'): - self_.a = a - self_.b = b - self_.c = c - attributes.register_class(X) - - o = X('foo') - eq_(o.a, 'foo') - eq_(o.b, 123) - eq_(o.c, 'abc') - - class Y(object): - unique = object() - - class OutOfScopeForEval(object): - def __repr__(self_): - # misleading repr - return '123' - - outofscope = OutOfScopeForEval() - - def __init__(self_, u=unique, o=outofscope): - self_.u = u - self_.o = o - - attributes.register_class(Y) - - o = Y() - assert o.u is Y.unique - assert o.o is Y.outofscope - - -class MapperInitTest(_base.ORMTest): - - def fixture(self): - return Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column('type', Integer), - Column('x', Integer), - Column('y', Integer)) - - def test_partially_mapped_inheritance(self): - class A(object): - pass - - class B(A): - pass - - class C(B): - def __init__(self, x): - pass - - m = mapper(A, self.fixture()) - - # B is not mapped in the current implementation - self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, B) - - # C is not mapped in the current implementation - self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, C) - -class InstrumentationCollisionTest(_base.ORMTest): - def test_none(self): - class A(object): pass - attributes.register_class(A) - - mgr_factory = lambda cls: attributes.ClassManager(cls) - class B(object): - __sa_instrumentation_manager__ = staticmethod(mgr_factory) - attributes.register_class(B) - - class C(object): - __sa_instrumentation_manager__ = attributes.ClassManager - attributes.register_class(C) - - def test_single_down(self): - class A(object): pass - attributes.register_class(A) - - mgr_factory = lambda cls: attributes.ClassManager(cls) - class B(A): - __sa_instrumentation_manager__ = staticmethod(mgr_factory) - - self.assertRaises(TypeError, attributes.register_class, B) - - def test_single_up(self): - - class A(object): pass - # delay registration - - mgr_factory = lambda cls: attributes.ClassManager(cls) - class B(A): - __sa_instrumentation_manager__ = staticmethod(mgr_factory) - attributes.register_class(B) - self.assertRaises(TypeError, attributes.register_class, A) - - def test_diamond_b1(self): - mgr_factory = lambda cls: attributes.ClassManager(cls) - - class A(object): pass - class B1(A): pass - class B2(A): - __sa_instrumentation_manager__ = mgr_factory - class C(object): pass - - self.assertRaises(TypeError, attributes.register_class, B1) - - def test_diamond_b2(self): - mgr_factory = lambda cls: attributes.ClassManager(cls) - - class A(object): pass - class B1(A): pass - class B2(A): - __sa_instrumentation_manager__ = mgr_factory - class C(object): pass - - self.assertRaises(TypeError, attributes.register_class, B2) - - def test_diamond_c_b(self): - mgr_factory = lambda cls: attributes.ClassManager(cls) - - class A(object): pass - class B1(A): pass - class B2(A): - __sa_instrumentation_manager__ = mgr_factory - class C(object): pass - - attributes.register_class(C) - self.assertRaises(TypeError, attributes.register_class, B1) - - -class OnLoadTest(_base.ORMTest): - """Check that Events.on_load is not hit in regular attributes operations.""" - - def test_basic(self): - import pickle - - global A - class A(object): - pass - - def canary(instance): assert False - - try: - attributes.register_class(A) - manager = attributes.manager_of_class(A) - manager.events.add_listener('on_load', canary) - - a = A() - p_a = pickle.dumps(a) - re_a = pickle.loads(p_a) - finally: - del A - - def tearDownAll(self): - clear_mappers() - attributes._install_lookup_strategy(util.symbol('native')) - - -class ExtendedEventsTest(_base.ORMTest): - """Allow custom Events implementations.""" - - @modifies_instrumentation_finders - def test_subclassed(self): - class MyEvents(attributes.Events): - pass - class MyClassManager(attributes.ClassManager): - event_registry_factory = MyEvents - - attributes.instrumentation_finders.insert(0, lambda cls: MyClassManager) - - class A(object): pass - - attributes.register_class(A) - manager = attributes.manager_of_class(A) - assert isinstance(manager.events, MyEvents) - - - -class NativeInstrumentationTest(_base.ORMTest): - @with_lookup_strategy(sa.util.symbol('native')) - def test_register_reserved_attribute(self): - class T(object): pass - - attributes.register_class(T) - manager = attributes.manager_of_class(T) - - sa = attributes.ClassManager.STATE_ATTR - ma = attributes.ClassManager.MANAGER_ATTR - - fails = lambda method, attr: self.assertRaises( - KeyError, getattr(manager, method), attr, property()) - - fails('install_member', sa) - fails('install_member', ma) - fails('install_descriptor', sa) - fails('install_descriptor', ma) - - @with_lookup_strategy(sa.util.symbol('native')) - def test_mapped_stateattr(self): - t = Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column(attributes.ClassManager.STATE_ATTR, Integer)) - - class T(object): pass - - self.assertRaises(KeyError, mapper, T, t) - - @with_lookup_strategy(sa.util.symbol('native')) - def test_mapped_managerattr(self): - t = Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column(attributes.ClassManager.MANAGER_ATTR, Integer)) - - class T(object): pass - self.assertRaises(KeyError, mapper, T, t) - - -class MiscTest(_base.ORMTest): - """Seems basic, but not directly covered elsewhere!""" - - def test_compileonattr(self): - t = Table('t', MetaData(), - Column('id', Integer, primary_key=True), - Column('x', Integer)) - class A(object): pass - mapper(A, t) - - a = A() - assert a.id is None - - def test_compileonattr_rel(self): - m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - Column('x', Integer)) - t2 = Table('t2', m, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer, ForeignKey('t1.id'))) - class A(object): pass - class B(object): pass - mapper(A, t1, properties=dict(bs=relation(B))) - mapper(B, t2) - - a = A() - assert not a.bs - - def test_uninstrument(self): - class A(object):pass - - manager = attributes.register_class(A) - - assert attributes.manager_of_class(A) is manager - attributes.unregister_class(A) - assert attributes.manager_of_class(A) is None - - def test_compileonattr_rel_backref_a(self): - m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - Column('x', Integer)) - t2 = Table('t2', m, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer, ForeignKey('t1.id'))) - - class Base(object): - def __init__(self, *args, **kwargs): - pass - - for base in object, Base: - class A(base): pass - class B(base): pass - mapper(A, t1, properties=dict(bs=relation(B, backref='a'))) - mapper(B, t2) - - b = B() - assert b.a is None - a = A() - b.a = a - - session = create_session() - session.add(b) - assert a in session, "base is %s" % base - - def test_compileonattr_rel_backref_b(self): - m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - Column('x', Integer)) - t2 = Table('t2', m, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer, ForeignKey('t1.id'))) - - class Base(object): - def __init__(self): pass - class Base_AKW(object): - def __init__(self, *args, **kwargs): pass - - for base in object, Base, Base_AKW: - class A(base): pass - class B(base): pass - mapper(A, t1) - mapper(B, t2, properties=dict(a=relation(A, backref='bs'))) - - a = A() - b = B() - b.a = a - - session = create_session() - session.add(a) - assert b in session, 'base: %s' % base - - -class FinderTest(_base.ORMTest): - def test_standard(self): - class A(object): pass - - attributes.register_class(A) - - eq_(type(attributes.manager_of_class(A)), attributes.ClassManager) - - def test_nativeext_interfaceexact(self): - class A(object): - __sa_instrumentation_manager__ = sa.orm.interfaces.InstrumentationManager - - attributes.register_class(A) - ne_(type(attributes.manager_of_class(A)), attributes.ClassManager) - - def test_nativeext_submanager(self): - class Mine(attributes.ClassManager): pass - class A(object): - __sa_instrumentation_manager__ = Mine - - attributes.register_class(A) - eq_(type(attributes.manager_of_class(A)), Mine) - - @modifies_instrumentation_finders - def test_customfinder_greedy(self): - class Mine(attributes.ClassManager): pass - class A(object): pass - def find(cls): - return Mine - - attributes.instrumentation_finders.insert(0, find) - attributes.register_class(A) - eq_(type(attributes.manager_of_class(A)), Mine) - - @modifies_instrumentation_finders - def test_customfinder_pass(self): - class A(object): pass - def find(cls): - return None - - attributes.instrumentation_finders.insert(0, find) - attributes.register_class(A) - eq_(type(attributes.manager_of_class(A)), attributes.ClassManager) - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py deleted file mode 100644 index b5c3b3669..000000000 --- a/test/orm/lazy_relations.py +++ /dev/null @@ -1,416 +0,0 @@ -"""basic tests of lazy loaded attributes""" - -import testenv; testenv.configure_for_tests() -import datetime -from sqlalchemy import exc as sa_exc -from sqlalchemy.orm import attributes -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from testlib.testing import eq_ -from orm import _base, _fixtures - - -class LazyTest(_fixtures.FixtureTest): - run_inserts = 'once' - run_deletes = None - - @testing.resolve_artifact_names - def test_basic(self): - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), lazy=True) - }) - sess = create_session() - q = sess.query(User) - assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all() - - @testing.resolve_artifact_names - def test_needs_parent(self): - """test the error raised when parent object is not bound.""" - - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), lazy=True) - }) - sess = create_session() - q = sess.query(User) - u = q.filter(users.c.id == 7).first() - sess.expunge(u) - self.assertRaises(sa_exc.InvalidRequestError, getattr, u, 'addresses') - - @testing.resolve_artifact_names - def test_orderby(self): - mapper(User, users, properties = { - 'addresses':relation(mapper(Address, addresses), lazy=True, order_by=addresses.c.email_address), - }) - q = create_session().query(User) - assert [ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=2, email_address='ed@wood.com') - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ] == q.all() - - @testing.resolve_artifact_names - def test_orderby_secondary(self): - """tests that a regular mapper select on a single table can order by a relation to a second table""" - - mapper(Address, addresses) - - mapper(User, users, properties = dict( - addresses = relation(Address, lazy=True), - )) - q = create_session().query(User) - l = q.filter(users.c.id==addresses.c.user_id).order_by(addresses.c.email_address).all() - assert [ - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=3, email_address='ed@bettyboop.com'), - Address(id=4, email_address='ed@lala.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=7, addresses=[ - Address(id=1) - ]), - ] == l - - @testing.resolve_artifact_names - def test_orderby_desc(self): - mapper(Address, addresses) - - mapper(User, users, properties = dict( - addresses = relation(Address, lazy=True, order_by=[sa.desc(addresses.c.email_address)]), - )) - sess = create_session() - assert [ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=8, addresses=[ - Address(id=2, email_address='ed@wood.com'), - Address(id=4, email_address='ed@lala.com'), - Address(id=3, email_address='ed@bettyboop.com'), - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - User(id=10, addresses=[]) - ] == sess.query(User).all() - - @testing.resolve_artifact_names - def test_no_orphan(self): - """test that a lazily loaded child object is not marked as an orphan""" - - mapper(User, users, properties={ - 'addresses':relation(Address, cascade="all,delete-orphan", lazy=True) - }) - mapper(Address, addresses) - - sess = create_session() - user = sess.query(User).get(7) - assert getattr(User, 'addresses').hasparent(attributes.instance_state(user.addresses[0]), optimistic=True) - assert not sa.orm.class_mapper(Address)._is_orphan(attributes.instance_state(user.addresses[0])) - - @testing.resolve_artifact_names - def test_limit(self): - """test limit operations combined with lazy-load relationships.""" - - mapper(Item, items) - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, lazy=True) - }) - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), lazy=True), - 'orders':relation(Order, lazy=True) - }) - - sess = create_session() - q = sess.query(User) - - if testing.against('maxdb', 'mssql'): - l = q.limit(2).all() - assert self.static.user_all_result[:2] == l - else: - l = q.limit(2).offset(1).all() - assert self.static.user_all_result[1:3] == l - - @testing.resolve_artifact_names - def test_distinct(self): - mapper(Item, items) - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, lazy=True) - }) - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), lazy=True), - 'orders':relation(Order, lazy=True) - }) - - sess = create_session() - q = sess.query(User) - - # use a union all to get a lot of rows to join against - u2 = users.alias('u2') - s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') - print [key for key in s.c.keys()] - l = q.filter(s.c.u2_id==User.id).distinct().all() - assert self.static.user_all_result == l - - @testing.resolve_artifact_names - def test_one_to_many_scalar(self): - mapper(User, users, properties = dict( - address = relation(mapper(Address, addresses), lazy=True, uselist=False) - )) - q = create_session().query(User) - l = q.filter(users.c.id == 7).all() - assert [User(id=7, address=Address(id=1))] == l - - @testing.resolve_artifact_names - def test_many_to_one_binds(self): - mapper(Address, addresses, primary_key=[addresses.c.user_id, addresses.c.email_address]) - - mapper(User, users, properties = dict( - address = relation(Address, uselist=False, - primaryjoin=sa.and_(users.c.id==addresses.c.user_id, addresses.c.email_address=='ed@bettyboop.com') - ) - )) - q = create_session().query(User) - eq_( - [ - User(id=7, address=None), - User(id=8, address=Address(id=3)), - User(id=9, address=None), - User(id=10, address=None), - ], - list(q) - ) - - - @testing.resolve_artifact_names - def test_double(self): - """tests lazy loading with two relations simulatneously, from the same table, using aliases. """ - - openorders = sa.alias(orders, 'openorders') - closedorders = sa.alias(orders, 'closedorders') - - mapper(Address, addresses) - - mapper(Order, orders) - - open_mapper = mapper(Order, openorders, non_primary=True) - closed_mapper = mapper(Order, closedorders, non_primary=True) - mapper(User, users, properties = dict( - addresses = relation(Address, lazy = True), - open_orders = relation(open_mapper, primaryjoin = sa.and_(openorders.c.isopen == 1, users.c.id==openorders.c.user_id), lazy=True), - closed_orders = relation(closed_mapper, primaryjoin = sa.and_(closedorders.c.isopen == 0, users.c.id==closedorders.c.user_id), lazy=True) - )) - q = create_session().query(User) - - assert [ - User( - id=7, - addresses=[Address(id=1)], - open_orders = [Order(id=3)], - closed_orders = [Order(id=1), Order(id=5)] - ), - User( - id=8, - addresses=[Address(id=2), Address(id=3), Address(id=4)], - open_orders = [], - closed_orders = [] - ), - User( - id=9, - addresses=[Address(id=5)], - open_orders = [Order(id=4)], - closed_orders = [Order(id=2)] - ), - User(id=10) - - ] == q.all() - - sess = create_session() - user = sess.query(User).get(7) - assert [Order(id=1), Order(id=5)] == create_session().query(closed_mapper).with_parent(user, property='closed_orders').all() - assert [Order(id=3)] == create_session().query(open_mapper).with_parent(user, property='open_orders').all() - - @testing.resolve_artifact_names - def test_many_to_many(self): - - mapper(Keyword, keywords) - mapper(Item, items, properties = dict( - keywords = relation(Keyword, secondary=item_keywords, lazy=True), - )) - - q = create_session().query(Item) - assert self.static.item_keyword_result == q.all() - - assert self.static.item_keyword_result[0:2] == q.join('keywords').filter(keywords.c.name == 'red').all() - - @testing.resolve_artifact_names - def test_uses_get(self): - """test that a simple many-to-one lazyload optimizes to use query.get().""" - - for pj in ( - None, - users.c.id==addresses.c.user_id, - addresses.c.user_id==users.c.id - ): - mapper(Address, addresses, properties = dict( - user = relation(mapper(User, users), lazy=True, primaryjoin=pj) - )) - - sess = create_session() - - # load address - a1 = sess.query(Address).filter_by(email_address="ed@wood.com").one() - - # load user that is attached to the address - u1 = sess.query(User).get(8) - - def go(): - # lazy load of a1.user should get it from the session - assert a1.user is u1 - self.assert_sql_count(testing.db, go, 0) - sa.orm.clear_mappers() - - @testing.resolve_artifact_names - def test_many_to_one(self): - mapper(Address, addresses, properties = dict( - user = relation(mapper(User, users), lazy=True) - )) - sess = create_session() - q = sess.query(Address) - a = q.filter(addresses.c.id==1).one() - - assert a.user is not None - - u1 = sess.query(User).get(7) - - assert a.user is u1 - - @testing.resolve_artifact_names - def test_backrefs_dont_lazyload(self): - mapper(User, users, properties={ - 'addresses':relation(Address, backref='user') - }) - mapper(Address, addresses) - sess = create_session() - ad = sess.query(Address).filter_by(id=1).one() - assert ad.user.id == 7 - def go(): - ad.user = None - assert ad.user is None - self.assert_sql_count(testing.db, go, 0) - - u1 = sess.query(User).filter_by(id=7).one() - def go(): - assert ad not in u1.addresses - self.assert_sql_count(testing.db, go, 1) - - sess.expire(u1, ['addresses']) - def go(): - assert ad in u1.addresses - self.assert_sql_count(testing.db, go, 1) - - sess.expire(u1, ['addresses']) - ad2 = Address() - def go(): - ad2.user = u1 - assert ad2.user is u1 - self.assert_sql_count(testing.db, go, 0) - - def go(): - assert ad2 in u1.addresses - self.assert_sql_count(testing.db, go, 1) - - -class M2OGetTest(_fixtures.FixtureTest): - run_inserts = 'once' - run_deletes = None - - @testing.resolve_artifact_names - def test_m2o_noload(self): - """test that a NULL foreign key doesn't trigger a lazy load""" - mapper(User, users) - - mapper(Address, addresses, properties={ - 'user':relation(User) - }) - - sess = create_session() - ad1 = Address(email_address='somenewaddress', id=12) - sess.add(ad1) - sess.flush() - sess.expunge_all() - - ad2 = sess.query(Address).get(1) - ad3 = sess.query(Address).get(ad1.id) - def go(): - # one lazy load - assert ad2.user.name == 'jack' - # no lazy load - assert ad3.user is None - self.assert_sql_count(testing.db, go, 1) - -class CorrelatedTest(_base.MappedTest): - - def define_tables(self, meta): - Table('user_t', meta, - Column('id', Integer, primary_key=True), - Column('name', String(50))) - - Table('stuff', meta, - Column('id', Integer, primary_key=True), - Column('date', sa.Date), - Column('user_id', Integer, ForeignKey('user_t.id'))) - - @testing.resolve_artifact_names - def insert_data(self): - user_t.insert().execute( - {'id':1, 'name':'user1'}, - {'id':2, 'name':'user2'}, - {'id':3, 'name':'user3'}) - - stuff.insert().execute( - {'id':1, 'user_id':1, 'date':datetime.date(2007, 10, 15)}, - {'id':2, 'user_id':1, 'date':datetime.date(2007, 12, 15)}, - {'id':3, 'user_id':1, 'date':datetime.date(2007, 11, 15)}, - {'id':4, 'user_id':2, 'date':datetime.date(2008, 1, 15)}, - {'id':5, 'user_id':3, 'date':datetime.date(2007, 6, 15)}) - - @testing.resolve_artifact_names - def test_correlated_lazyload(self): - class User(_base.ComparableEntity): - pass - - class Stuff(_base.ComparableEntity): - pass - - mapper(Stuff, stuff) - - stuff_view = sa.select([stuff.c.id]).where(stuff.c.user_id==user_t.c.id).correlate(user_t).order_by(sa.desc(stuff.c.date)).limit(1) - - mapper(User, user_t, properties={ - 'stuff':relation(Stuff, primaryjoin=sa.and_(user_t.c.id==stuff.c.user_id, stuff.c.id==(stuff_view.as_scalar()))) - }) - - sess = create_session() - - eq_(sess.query(User).all(), [ - User(name='user1', stuff=[Stuff(date=datetime.date(2007, 12, 15), id=2)]), - User(name='user2', stuff=[Stuff(id=4, date=datetime.date(2008, 1 , 15))]), - User(name='user3', stuff=[Stuff(id=5, date=datetime.date(2007, 6, 15))]) - ]) - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/lazytest1.py b/test/orm/lazytest1.py deleted file mode 100644 index 5ebb8feeb..000000000 --- a/test/orm/lazytest1.py +++ /dev/null @@ -1,90 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base - - -class LazyTest(_base.MappedTest): - def define_tables(self, metadata): - Table('infos', metadata, - Column('pk', Integer, primary_key=True), - Column('info', String(128))) - - Table('data', metadata, - Column('data_pk', Integer, primary_key=True), - Column('info_pk', Integer, - ForeignKey('infos.pk')), - Column('timeval', Integer), - Column('data_val', String(128))) - - Table('rels', metadata, - Column('rel_pk', Integer, primary_key=True), - Column('info_pk', Integer, - ForeignKey('infos.pk')), - Column('start', Integer), - Column('finish', Integer)) - - @testing.resolve_artifact_names - def insert_data(self): - infos.insert().execute( - {'pk':1, 'info':'pk_1_info'}, - {'pk':2, 'info':'pk_2_info'}, - {'pk':3, 'info':'pk_3_info'}, - {'pk':4, 'info':'pk_4_info'}, - {'pk':5, 'info':'pk_5_info'}) - - rels.insert().execute( - {'rel_pk':1, 'info_pk':1, 'start':10, 'finish':19}, - {'rel_pk':2, 'info_pk':1, 'start':100, 'finish':199}, - {'rel_pk':3, 'info_pk':2, 'start':20, 'finish':29}, - {'rel_pk':4, 'info_pk':3, 'start':13, 'finish':23}, - {'rel_pk':5, 'info_pk':5, 'start':15, 'finish':25}) - - data.insert().execute( - {'data_pk':1, 'info_pk':1, 'timeval':11, 'data_val':'11_data'}, - {'data_pk':2, 'info_pk':1, 'timeval':9, 'data_val':'9_data'}, - {'data_pk':3, 'info_pk':1, 'timeval':13, 'data_val':'13_data'}, - {'data_pk':4, 'info_pk':2, 'timeval':23, 'data_val':'23_data'}, - {'data_pk':5, 'info_pk':2, 'timeval':13, 'data_val':'13_data'}, - {'data_pk':6, 'info_pk':1, 'timeval':15, 'data_val':'15_data'}) - - @testing.resolve_artifact_names - def testone(self): - """A lazy load which has multiple join conditions. - - Including two that are against the same column in the child table. - - """ - class Information(object): - pass - - class Relation(object): - pass - - class Data(object): - pass - - session = create_session() - - mapper(Data, data) - mapper(Relation, rels, properties={ - 'datas': relation(Data, - primaryjoin=sa.and_( - rels.c.info_pk == - data.c.info_pk, - data.c.timeval >= rels.c.start, - data.c.timeval <= rels.c.finish), - foreign_keys=[data.c.info_pk])}) - mapper(Information, infos, properties={ - 'rels': relation(Relation) - }) - - info = session.query(Information).get(1) - assert info - assert len(info.rels) == 2 - assert len(info.rels[0].datas) == 3 - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/manytomany.py b/test/orm/manytomany.py deleted file mode 100644 index 23af3bd1f..000000000 --- a/test/orm/manytomany.py +++ /dev/null @@ -1,324 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base - - -class M2MTest(_base.MappedTest): - def define_tables(self, metadata): - Table('place', metadata, - Column('place_id', Integer, sa.Sequence('pid_seq', optional=True), - primary_key=True), - Column('name', String(30), nullable=False)) - - Table('transition', metadata, - Column('transition_id', Integer, - sa.Sequence('tid_seq', optional=True), primary_key=True), - Column('name', String(30), nullable=False)) - - Table('place_thingy', metadata, - Column('thingy_id', Integer, sa.Sequence('thid_seq', optional=True), - primary_key=True), - Column('place_id', Integer, ForeignKey('place.place_id'), - nullable=False), - Column('name', String(30), nullable=False)) - - # association table #1 - Table('place_input', metadata, - Column('place_id', Integer, ForeignKey('place.place_id')), - Column('transition_id', Integer, - ForeignKey('transition.transition_id'))) - - # association table #2 - Table('place_output', metadata, - Column('place_id', Integer, ForeignKey('place.place_id')), - Column('transition_id', Integer, - ForeignKey('transition.transition_id'))) - - Table('place_place', metadata, - Column('pl1_id', Integer, ForeignKey('place.place_id')), - Column('pl2_id', Integer, ForeignKey('place.place_id'))) - - def setup_classes(self): - class Place(_base.BasicEntity): - def __init__(self, name=None): - self.name = name - def __str__(self): - return "(Place '%s')" % self.name - __repr__ = __str__ - - class PlaceThingy(_base.BasicEntity): - def __init__(self, name=None): - self.name = name - - class Transition(_base.BasicEntity): - def __init__(self, name=None): - self.name = name - self.inputs = [] - self.outputs = [] - def __repr__(self): - return ' '.join((object.__repr__(self), - repr(self.inputs), - repr(self.outputs))) - - @testing.resolve_artifact_names - def test_error(self): - mapper(Place, place, properties={ - 'transitions':relation(Transition, secondary=place_input, backref='places') - }) - mapper(Transition, transition, properties={ - 'places':relation(Place, secondary=place_input, backref='transitions') - }) - self.assertRaisesMessage(sa.exc.ArgumentError, "Error creating backref", - sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_circular(self): - """test a many-to-many relationship from a table to itself.""" - - Place.mapper = mapper(Place, place) - - Place.mapper.add_property('places', relation( - Place.mapper, secondary=place_place, primaryjoin=place.c.place_id==place_place.c.pl1_id, - secondaryjoin=place.c.place_id==place_place.c.pl2_id, - order_by=place_place.c.pl2_id, - lazy=True, - )) - - sess = create_session() - p1 = Place('place1') - p2 = Place('place2') - p3 = Place('place3') - p4 = Place('place4') - p5 = Place('place5') - p6 = Place('place6') - p7 = Place('place7') - sess.add_all((p1, p2, p3, p4, p5, p6, p7)) - p1.places.append(p2) - p1.places.append(p3) - p5.places.append(p6) - p6.places.append(p1) - p7.places.append(p1) - p1.places.append(p5) - p4.places.append(p3) - p3.places.append(p4) - sess.flush() - - sess.expunge_all() - l = sess.query(Place).order_by(place.c.place_id).all() - (p1, p2, p3, p4, p5, p6, p7) = l - assert p1.places == [p2,p3,p5] - assert p5.places == [p6] - assert p7.places == [p1] - assert p6.places == [p1] - assert p4.places == [p3] - assert p3.places == [p4] - assert p2.places == [] - - for p in l: - pp = p.places - print "Place " + str(p) +" places " + repr(pp) - - [sess.delete(p) for p in p1,p2,p3,p4,p5,p6,p7] - sess.flush() - - @testing.resolve_artifact_names - def test_double(self): - """test that a mapper can have two eager relations to the same table, via - two different association tables. aliases are required.""" - - Place.mapper = mapper(Place, place, properties = { - 'thingies':relation(mapper(PlaceThingy, place_thingy), lazy=False) - }) - - Transition.mapper = mapper(Transition, transition, properties = dict( - inputs = relation(Place.mapper, place_output, lazy=False), - outputs = relation(Place.mapper, place_input, lazy=False), - ) - ) - - tran = Transition('transition1') - tran.inputs.append(Place('place1')) - tran.outputs.append(Place('place2')) - tran.outputs.append(Place('place3')) - sess = create_session() - sess.add(tran) - sess.flush() - - sess.expunge_all() - r = sess.query(Transition).all() - self.assert_unordered_result(r, Transition, - {'name': 'transition1', - 'inputs': (Place, [{'name':'place1'}]), - 'outputs': (Place, [{'name':'place2'}, {'name':'place3'}]) - }) - - @testing.resolve_artifact_names - def test_bidirectional(self): - """tests a many-to-many backrefs""" - Place.mapper = mapper(Place, place) - Transition.mapper = mapper(Transition, transition, properties = dict( - inputs = relation(Place.mapper, place_output, lazy=True, backref='inputs'), - outputs = relation(Place.mapper, place_input, lazy=True, backref='outputs'), - ) - ) - - t1 = Transition('transition1') - t2 = Transition('transition2') - t3 = Transition('transition3') - p1 = Place('place1') - p2 = Place('place2') - p3 = Place('place3') - - t1.inputs.append(p1) - t1.inputs.append(p2) - t1.outputs.append(p3) - t2.inputs.append(p1) - p2.inputs.append(t2) - p3.inputs.append(t2) - p1.outputs.append(t1) - sess = create_session() - sess.add_all((t1, t2, t3,p1, p2, p3)) - sess.flush() - - self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])}) - self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])}) - - -class M2MTest2(_base.MappedTest): - def define_tables(self, metadata): - Table('student', metadata, - Column('name', String(20), primary_key=True)) - - Table('course', metadata, - Column('name', String(20), primary_key=True)) - - Table('enroll', metadata, - Column('student_id', String(20), ForeignKey('student.name'), - primary_key=True), - Column('course_id', String(20), ForeignKey('course.name'), - primary_key=True)) - - def setup_classes(self): - class Student(_base.BasicEntity): - def __init__(self, name=''): - self.name = name - class Course(_base.BasicEntity): - def __init__(self, name=''): - self.name = name - - @testing.resolve_artifact_names - def test_circular(self): - - mapper(Student, student) - mapper(Course, course, properties={ - 'students': relation(Student, enroll, backref='courses')}) - - sess = create_session() - s1 = Student('Student1') - c1 = Course('Course1') - c2 = Course('Course2') - c3 = Course('Course3') - s1.courses.append(c1) - s1.courses.append(c2) - c3.students.append(s1) - self.assert_(len(s1.courses) == 3) - self.assert_(len(c1.students) == 1) - sess.add(s1) - sess.flush() - sess.expunge_all() - s = sess.query(Student).filter_by(name='Student1').one() - c = sess.query(Course).filter_by(name='Course3').one() - self.assert_(len(s.courses) == 3) - del s.courses[1] - self.assert_(len(s.courses) == 2) - - @testing.resolve_artifact_names - def test_dupliates_raise(self): - """test constraint error is raised for dupe entries in a list""" - - mapper(Student, student) - mapper(Course, course, properties={ - 'students': relation(Student, enroll, backref='courses')}) - - sess = create_session() - s1 = Student("s1") - c1 = Course('c1') - s1.courses.append(c1) - s1.courses.append(c1) - sess.add(s1) - self.assertRaises(sa.exc.DBAPIError, sess.flush) - - @testing.resolve_artifact_names - def test_delete(self): - """A many-to-many table gets cleared out with deletion from the backref side""" - - mapper(Student, student) - mapper(Course, course, properties = { - 'students': relation(Student, enroll, lazy=True, - backref='courses')}) - - sess = create_session() - s1 = Student('Student1') - c1 = Course('Course1') - c2 = Course('Course2') - c3 = Course('Course3') - s1.courses.append(c1) - s1.courses.append(c2) - c3.students.append(s1) - sess.add(s1) - sess.flush() - sess.delete(s1) - sess.flush() - assert enroll.count().scalar() == 0 - -class M2MTest3(_base.MappedTest): - def define_tables(self, metadata): - Table('c', metadata, - Column('c1', Integer, primary_key = True), - Column('c2', String(20))) - - Table('a', metadata, - Column('a1', Integer, primary_key=True), - Column('a2', String(20)), - Column('c1', Integer, ForeignKey('c.c1'))) - - Table('c2a1', metadata, - Column('c1', Integer, ForeignKey('c.c1')), - Column('a1', Integer, ForeignKey('a.a1'))) - - Table('c2a2', metadata, - Column('c1', Integer, ForeignKey('c.c1')), - Column('a1', Integer, ForeignKey('a.a1'))) - - Table('b', metadata, - Column('b1', Integer, primary_key=True), - Column('a1', Integer, ForeignKey('a.a1')), - Column('b2', sa.Boolean)) - - @testing.resolve_artifact_names - def test_basic(self): - class C(object):pass - class A(object):pass - class B(object):pass - - mapper(B, b) - - mapper(A, a, properties={ - 'tbs': relation(B, primaryjoin=sa.and_(b.c.a1 == a.c.a1, - b.c.b2 == True), - lazy=False)}) - - mapper(C, c, properties={ - 'a1s': relation(A, secondary=c2a1, lazy=False), - 'a2s': relation(A, secondary=c2a2, lazy=False)}) - - assert create_session().query(C).with_labels().statement - - # TODO: seems like just a test for an ancient exception throw. - # how about some data/inserts/queries/assertions for this one - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/mapper.py b/test/orm/mapper.py deleted file mode 100644 index 13e02a38a..000000000 --- a/test/orm/mapper.py +++ /dev/null @@ -1,2467 +0,0 @@ -"""General mapper operations with an emphasis on selecting/loading.""" - -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, func -from testlib.sa.engine import default -from testlib.sa.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased -from testlib.sa.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property -from testlib.testing import eq_, AssertsCompiledSQL -import pickleable -from orm import _base, _fixtures - - -class MapperTest(_fixtures.FixtureTest): - - @testing.resolve_artifact_names - def test_prop_shadow(self): - """A backref name may not shadow an existing property name.""" - - mapper(Address, addresses) - mapper(User, users, - properties={ - 'addresses':relation(Address, backref='email_address') - }) - self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_update_attr_keys(self): - """test that update()/insert() use the correct key when given InstrumentedAttributes.""" - - mapper(User, users, properties={ - 'foobar':users.c.name - }) - - users.insert().values({User.foobar:'name1'}).execute() - eq_(sa.select([User.foobar]).where(User.foobar=='name1').execute().fetchall(), [('name1',)]) - - users.update().values({User.foobar:User.foobar + 'foo'}).execute() - eq_(sa.select([User.foobar]).where(User.foobar=='name1foo').execute().fetchall(), [('name1foo',)]) - - @testing.resolve_artifact_names - def test_utils(self): - from sqlalchemy.orm.util import _is_mapped_class, _is_aliased_class - - class Foo(object): - x = "something" - @property - def y(self): - return "somethign else" - m = mapper(Foo, users) - a1 = aliased(Foo) - - f = Foo() - - for fn, arg, ret in [ - (_is_mapped_class, Foo.x, False), - (_is_mapped_class, Foo.y, False), - (_is_mapped_class, Foo, True), - (_is_mapped_class, f, False), - (_is_mapped_class, a1, True), - (_is_mapped_class, m, True), - (_is_aliased_class, a1, True), - (_is_aliased_class, Foo.x, False), - (_is_aliased_class, Foo.y, False), - (_is_aliased_class, Foo, False), - (_is_aliased_class, f, False), - (_is_aliased_class, a1, True), - (_is_aliased_class, m, False), - ]: - assert fn(arg) == ret - - - - @testing.resolve_artifact_names - def test_prop_accessor(self): - mapper(User, users) - self.assertRaises(NotImplementedError, - getattr, sa.orm.class_mapper(User), 'properties') - - - @testing.resolve_artifact_names - def test_bad_cascade(self): - mapper(Address, addresses) - self.assertRaises(sa.exc.ArgumentError, - relation, Address, cascade="fake, all, delete-orphan") - - @testing.resolve_artifact_names - def test_exceptions_sticky(self): - """test preservation of mapper compile errors raised during hasattr().""" - - mapper(Address, addresses, properties={ - 'user':relation(User) - }) - - hasattr(Address.user, 'property') - self.assertRaisesMessage(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers) - - @testing.resolve_artifact_names - def test_column_prefix(self): - mapper(User, users, column_prefix='_', properties={ - 'user_name': synonym('_name') - }) - - s = create_session() - u = s.query(User).get(7) - eq_(u._name, 'jack') - eq_(u._id,7) - u2 = s.query(User).filter_by(user_name='jack').one() - assert u is u2 - - @testing.resolve_artifact_names - def test_no_pks_1(self): - s = sa.select([users.c.name]).alias('foo') - self.assertRaises(sa.exc.ArgumentError, mapper, User, s) - - @testing.emits_warning( - 'mapper Mapper|User|Select object creating an alias for ' - 'the given selectable - use Class attributes for queries') - @testing.resolve_artifact_names - def test_no_pks_2(self): - s = sa.select([users.c.name]) - self.assertRaises(sa.exc.ArgumentError, mapper, User, s) - - @testing.resolve_artifact_names - def test_recompile_on_other_mapper(self): - """A compile trigger on an already-compiled mapper still triggers a check against all mappers.""" - mapper(User, users) - sa.orm.compile_mappers() - assert sa.orm.mapperlib._new_mappers is False - - m = mapper(Address, addresses, properties={ - 'user': relation(User, backref="addresses")}) - - assert m.compiled is False - assert sa.orm.mapperlib._new_mappers is True - u = User() - assert User.addresses - assert sa.orm.mapperlib._new_mappers is False - - @testing.resolve_artifact_names - def test_compile_on_session(self): - m = mapper(User, users) - session = create_session() - session.connection(m) - - @testing.resolve_artifact_names - def test_incomplete_columns(self): - """Loading from a select which does not contain all columns""" - mapper(Address, addresses) - s = create_session() - a = s.query(Address).from_statement( - sa.select([addresses.c.id, addresses.c.user_id])).first() - eq_(a.user_id, 7) - eq_(a.id, 1) - # email address auto-defers - assert 'email_addres' not in a.__dict__ - eq_(a.email_address, 'jack@bean.com') - - @testing.resolve_artifact_names - def test_bad_constructor(self): - """If the construction of a mapped class fails, the instance does not get placed in the session""" - class Foo(object): - def __init__(self, one, two, _sa_session=None): - pass - - mapper(Foo, users, extension=sa.orm.scoped_session( - create_session).extension) - - sess = create_session() - self.assertRaises(TypeError, Foo, 'one', _sa_session=sess) - eq_(len(list(sess)), 0) - self.assertRaises(TypeError, Foo, 'one') - Foo('one', 'two', _sa_session=sess) - eq_(len(list(sess)), 1) - - @testing.resolve_artifact_names - def test_constructor_exc_1(self): - """Exceptions raised in the mapped class are not masked by sa decorations""" - ex = AssertionError('oops') - sess = create_session() - - class Foo(object): - def __init__(self, **kw): - raise ex - mapper(Foo, users) - - try: - Foo() - assert False - except Exception, e: - assert e is ex - - sa.orm.clear_mappers() - mapper(Foo, users, extension=sa.orm.scoped_session( - create_session).extension) - def bad_expunge(foo): - raise Exception("this exception should be stated as a warning") - - sess.expunge = bad_expunge - self.assertRaises(sa.exc.SAWarning, Foo, _sa_session=sess) - - @testing.resolve_artifact_names - def test_constructor_exc_2(self): - """TypeError is raised for illegal constructor args, whether or not explicit __init__ is present [ticket:908].""" - - class Foo(object): - def __init__(self): - pass - class Bar(object): - pass - - mapper(Foo, users) - mapper(Bar, addresses) - self.assertRaises(TypeError, Foo, x=5) - self.assertRaises(TypeError, Bar, x=5) - - @testing.resolve_artifact_names - def test_props(self): - m = mapper(User, users, properties = { - 'addresses' : relation(mapper(Address, addresses)) - }).compile() - assert User.addresses.property is m.get_property('addresses') - - @testing.resolve_artifact_names - def test_compile_on_prop_1(self): - mapper(User, users, properties = { - 'addresses' : relation(mapper(Address, addresses)) - }) - User.addresses.any(Address.email_address=='foo@bar.com') - - @testing.resolve_artifact_names - def test_compile_on_prop_2(self): - mapper(User, users, properties = { - 'addresses' : relation(mapper(Address, addresses)) - }) - eq_(str(User.id == 3), str(users.c.id==3)) - - @testing.resolve_artifact_names - def test_compile_on_prop_3(self): - class Foo(User):pass - mapper(User, users) - mapper(Foo, addresses, inherits=User) - assert getattr(Foo().__class__, 'name').impl is not None - - @testing.resolve_artifact_names - def test_deferred_subclass_attribute_instrument(self): - class Foo(User):pass - mapper(User, users) - compile_mappers() - mapper(Foo, addresses, inherits=User) - assert getattr(Foo().__class__, 'name').impl is not None - - @testing.resolve_artifact_names - def test_compile_on_get_props_1(self): - m =mapper(User, users) - assert not m.compiled - assert list(m.iterate_properties) - assert m.compiled - - @testing.resolve_artifact_names - def test_compile_on_get_props_2(self): - m= mapper(User, users) - assert not m.compiled - assert m.get_property('name') - assert m.compiled - - @testing.resolve_artifact_names - def test_add_property(self): - assert_col = [] - - class User(_base.ComparableEntity): - def _get_name(self): - assert_col.append(('get', self._name)) - return self._name - def _set_name(self, name): - assert_col.append(('set', name)) - self._name = name - name = property(_get_name, _set_name) - - def _uc_name(self): - if self._name is None: - return None - return self._name.upper() - uc_name = property(_uc_name) - uc_name2 = property(_uc_name) - - m = mapper(User, users) - mapper(Address, addresses) - - class UCComparator(sa.orm.PropComparator): - __hash__ = None - - def __eq__(self, other): - cls = self.prop.parent.class_ - col = getattr(cls, 'name') - if other is None: - return col == None - else: - return sa.func.upper(col) == sa.func.upper(other) - - m.add_property('_name', deferred(users.c.name)) - m.add_property('name', synonym('_name')) - m.add_property('addresses', relation(Address)) - m.add_property('uc_name', sa.orm.comparable_property(UCComparator)) - m.add_property('uc_name2', sa.orm.comparable_property( - UCComparator, User.uc_name2)) - - sess = create_session(autocommit=False) - assert sess.query(User).get(7) - - u = sess.query(User).filter_by(name='jack').one() - - def go(): - eq_(len(u.addresses), - len(self.static.user_address_result[0].addresses)) - eq_(u.name, 'jack') - eq_(u.uc_name, 'JACK') - eq_(u.uc_name2, 'JACK') - eq_(assert_col, [('get', 'jack')], str(assert_col)) - self.sql_count_(2, go) - - u.name = 'ed' - u3 = User() - u3.name = 'some user' - sess.add(u3) - sess.flush() - sess.rollback() - - @testing.resolve_artifact_names - def test_replace_property(self): - m = mapper(User, users) - m.add_property('_name',users.c.name) - m.add_property('name', synonym('_name', proxy=True)) - - sess = create_session() - u = sess.query(User).filter_by(name='jack').one() - eq_(u._name, 'jack') - eq_(u.name, 'jack') - u.name = 'jacko' - assert m._columntoproperty[users.c.name] is m.get_property('_name') - - sa.orm.clear_mappers() - - m = mapper(User, users) - m.add_property('name', synonym('_name', map_column=True)) - - sess.expunge_all() - u = sess.query(User).filter_by(name='jack').one() - eq_(u._name, 'jack') - eq_(u.name, 'jack') - u.name = 'jacko' - assert m._columntoproperty[users.c.name] is m.get_property('_name') - - @testing.resolve_artifact_names - def test_synonym_replaces_backref(self): - assert_calls = [] - class Address(object): - def _get_user(self): - assert_calls.append("get") - return self._user - def _set_user(self, user): - assert_calls.append("set") - self._user = user - user = property(_get_user, _set_user) - - # synonym is created against nonexistent prop - mapper(Address, addresses, properties={ - 'user':synonym('_user') - }) - sa.orm.compile_mappers() - - # later, backref sets up the prop - mapper(User, users, properties={ - 'addresses':relation(Address, backref='_user') - }) - - sess = create_session() - u1 = sess.query(User).get(7) - u2 = sess.query(User).get(8) - # comparaison ops need to work - a1 = sess.query(Address).filter(Address.user==u1).one() - eq_(a1.id, 1) - a1.user = u2 - assert a1.user is u2 - eq_(assert_calls, ["set", "get"]) - - @testing.resolve_artifact_names - def test_self_ref_synonym(self): - t = Table('nodes', MetaData(), - Column('id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('nodes.id'))) - - class Node(object): - pass - - mapper(Node, t, properties={ - '_children':relation(Node, backref=backref('_parent', remote_side=t.c.id)), - 'children':synonym('_children'), - 'parent':synonym('_parent') - }) - - n1 = Node() - n2 = Node() - n1.children.append(n2) - assert n2.parent is n2._parent is n1 - assert n1.children[0] is n1._children[0] is n2 - eq_(str(Node.parent == n2), ":param_1 = nodes.parent_id") - - @testing.resolve_artifact_names - def test_illegal_non_primary(self): - mapper(User, users) - mapper(Address, addresses) - try: - mapper(User, users, non_primary=True, properties={ - 'addresses':relation(Address) - }).compile() - assert False - except sa.exc.ArgumentError, e: - assert "Attempting to assign a new relation 'addresses' to a non-primary mapper on class 'User'" in str(e) - - @testing.resolve_artifact_names - def test_illegal_non_primary_2(self): - try: - mapper(User, users, non_primary=True) - assert False - except sa.exc.InvalidRequestError, e: - assert "Configure a primary mapper first" in str(e) - - @testing.resolve_artifact_names - def test_prop_filters(self): - t = Table('person', MetaData(), - Column('id', Integer, primary_key=True), - Column('type', String(128)), - Column('name', String(128)), - Column('employee_number', Integer), - Column('boss_id', Integer, ForeignKey('person.id')), - Column('vendor_id', Integer)) - - class Person(object): pass - class Vendor(Person): pass - class Employee(Person): pass - class Manager(Employee): pass - class Hoho(object): pass - class Lala(object): pass - - class HasDef(object): - def name(self): - pass - - p_m = mapper(Person, t, polymorphic_on=t.c.type, - include_properties=('id', 'type', 'name')) - e_m = mapper(Employee, inherits=p_m, polymorphic_identity='employee', - properties={ - 'boss': relation(Manager, backref=backref('peon', ), remote_side=t.c.id) - }, - exclude_properties=('vendor_id',)) - - m_m = mapper(Manager, inherits=e_m, polymorphic_identity='manager', - include_properties=('id', 'type')) - - v_m = mapper(Vendor, inherits=p_m, polymorphic_identity='vendor', - exclude_properties=('boss_id', 'employee_number')) - h_m = mapper(Hoho, t, include_properties=('id', 'type', 'name')) - l_m = mapper(Lala, t, exclude_properties=('vendor_id', 'boss_id'), - column_prefix="p_") - - hd_m = mapper(HasDef, t, column_prefix="h_") - - p_m.compile() - #sa.orm.compile_mappers() - - def assert_props(cls, want): - have = set([n for n in dir(cls) if not n.startswith('_')]) - want = set(want) - eq_(have, want) - - def assert_instrumented(cls, want): - have = set([p.key for p in class_mapper(cls).iterate_properties]) - want = set(want) - eq_(have, want) - - assert_props(HasDef, ['h_boss_id', 'h_employee_number', 'h_id', 'name', 'h_name', 'h_vendor_id', 'h_type']) - assert_props(Person, ['id', 'name', 'type']) - assert_instrumented(Person, ['id', 'name', 'type']) - assert_props(Employee, ['boss', 'boss_id', 'employee_number', - 'id', 'name', 'type']) - assert_instrumented(Employee,['boss', 'boss_id', 'employee_number', - 'id', 'name', 'type']) - assert_props(Manager, ['boss', 'boss_id', 'employee_number', 'peon', - 'id', 'name', 'type']) - - # 'peon' and 'type' are both explicitly stated properties - assert_instrumented(Manager, ['peon', 'type', 'id']) - - assert_props(Vendor, ['vendor_id', 'id', 'name', 'type']) - assert_props(Hoho, ['id', 'name', 'type']) - assert_props(Lala, ['p_employee_number', 'p_id', 'p_name', 'p_type']) - - # excluding the discriminator column is currently not allowed - class Foo(Person): - pass - self.assertRaises(sa.exc.InvalidRequestError, mapper, Foo, inherits=Person, polymorphic_identity='foo', exclude_properties=('type',) ) - - @testing.resolve_artifact_names - def test_mapping_to_join(self): - """Mapping to a join""" - usersaddresses = sa.join(users, addresses, - users.c.id == addresses.c.user_id) - mapper(User, usersaddresses, primary_key=[users.c.id]) - l = create_session().query(User).order_by(users.c.id).all() - eq_(l, self.static.user_result[:3]) - - @testing.resolve_artifact_names - def test_mapping_to_join_no_pk(self): - m = mapper(Address, addresses.join(email_bounces)) - m.compile() - assert addresses in m._pks_by_table - assert email_bounces not in m._pks_by_table - - sess = create_session() - a = Address(id=10, email_address='e1') - sess.add(a) - sess.flush() - - eq_(addresses.count().scalar(), 6) - eq_(email_bounces.count().scalar(), 5) - - @testing.resolve_artifact_names - def test_mapping_to_outerjoin(self): - """Mapping to an outer join with a nullable composite primary key.""" - - - mapper(User, users.outerjoin(addresses), - allow_null_pks=True, - primary_key=[users.c.id, addresses.c.id], - properties=dict( - address_id=addresses.c.id)) - - session = create_session() - l = session.query(User).order_by(User.id, User.address_id).all() - - eq_(l, [ - User(id=7, address_id=1), - User(id=8, address_id=2), - User(id=8, address_id=3), - User(id=8, address_id=4), - User(id=9, address_id=5), - User(id=10, address_id=None)]) - - @testing.resolve_artifact_names - def test_custom_join(self): - """select_from totally replace the FROM parameters.""" - - mapper(Item, items) - - mapper(Order, orders, properties=dict( - items=relation(Item, order_items))) - - mapper(User, users, properties=dict( - orders=relation(Order))) - - session = create_session() - l = (session.query(User). - select_from(users.join(orders). - join(order_items). - join(items)). - filter(items.c.description == 'item 4')).all() - - eq_(l, [self.static.user_result[0]]) - - @testing.resolve_artifact_names - def test_cancel_order_by(self): - mapper(User, users, order_by=users.c.name.desc()) - - assert "order by users.name desc" in str(create_session().query(User).statement).lower() - assert "order by" not in str(create_session().query(User).order_by(None).statement).lower() - assert "order by users.name asc" in str(create_session().query(User).order_by(User.name.asc()).statement).lower() - - eq_( - create_session().query(User).all(), - [User(id=7, name=u'jack'), User(id=9, name=u'fred'), User(id=8, name=u'ed'), User(id=10, name=u'chuck')] - ) - - eq_( - create_session().query(User).order_by(User.name).all(), - [User(id=10, name=u'chuck'), User(id=8, name=u'ed'), User(id=9, name=u'fred'), User(id=7, name=u'jack')] - ) - - # 'Raises a "expression evaluation not supported" error at prepare time - @testing.fails_on('firebird', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_function(self): - """Mapping to a SELECT statement that has functions in it.""" - - s = sa.select([users, - (users.c.id * 2).label('concat'), - sa.func.count(addresses.c.id).label('count')], - users.c.id == addresses.c.user_id, - group_by=[c for c in users.c]).alias('myselect') - - mapper(User, s, order_by=s.c.id) - sess = create_session() - l = sess.query(User).all() - - for idx, total in enumerate((14, 16)): - eq_(l[idx].concat, l[idx].id * 2) - eq_(l[idx].concat, total) - - @testing.resolve_artifact_names - def test_count(self): - """The count function on Query.""" - - mapper(User, users) - - session = create_session() - q = session.query(User) - - eq_(q.count(), 4) - eq_(q.filter(User.id.in_([8,9])).count(), 2) - eq_(q.filter(users.c.id.in_([8,9])).count(), 2) - - eq_(session.query(User.id).count(), 4) - eq_(session.query(User.id).filter(User.id.in_((8, 9))).count(), 2) - - @testing.resolve_artifact_names - def test_many_to_many_count(self): - mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords = relation(Keyword, item_keywords, lazy=True))) - - session = create_session() - q = (session.query(Item). - join('keywords'). - distinct(). - filter(Keyword.name == "red")) - eq_(q.count(), 2) - - @testing.resolve_artifact_names - def test_override_1(self): - """Overriding a column raises an error.""" - def go(): - mapper(User, users, - properties=dict( - name=relation(mapper(Address, addresses)))) - - self.assertRaises(sa.exc.ArgumentError, go) - - @testing.resolve_artifact_names - def test_override_2(self): - """exclude_properties cancels the error.""" - - mapper(User, users, - exclude_properties=['name'], - properties=dict( - name=relation(mapper(Address, addresses)))) - - assert bool(User.name) - - @testing.resolve_artifact_names - def test_override_3(self): - """The column being named elsewhere also cancels the error,""" - mapper(User, users, - properties=dict( - name=relation(mapper(Address, addresses)), - foo=users.c.name)) - - @testing.resolve_artifact_names - def test_synonym(self): - - assert_col = [] - class extendedproperty(property): - attribute = 123 - def __getitem__(self, key): - return 'value' - - class User(object): - def _get_name(self): - assert_col.append(('get', self.name)) - return self.name - def _set_name(self, name): - assert_col.append(('set', name)) - self.name = name - uname = extendedproperty(_get_name, _set_name) - - mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=True), - uname = synonym('name'), - adlist = synonym('addresses', proxy=True), - adname = synonym('addresses') - )) - - # ensure the synonym can get at the proxied comparators without - # an explicit compile - User.name == 'ed' - User.adname.any() - - assert hasattr(User, 'adlist') - # as of 0.4.2, synonyms always create a property - assert hasattr(User, 'adname') - - # test compile - assert not isinstance(User.uname == 'jack', bool) - - assert User.uname.property - assert User.adlist.property - - sess = create_session() - - # test RowTuple names - row = sess.query(User.id, User.uname).first() - assert row.uname == row[1] - - u = sess.query(User).filter(User.uname=='jack').one() - - fixture = self.static.user_address_result[0].addresses - eq_(u.adlist, fixture) - - addr = sess.query(Address).filter_by(id=fixture[0].id).one() - u = sess.query(User).filter(User.adname.contains(addr)).one() - u2 = sess.query(User).filter(User.adlist.contains(addr)).one() - - assert u is u2 - - assert u not in sess.dirty - u.uname = "some user name" - assert len(assert_col) > 0 - eq_(assert_col, [('set', 'some user name')]) - eq_(u.uname, "some user name") - eq_(assert_col, [('set', 'some user name'), ('get', 'some user name')]) - eq_(u.name, "some user name") - assert u in sess.dirty - - eq_(User.uname.attribute, 123) - eq_(User.uname['key'], 'value') - - @testing.resolve_artifact_names - def test_synonym_column_location(self): - def go(): - mapper(User, users, properties={ - 'not_name':synonym('_name', map_column=True)}) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - ("Can't compile synonym '_name': no column on table " - "'users' named 'not_name'"), - go) - - @testing.resolve_artifact_names - def test_column_synonyms(self): - """Synonyms which automatically instrument properties, set up aliased column, etc.""" - - - assert_col = [] - class User(object): - def _get_name(self): - assert_col.append(('get', self._name)) - return self._name - def _set_name(self, name): - assert_col.append(('set', name)) - self._name = name - name = property(_get_name, _set_name) - - mapper(Address, addresses) - mapper(User, users, properties = { - 'addresses':relation(Address, lazy=True), - 'name':synonym('_name', map_column=True) - }) - - # test compile - assert not isinstance(User.name == 'jack', bool) - - assert hasattr(User, 'name') - assert hasattr(User, '_name') - - sess = create_session() - u = sess.query(User).filter(User.name == 'jack').one() - eq_(u.name, 'jack') - u.name = 'foo' - eq_(u.name, 'foo') - eq_(assert_col, [('get', 'jack'), ('set', 'foo'), ('get', 'foo')]) - - @testing.resolve_artifact_names - def test_comparable(self): - class extendedproperty(property): - attribute = 123 - - def method1(self): - return "method1" - - def __getitem__(self, key): - return 'value' - - class UCComparator(sa.orm.PropComparator): - __hash__ = None - - def method1(self): - return "uccmethod1" - - def method2(self, other): - return "method2" - - def __eq__(self, other): - cls = self.prop.parent.class_ - col = getattr(cls, 'name') - if other is None: - return col == None - else: - return sa.func.upper(col) == sa.func.upper(other) - - def map_(with_explicit_property): - class User(object): - @extendedproperty - def uc_name(self): - if self.name is None: - return None - return self.name.upper() - if with_explicit_property: - args = (UCComparator, User.uc_name) - else: - args = (UCComparator,) - - mapper(User, users, properties=dict( - uc_name = sa.orm.comparable_property(*args))) - return User - - for User in (map_(True), map_(False)): - sess = create_session() - sess.begin() - q = sess.query(User) - - assert hasattr(User, 'name') - assert hasattr(User, 'uc_name') - - eq_(User.uc_name.method1(), "method1") - eq_(User.uc_name.method2('x'), "method2") - - self.assertRaisesMessage( - AttributeError, - "Neither 'extendedproperty' object nor 'UCComparator' object has an attribute 'nonexistent'", - getattr, User.uc_name, 'nonexistent') - - # test compile - assert not isinstance(User.uc_name == 'jack', bool) - u = q.filter(User.uc_name=='JACK').one() - - assert u.uc_name == "JACK" - assert u not in sess.dirty - - u.name = "some user name" - eq_(u.name, "some user name") - assert u in sess.dirty - eq_(u.uc_name, "SOME USER NAME") - - sess.flush() - sess.expunge_all() - - q = sess.query(User) - u2 = q.filter(User.name=='some user name').one() - u3 = q.filter(User.uc_name=='SOME USER NAME').one() - - assert u2 is u3 - - eq_(User.uc_name.attribute, 123) - eq_(User.uc_name['key'], 'value') - sess.rollback() - - @testing.resolve_artifact_names - def test_comparable_column(self): - class MyComparator(sa.orm.properties.ColumnProperty.Comparator): - def __eq__(self, other): - # lower case comparison - return func.lower(self.__clause_element__()) == func.lower(other) - - def intersects(self, other): - # non-standard comparator - return self.__clause_element__().op('&=')(other) - - mapper(User, users, properties={ - 'name':sa.orm.column_property(users.c.name, comparator_factory=MyComparator) - }) - - self.assertRaisesMessage( - AttributeError, - "Neither 'InstrumentedAttribute' object nor 'MyComparator' object has an attribute 'nonexistent'", - getattr, User.name, "nonexistent") - - eq_(str((User.name == 'ed').compile(dialect=sa.engine.default.DefaultDialect())) , "lower(users.name) = lower(:lower_1)") - eq_(str((User.name.intersects('ed')).compile(dialect=sa.engine.default.DefaultDialect())), "users.name &= :name_1") - - - @testing.resolve_artifact_names - def test_reconstructor(self): - recon = [] - - class User(object): - @reconstructor - def reconstruct(self): - recon.append('go') - - mapper(User, users) - - User() - eq_(recon, []) - create_session().query(User).first() - eq_(recon, ['go']) - - @testing.resolve_artifact_names - def test_reconstructor_inheritance(self): - recon = [] - class A(object): - @reconstructor - def reconstruct(self): - recon.append('A') - - class B(A): - @reconstructor - def reconstruct(self): - recon.append('B') - - class C(A): - @reconstructor - def reconstruct(self): - recon.append('C') - - mapper(A, users, polymorphic_on=users.c.name, - polymorphic_identity='jack') - mapper(B, inherits=A, polymorphic_identity='ed') - mapper(C, inherits=A, polymorphic_identity='chuck') - - A() - B() - C() - eq_(recon, []) - - sess = create_session() - sess.query(A).first() - sess.query(B).first() - sess.query(C).first() - eq_(recon, ['A', 'B', 'C']) - - @testing.resolve_artifact_names - def test_unmapped_reconstructor_inheritance(self): - recon = [] - class Base(object): - @reconstructor - def reconstruct(self): - recon.append('go') - - class User(Base): - pass - - mapper(User, users) - - User() - eq_(recon, []) - - create_session().query(User).first() - eq_(recon, ['go']) - - @testing.resolve_artifact_names - def test_unmapped_error(self): - mapper(Address, addresses) - sa.orm.clear_mappers() - - mapper(User, users, properties={ - 'addresses':relation(Address) - }) - - self.assertRaises(sa.orm.exc.UnmappedClassError, sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_oldstyle_mixin(self): - class OldStyle: - pass - class NewStyle(object): - pass - - class A(NewStyle, OldStyle): - pass - - mapper(A, users) - - class B(OldStyle, NewStyle): - pass - - mapper(B, users) - - -class OptionsTest(_fixtures.FixtureTest): - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_synonym_options(self): - mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=True, - order_by=addresses.c.id), - adlist = synonym('addresses', proxy=True))) - - - def go(): - sess = create_session() - u = (sess.query(User). - order_by(User.id). - options(sa.orm.eagerload('adlist')). - filter_by(name='jack')).one() - eq_(u.adlist, - [self.static.user_address_result[0].addresses[0]]) - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_eager_options(self): - """A lazy relation can be upgraded to an eager relation.""" - mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), - order_by=addresses.c.id))) - - sess = create_session() - l = (sess.query(User). - order_by(User.id). - options(sa.orm.eagerload('addresses'))).all() - - def go(): - eq_(l, self.static.user_address_result) - self.sql_count_(0, go) - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_eager_options_with_limit(self): - mapper(User, users, properties=dict( - addresses=relation(mapper(Address, addresses), lazy=True))) - - sess = create_session() - u = (sess.query(User). - options(sa.orm.eagerload('addresses')). - filter_by(id=8)).one() - - def go(): - eq_(u.id, 8) - eq_(len(u.addresses), 3) - self.sql_count_(0, go) - - sess.expunge_all() - - u = sess.query(User).filter_by(id=8).one() - eq_(u.id, 8) - eq_(len(u.addresses), 3) - - @testing.fails_on('maxdb', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_lazy_options_with_limit(self): - mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=False))) - - sess = create_session() - u = (sess.query(User). - options(sa.orm.lazyload('addresses')). - filter_by(id=8)).one() - - def go(): - eq_(u.id, 8) - eq_(len(u.addresses), 3) - self.sql_count_(1, go) - - @testing.resolve_artifact_names - def test_eager_degrade(self): - """An eager relation automatically degrades to a lazy relation if eager columns are not available""" - mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=False))) - - sess = create_session() - # first test straight eager load, 1 statement - def go(): - l = sess.query(User).order_by(User.id).all() - eq_(l, self.static.user_address_result) - self.sql_count_(1, go) - - sess.expunge_all() - - # then select just from users. run it into instances. - # then assert the data, which will launch 3 more lazy loads - # (previous users in session fell out of scope and were removed from - # session's identity map) - r = users.select().order_by(users.c.id).execute() - def go(): - l = list(sess.query(User).instances(r)) - eq_(l, self.static.user_address_result) - self.sql_count_(4, go) - - - @testing.resolve_artifact_names - def test_eager_degrade_deep(self): - # test with a deeper set of eager loads. when we first load the three - # users, they will have no addresses or orders. the number of lazy - # loads when traversing the whole thing will be three for the - # addresses and three for the orders. - mapper(Address, addresses) - - mapper(Keyword, keywords) - - mapper(Item, items, properties=dict( - keywords=relation(Keyword, secondary=item_keywords, - lazy=False, - order_by=item_keywords.c.keyword_id))) - - mapper(Order, orders, properties=dict( - items=relation(Item, secondary=order_items, lazy=False, - order_by=order_items.c.item_id))) - - mapper(User, users, properties=dict( - addresses=relation(Address, lazy=False, - order_by=addresses.c.id), - orders=relation(Order, lazy=False, - order_by=orders.c.id))) - - sess = create_session() - - # first test straight eager load, 1 statement - def go(): - l = sess.query(User).order_by(User.id).all() - eq_(l, self.static.user_all_result) - self.assert_sql_count(testing.db, go, 1) - - sess.expunge_all() - - # then select just from users. run it into instances. - # then assert the data, which will launch 6 more lazy loads - r = users.select().execute() - def go(): - l = list(sess.query(User).instances(r)) - eq_(l, self.static.user_all_result) - self.assert_sql_count(testing.db, go, 6) - - @testing.resolve_artifact_names - def test_lazy_options(self): - """An eager relation can be upgraded to a lazy relation.""" - mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=False) - )) - - sess = create_session() - l = (sess.query(User). - order_by(User.id). - options(sa.orm.lazyload('addresses'))).all() - - def go(): - eq_(l, self.static.user_address_result) - self.sql_count_(4, go) - - -class DeepOptionsTest(_fixtures.FixtureTest): - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Keyword, keywords) - - mapper(Item, items, properties=dict( - keywords=relation(Keyword, item_keywords, - order_by=item_keywords.c.item_id))) - - mapper(Order, orders, properties=dict( - items=relation(Item, order_items, - order_by=items.c.id))) - - mapper(User, users, order_by=users.c.id, properties=dict( - orders=relation(Order, order_by=orders.c.id))) - - @testing.resolve_artifact_names - def test_deep_options_1(self): - sess = create_session() - - # eagerload nothing. - u = sess.query(User).all() - def go(): - x = u[0].orders[1].items[0].keywords[1] - self.assert_sql_count(testing.db, go, 3) - - @testing.resolve_artifact_names - def test_deep_options_2(self): - sess = create_session() - - # eagerload orders.items.keywords; eagerload_all() implies eager load - # of orders, orders.items - l = (sess.query(User). - options(sa.orm.eagerload_all('orders.items.keywords'))).all() - def go(): - x = l[0].orders[1].items[0].keywords[1] - self.sql_count_(0, go) - - - @testing.resolve_artifact_names - def test_deep_options_3(self): - sess = create_session() - - # same thing, with separate options calls - q2 = (sess.query(User). - options(sa.orm.eagerload('orders')). - options(sa.orm.eagerload('orders.items')). - options(sa.orm.eagerload('orders.items.keywords'))) - u = q2.all() - def go(): - x = u[0].orders[1].items[0].keywords[1] - self.sql_count_(0, go) - - @testing.resolve_artifact_names - def test_deep_options_4(self): - sess = create_session() - - self.assertRaisesMessage( - sa.exc.ArgumentError, - r"Can't find entity Mapper\|Order\|orders in Query. " - r"Current list: \['Mapper\|User\|users'\]", - sess.query(User).options, sa.orm.eagerload(Order.items)) - - # eagerload "keywords" on items. it will lazy load "orders", then - # lazy load the "items" on the order, but on "items" it will eager - # load the "keywords" - q3 = sess.query(User).options(sa.orm.eagerload('orders.items.keywords')) - u = q3.all() - def go(): - x = u[0].orders[1].items[0].keywords[1] - self.sql_count_(2, go) - -class ValidatorTest(_fixtures.FixtureTest): - @testing.resolve_artifact_names - def test_scalar(self): - class User(_base.ComparableEntity): - @validates('name') - def validate_name(self, key, name): - assert name != 'fred' - return name + ' modified' - - mapper(User, users) - sess = create_session() - u1 = User(name='ed') - eq_(u1.name, 'ed modified') - self.assertRaises(AssertionError, setattr, u1, "name", "fred") - eq_(u1.name, 'ed modified') - sess.add(u1) - sess.flush() - sess.expunge_all() - eq_(sess.query(User).filter_by(name='ed modified').one(), User(name='ed')) - - - @testing.resolve_artifact_names - def test_collection(self): - class User(_base.ComparableEntity): - @validates('addresses') - def validate_address(self, key, ad): - assert '@' in ad.email_address - return ad - - mapper(User, users, properties={'addresses':relation(Address)}) - mapper(Address, addresses) - sess = create_session() - u1 = User(name='edward') - self.assertRaises(AssertionError, u1.addresses.append, Address(email_address='noemail')) - u1.addresses.append(Address(id=15, email_address='foo@bar.com')) - sess.add(u1) - sess.flush() - sess.expunge_all() - eq_( - sess.query(User).filter_by(name='edward').one(), - User(name='edward', addresses=[Address(email_address='foo@bar.com')]) - ) - -class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): - @testing.resolve_artifact_names - def test_kwarg_accepted(self): - class DummyComposite(object): - def __init__(self, x, y): - pass - - from sqlalchemy.orm.interfaces import PropComparator - - class MyFactory(PropComparator): - pass - - for args in ( - (column_property, users.c.name), - (deferred, users.c.name), - (synonym, 'name'), - (composite, DummyComposite, users.c.id, users.c.name), - (relation, Address), - (backref, 'address'), - (comparable_property, ), - (dynamic_loader, Address) - ): - fn = args[0] - args = args[1:] - fn(comparator_factory=MyFactory, *args) - - @testing.resolve_artifact_names - def test_column(self): - from sqlalchemy.orm.properties import ColumnProperty - - class MyFactory(ColumnProperty.Comparator): - __hash__ = None - def __eq__(self, other): - return func.foobar(self.__clause_element__()) == func.foobar(other) - mapper(User, users, properties={'name':column_property(users.c.name, comparator_factory=MyFactory)}) - self.assert_compile(User.name == 'ed', "foobar(users.name) = foobar(:foobar_1)", dialect=default.DefaultDialect()) - self.assert_compile(aliased(User).name == 'ed', "foobar(users_1.name) = foobar(:foobar_1)", dialect=default.DefaultDialect()) - - @testing.resolve_artifact_names - def test_synonym(self): - from sqlalchemy.orm.properties import ColumnProperty - class MyFactory(ColumnProperty.Comparator): - __hash__ = None - def __eq__(self, other): - return func.foobar(self.__clause_element__()) == func.foobar(other) - mapper(User, users, properties={'name':synonym('_name', map_column=True, comparator_factory=MyFactory)}) - self.assert_compile(User.name == 'ed', "foobar(users.name) = foobar(:foobar_1)", dialect=default.DefaultDialect()) - self.assert_compile(aliased(User).name == 'ed', "foobar(users_1.name) = foobar(:foobar_1)", dialect=default.DefaultDialect()) - - @testing.resolve_artifact_names - def test_relation(self): - from sqlalchemy.orm.properties import PropertyLoader - - class MyFactory(PropertyLoader.Comparator): - __hash__ = None - def __eq__(self, other): - return func.foobar(self.__clause_element__().c.user_id) == func.foobar(other.id) - - class MyFactory2(PropertyLoader.Comparator): - __hash__ = None - def __eq__(self, other): - return func.foobar(self.__clause_element__().c.id) == func.foobar(other.user_id) - - mapper(User, users) - mapper(Address, addresses, properties={ - 'user':relation(User, comparator_factory=MyFactory, - backref=backref("addresses", comparator_factory=MyFactory2) - ) - } - ) - self.assert_compile(Address.user == User(id=5), "foobar(addresses.user_id) = foobar(:foobar_1)", dialect=default.DefaultDialect()) - self.assert_compile(User.addresses == Address(id=5, user_id=7), "foobar(users.id) = foobar(:foobar_1)", dialect=default.DefaultDialect()) - - self.assert_compile(aliased(Address).user == User(id=5), "foobar(addresses_1.user_id) = foobar(:foobar_1)", dialect=default.DefaultDialect()) - self.assert_compile(aliased(User).addresses == Address(id=5, user_id=7), "foobar(users_1.id) = foobar(:foobar_1)", dialect=default.DefaultDialect()) - - -class DeferredTest(_fixtures.FixtureTest): - - @testing.resolve_artifact_names - def test_basic(self): - """A basic deferred load.""" - - mapper(Order, orders, order_by=orders.c.id, properties={ - 'description': deferred(orders.c.description)}) - - o = Order() - self.assert_(o.description is None) - - q = create_session().query(Order) - def go(): - l = q.all() - o2 = l[2] - x = o2.description - - self.sql_eq_(go, [ - ("SELECT orders.id AS orders_id, " - "orders.user_id AS orders_user_id, " - "orders.address_id AS orders_address_id, " - "orders.isopen AS orders_isopen " - "FROM orders ORDER BY orders.id", {}), - ("SELECT orders.description AS orders_description " - "FROM orders WHERE orders.id = :param_1", - {'param_1':3})]) - - @testing.resolve_artifact_names - def test_unsaved(self): - """Deferred loading does not kick in when just PK cols are set.""" - - mapper(Order, orders, properties={ - 'description': deferred(orders.c.description)}) - - sess = create_session() - o = Order() - sess.add(o) - o.id = 7 - def go(): - o.description = "some description" - self.sql_count_(0, go) - - @testing.resolve_artifact_names - def test_synonym_group_bug(self): - mapper(Order, orders, properties={ - 'isopen':synonym('_isopen', map_column=True), - 'description':deferred(orders.c.description, group='foo') - }) - - sess = create_session() - o1 = sess.query(Order).get(1) - eq_(o1.description, "order 1") - - @testing.resolve_artifact_names - def test_unsaved_2(self): - mapper(Order, orders, properties={ - 'description': deferred(orders.c.description)}) - - sess = create_session() - o = Order() - sess.add(o) - def go(): - o.description = "some description" - self.sql_count_(0, go) - - @testing.resolve_artifact_names - def test_unsaved_group(self): - """Deferred loading doesnt kick in when just PK cols are set""" - - mapper(Order, orders, order_by=orders.c.id, properties=dict( - description=deferred(orders.c.description, group='primary'), - opened=deferred(orders.c.isopen, group='primary'))) - - sess = create_session() - o = Order() - sess.add(o) - o.id = 7 - def go(): - o.description = "some description" - self.sql_count_(0, go) - - @testing.resolve_artifact_names - def test_unsaved_group_2(self): - mapper(Order, orders, order_by=orders.c.id, properties=dict( - description=deferred(orders.c.description, group='primary'), - opened=deferred(orders.c.isopen, group='primary'))) - - sess = create_session() - o = Order() - sess.add(o) - def go(): - o.description = "some description" - self.sql_count_(0, go) - - @testing.resolve_artifact_names - def test_save(self): - m = mapper(Order, orders, properties={ - 'description': deferred(orders.c.description)}) - - sess = create_session() - o2 = sess.query(Order).get(2) - o2.isopen = 1 - sess.flush() - - @testing.resolve_artifact_names - def test_group(self): - """Deferred load with a group""" - mapper(Order, orders, properties={ - 'userident': deferred(orders.c.user_id, group='primary'), - 'addrident': deferred(orders.c.address_id, group='primary'), - 'description': deferred(orders.c.description, group='primary'), - 'opened': deferred(orders.c.isopen, group='primary') - }) - - sess = create_session() - q = sess.query(Order).order_by(Order.id) - def go(): - l = q.all() - o2 = l[2] - eq_(o2.opened, 1) - eq_(o2.userident, 7) - eq_(o2.description, 'order 3') - - self.sql_eq_(go, [ - ("SELECT orders.id AS orders_id " - "FROM orders ORDER BY orders.id", {}), - ("SELECT orders.user_id AS orders_user_id, " - "orders.address_id AS orders_address_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen " - "FROM orders WHERE orders.id = :param_1", - {'param_1':3})]) - - o2 = q.all()[2] - eq_(o2.description, 'order 3') - assert o2 not in sess.dirty - o2.description = 'order 3' - def go(): - sess.flush() - self.sql_count_(0, go) - - @testing.resolve_artifact_names - def test_preserve_changes(self): - """A deferred load operation doesn't revert modifications on attributes""" - mapper(Order, orders, properties = { - 'userident': deferred(orders.c.user_id, group='primary'), - 'description': deferred(orders.c.description, group='primary'), - 'opened': deferred(orders.c.isopen, group='primary') - }) - sess = create_session() - o = sess.query(Order).get(3) - assert 'userident' not in o.__dict__ - o.description = 'somenewdescription' - eq_(o.description, 'somenewdescription') - def go(): - eq_(o.opened, 1) - self.assert_sql_count(testing.db, go, 1) - eq_(o.description, 'somenewdescription') - assert o in sess.dirty - - @testing.resolve_artifact_names - def test_commits_state(self): - """ - When deferred elements are loaded via a group, they get the proper - CommittedState and don't result in changes being committed - - """ - mapper(Order, orders, properties = { - 'userident':deferred(orders.c.user_id, group='primary'), - 'description':deferred(orders.c.description, group='primary'), - 'opened':deferred(orders.c.isopen, group='primary')}) - - sess = create_session() - o2 = sess.query(Order).get(3) - - # this will load the group of attributes - eq_(o2.description, 'order 3') - assert o2 not in sess.dirty - # this will mark it as 'dirty', but nothing actually changed - o2.description = 'order 3' - # therefore the flush() shouldnt actually issue any SQL - self.assert_sql_count(testing.db, sess.flush, 0) - - @testing.resolve_artifact_names - def test_options(self): - """Options on a mapper to create deferred and undeferred columns""" - - mapper(Order, orders) - - sess = create_session() - q = sess.query(Order).order_by(Order.id).options(defer('user_id')) - - def go(): - q.all()[0].user_id - - self.sql_eq_(go, [ - ("SELECT orders.id AS orders_id, " - "orders.address_id AS orders_address_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen " - "FROM orders ORDER BY orders.id", {}), - ("SELECT orders.user_id AS orders_user_id " - "FROM orders WHERE orders.id = :param_1", - {'param_1':1})]) - sess.expunge_all() - - q2 = q.options(sa.orm.undefer('user_id')) - self.sql_eq_(q2.all, [ - ("SELECT orders.id AS orders_id, " - "orders.user_id AS orders_user_id, " - "orders.address_id AS orders_address_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen " - "FROM orders ORDER BY orders.id", - {})]) - - @testing.resolve_artifact_names - def test_undefer_group(self): - mapper(Order, orders, properties={ - 'userident':deferred(orders.c.user_id, group='primary'), - 'description':deferred(orders.c.description, group='primary'), - 'opened':deferred(orders.c.isopen, group='primary')}) - - sess = create_session() - q = sess.query(Order).order_by(Order.id) - def go(): - l = q.options(sa.orm.undefer_group('primary')).all() - o2 = l[2] - eq_(o2.opened, 1) - eq_(o2.userident, 7) - eq_(o2.description, 'order 3') - - self.sql_eq_(go, [ - ("SELECT orders.user_id AS orders_user_id, " - "orders.description AS orders_description, " - "orders.isopen AS orders_isopen, " - "orders.id AS orders_id, " - "orders.address_id AS orders_address_id " - "FROM orders ORDER BY orders.id", - {})]) - - @testing.resolve_artifact_names - def test_locates_col(self): - """Manually adding a column to the result undefers the column.""" - - mapper(Order, orders, properties={ - 'description':deferred(orders.c.description)}) - - sess = create_session() - o1 = sess.query(Order).order_by(Order.id).first() - def go(): - eq_(o1.description, 'order 1') - self.sql_count_(1, go) - - sess = create_session() - o1 = (sess.query(Order). - order_by(Order.id). - add_column(orders.c.description).first())[0] - def go(): - eq_(o1.description, 'order 1') - self.sql_count_(0, go) - - @testing.resolve_artifact_names - def test_deep_options(self): - mapper(Item, items, properties=dict( - description=deferred(items.c.description))) - mapper(Order, orders, properties=dict( - items=relation(Item, secondary=order_items))) - mapper(User, users, properties=dict( - orders=relation(Order, order_by=orders.c.id))) - - sess = create_session() - q = sess.query(User).order_by(User.id) - l = q.all() - item = l[0].orders[1].items[1] - def go(): - eq_(item.description, 'item 4') - self.sql_count_(1, go) - eq_(item.description, 'item 4') - - sess.expunge_all() - l = q.options(sa.orm.undefer('orders.items.description')).all() - item = l[0].orders[1].items[1] - def go(): - eq_(item.description, 'item 4') - self.sql_count_(0, go) - eq_(item.description, 'item 4') - -class DeferredPopulationTest(_base.MappedTest): - def define_tables(self, metadata): - Table("thing", metadata, - Column("id", Integer, primary_key=True), - Column("name", String(20))) - - Table("human", metadata, - Column("id", Integer, primary_key=True), - Column("thing_id", Integer, ForeignKey("thing.id")), - Column("name", String(20))) - - @testing.resolve_artifact_names - def setup_mappers(self): - class Human(_base.BasicEntity): pass - class Thing(_base.BasicEntity): pass - - mapper(Human, human, properties={"thing": relation(Thing)}) - mapper(Thing, thing, properties={"name": deferred(thing.c.name)}) - - @testing.resolve_artifact_names - def insert_data(self): - thing.insert().execute([ - {"id": 1, "name": "Chair"}, - ]) - - human.insert().execute([ - {"id": 1, "thing_id": 1, "name": "Clark Kent"}, - ]) - - def _test(self, thing): - assert "name" in attributes.instance_state(thing).dict - - @testing.resolve_artifact_names - def test_no_previous_query(self): - session = create_session() - thing = session.query(Thing).options(sa.orm.undefer("name")).first() - self._test(thing) - - @testing.resolve_artifact_names - def test_query_twice_with_clear(self): - session = create_session() - result = session.query(Thing).first() - session.expunge_all() - thing = session.query(Thing).options(sa.orm.undefer("name")).first() - self._test(thing) - - @testing.resolve_artifact_names - def test_query_twice_no_clear(self): - session = create_session() - result = session.query(Thing).first() - thing = session.query(Thing).options(sa.orm.undefer("name")).first() - self._test(thing) - - @testing.resolve_artifact_names - def test_eagerload_with_clear(self): - session = create_session() - human = session.query(Human).options(sa.orm.eagerload("thing")).first() - session.expunge_all() - thing = session.query(Thing).options(sa.orm.undefer("name")).first() - self._test(thing) - - @testing.resolve_artifact_names - def test_eagerload_no_clear(self): - session = create_session() - human = session.query(Human).options(sa.orm.eagerload("thing")).first() - thing = session.query(Thing).options(sa.orm.undefer("name")).first() - self._test(thing) - - @testing.resolve_artifact_names - def test_join_with_clear(self): - session = create_session() - result = session.query(Human).add_entity(Thing).join("thing").first() - session.expunge_all() - thing = session.query(Thing).options(sa.orm.undefer("name")).first() - self._test(thing) - - @testing.resolve_artifact_names - def test_join_no_clear(self): - session = create_session() - result = session.query(Human).add_entity(Thing).join("thing").first() - thing = session.query(Thing).options(sa.orm.undefer("name")).first() - self._test(thing) - - -class CompositeTypesTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('graphs', metadata, - Column('id', Integer, primary_key=True), - Column('version_id', Integer, primary_key=True, nullable=True), - Column('name', String(30))) - - Table('edges', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('graph_id', Integer, nullable=False), - Column('graph_version_id', Integer, nullable=False), - Column('x1', Integer), - Column('y1', Integer), - Column('x2', Integer), - Column('y2', Integer), - sa.ForeignKeyConstraint( - ['graph_id', 'graph_version_id'], - ['graphs.id', 'graphs.version_id'])) - - Table('foobars', metadata, - Column('id', Integer, primary_key=True), - Column('x1', Integer, default=2), - Column('x2', Integer), - Column('x3', Integer, default=15), - Column('x4', Integer) - ) - - @testing.resolve_artifact_names - def test_basic(self): - class Point(object): - def __init__(self, x, y): - self.x = x - self.y = y - def __composite_values__(self): - return [self.x, self.y] - __hash__ = None - def __eq__(self, other): - return isinstance(other, Point) and other.x == self.x and other.y == self.y - def __ne__(self, other): - return not isinstance(other, Point) or not self.__eq__(other) - - class Graph(object): - pass - class Edge(object): - def __init__(self, start, end): - self.start = start - self.end = end - - mapper(Graph, graphs, properties={ - 'edges':relation(Edge) - }) - mapper(Edge, edges, properties={ - 'start':sa.orm.composite(Point, edges.c.x1, edges.c.y1), - 'end': sa.orm.composite(Point, edges.c.x2, edges.c.y2) - }) - - sess = create_session() - g = Graph() - g.id = 1 - g.version_id=1 - g.edges.append(Edge(Point(3, 4), Point(5, 6))) - g.edges.append(Edge(Point(14, 5), Point(2, 7))) - sess.add(g) - sess.flush() - - sess.expunge_all() - g2 = sess.query(Graph).get([g.id, g.version_id]) - for e1, e2 in zip(g.edges, g2.edges): - eq_(e1.start, e2.start) - eq_(e1.end, e2.end) - - g2.edges[1].end = Point(18, 4) - sess.flush() - sess.expunge_all() - e = sess.query(Edge).get(g2.edges[1].id) - eq_(e.end, Point(18, 4)) - - e.end.x = 19 - e.end.y = 5 - sess.flush() - sess.expunge_all() - eq_(sess.query(Edge).get(g2.edges[1].id).end, Point(19, 5)) - - g.edges[1].end = Point(19, 5) - - sess.expunge_all() - def go(): - g2 = (sess.query(Graph). - options(sa.orm.eagerload('edges'))).get([g.id, g.version_id]) - for e1, e2 in zip(g.edges, g2.edges): - eq_(e1.start, e2.start) - eq_(e1.end, e2.end) - self.assert_sql_count(testing.db, go, 1) - - # test comparison of CompositeProperties to their object instances - g = sess.query(Graph).get([1, 1]) - assert sess.query(Edge).filter(Edge.start==Point(3, 4)).one() is g.edges[0] - - assert sess.query(Edge).filter(Edge.start!=Point(3, 4)).first() is g.edges[1] - - eq_(sess.query(Edge).filter(Edge.start==None).all(), []) - - # query by columns - eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)]) - - e = g.edges[1] - e.end.x = e.end.y = None - sess.flush() - eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, None, None)]) - - - @testing.resolve_artifact_names - def test_pk(self): - """Using a composite type as a primary key""" - - class Version(object): - def __init__(self, id, version): - self.id = id - self.version = version - def __composite_values__(self): - return (self.id, self.version) - __hash__ = None - def __eq__(self, other): - return other.id == self.id and other.version == self.version - def __ne__(self, other): - return not self.__eq__(other) - - class Graph(object): - def __init__(self, version): - self.version = version - - mapper(Graph, graphs, allow_null_pks=True, properties={ - 'version':sa.orm.composite(Version, graphs.c.id, - graphs.c.version_id)}) - - sess = create_session() - g = Graph(Version(1, 1)) - sess.add(g) - sess.flush() - - sess.expunge_all() - g2 = sess.query(Graph).get([1, 1]) - eq_(g.version, g2.version) - sess.expunge_all() - - g2 = sess.query(Graph).get(Version(1, 1)) - eq_(g.version, g2.version) - - # test pk mutation - @testing.fails_on('mssql', 'Cannot update identity columns.') - def update_pk(): - g2.version = Version(2, 1) - sess.flush() - g3 = sess.query(Graph).get(Version(2, 1)) - eq_(g2.version, g3.version) - update_pk() - - # test pk with one column NULL - # TODO: can't seem to get NULL in for a PK value - # in either mysql or postgres, autoincrement=False etc. - # notwithstanding - @testing.fails_on_everything_except("sqlite") - def go(): - g = Graph(Version(2, None)) - sess.add(g) - sess.flush() - sess.expunge_all() - g2 = sess.query(Graph).filter_by(version=Version(2, None)).one() - eq_(g.version, g2.version) - go() - - @testing.resolve_artifact_names - def test_attributes_with_defaults(self): - class Foobar(object): - pass - - class FBComposite(object): - def __init__(self, x1, x2, x3, x4): - self.x1 = x1 - self.x2 = x2 - self.x3 = x3 - self.x4 = x4 - def __composite_values__(self): - return self.x1, self.x2, self.x3, self.x4 - __hash__ = None - def __eq__(self, other): - return other.x1 == self.x1 and other.x2 == self.x2 and other.x3 == self.x3 and other.x4 == self.x4 - def __ne__(self, other): - return not self.__eq__(other) - - mapper(Foobar, foobars, properties=dict( - foob=sa.orm.composite(FBComposite, foobars.c.x1, foobars.c.x2, foobars.c.x3, foobars.c.x4) - )) - - sess = create_session() - f1 = Foobar() - f1.foob = FBComposite(None, 5, None, None) - sess.add(f1) - sess.flush() - - assert f1.foob == FBComposite(2, 5, 15, None) - - - f2 = Foobar() - sess.add(f2) - sess.flush() - assert f2.foob == FBComposite(2, None, 15, None) - - - @testing.resolve_artifact_names - def test_set_composite_values(self): - class Foobar(object): - pass - - class FBComposite(object): - def __init__(self, x1, x2, x3, x4): - self.x1val = x1 - self.x2val = x2 - self.x3 = x3 - self.x4 = x4 - def __composite_values__(self): - return self.x1val, self.x2val, self.x3, self.x4 - def __set_composite_values__(self, x1, x2, x3, x4): - self.x1val = x1 - self.x2val = x2 - self.x3 = x3 - self.x4 = x4 - __hash__ = None - def __eq__(self, other): - return other.x1val == self.x1val and other.x2val == self.x2val and other.x3 == self.x3 and other.x4 == self.x4 - def __ne__(self, other): - return not self.__eq__(other) - - mapper(Foobar, foobars, properties=dict( - foob=sa.orm.composite(FBComposite, foobars.c.x1, foobars.c.x2, foobars.c.x3, foobars.c.x4) - )) - - sess = create_session() - f1 = Foobar() - f1.foob = FBComposite(None, 5, None, None) - sess.add(f1) - sess.flush() - - assert f1.foob == FBComposite(2, 5, 15, None) - - @testing.resolve_artifact_names - def test_save_null(self): - """test saving a null composite value - - See google groups thread for more context: - http://groups.google.com/group/sqlalchemy/browse_thread/thread/0c6580a1761b2c29 - - """ - class Point(object): - def __init__(self, x, y): - self.x = x - self.y = y - def __composite_values__(self): - return [self.x, self.y] - __hash__ = None - def __eq__(self, other): - return other.x == self.x and other.y == self.y - def __ne__(self, other): - return not self.__eq__(other) - - class Graph(object): - pass - class Edge(object): - def __init__(self, start, end): - self.start = start - self.end = end - - mapper(Graph, graphs, properties={ - 'edges':relation(Edge) - }) - mapper(Edge, edges, properties={ - 'start':sa.orm.composite(Point, edges.c.x1, edges.c.y1), - 'end':sa.orm.composite(Point, edges.c.x2, edges.c.y2) - }) - - sess = create_session() - g = Graph() - g.id = 1 - g.version_id=1 - e = Edge(None, None) - g.edges.append(e) - - sess.add(g) - sess.flush() - - sess.expunge_all() - - g2 = sess.query(Graph).get([1, 1]) - assert g2.edges[-1].start.x is None - assert g2.edges[-1].start.y is None - - -class NoLoadTest(_fixtures.FixtureTest): - run_inserts = 'once' - run_deletes = None - - @testing.resolve_artifact_names - def test_basic(self): - """A basic one-to-many lazy load""" - m = mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=None) - )) - q = create_session().query(m) - l = [None] - def go(): - x = q.filter(User.id == 7).all() - x[0].addresses - l[0] = x - self.assert_sql_count(testing.db, go, 1) - - self.assert_result(l[0], User, - {'id' : 7, 'addresses' : (Address, [])}, - ) - - @testing.resolve_artifact_names - def test_options(self): - m = mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=None) - )) - q = create_session().query(m).options(sa.orm.lazyload('addresses')) - l = [None] - def go(): - x = q.filter(User.id == 7).all() - x[0].addresses - l[0] = x - self.sql_count_(2, go) - - self.assert_result(l[0], User, - {'id' : 7, 'addresses' : (Address, [{'id' : 1}])}, - ) - - -class MapperExtensionTest(_fixtures.FixtureTest): - run_inserts = None - - def extension(self): - methods = [] - - class Ext(sa.orm.MapperExtension): - def instrument_class(self, mapper, cls): - methods.append('instrument_class') - return sa.orm.EXT_CONTINUE - - def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): - methods.append('init_instance') - return sa.orm.EXT_CONTINUE - - def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): - methods.append('init_failed') - return sa.orm.EXT_CONTINUE - - def translate_row(self, mapper, context, row): - methods.append('translate_row') - return sa.orm.EXT_CONTINUE - - def create_instance(self, mapper, selectcontext, row, class_): - methods.append('create_instance') - return sa.orm.EXT_CONTINUE - - def reconstruct_instance(self, mapper, instance): - methods.append('reconstruct_instance') - return sa.orm.EXT_CONTINUE - - def append_result(self, mapper, selectcontext, row, instance, result, **flags): - methods.append('append_result') - return sa.orm.EXT_CONTINUE - - def populate_instance(self, mapper, selectcontext, row, instance, **flags): - methods.append('populate_instance') - return sa.orm.EXT_CONTINUE - - def before_insert(self, mapper, connection, instance): - methods.append('before_insert') - return sa.orm.EXT_CONTINUE - - def after_insert(self, mapper, connection, instance): - methods.append('after_insert') - return sa.orm.EXT_CONTINUE - - def before_update(self, mapper, connection, instance): - methods.append('before_update') - return sa.orm.EXT_CONTINUE - - def after_update(self, mapper, connection, instance): - methods.append('after_update') - return sa.orm.EXT_CONTINUE - - def before_delete(self, mapper, connection, instance): - methods.append('before_delete') - return sa.orm.EXT_CONTINUE - - def after_delete(self, mapper, connection, instance): - methods.append('after_delete') - return sa.orm.EXT_CONTINUE - - return Ext, methods - - @testing.resolve_artifact_names - def test_basic(self): - """test that common user-defined methods get called.""" - Ext, methods = self.extension() - - mapper(User, users, extension=Ext()) - sess = create_session() - u = User(name='u1') - sess.add(u) - sess.flush() - u = sess.query(User).populate_existing().get(u.id) - sess.expunge_all() - u = sess.query(User).get(u.id) - u.name = 'u1 changed' - sess.flush() - sess.delete(u) - sess.flush() - eq_(methods, - ['instrument_class', 'init_instance', 'before_insert', - 'after_insert', 'translate_row', 'populate_instance', - 'append_result', 'translate_row', 'create_instance', - 'populate_instance', 'reconstruct_instance', 'append_result', - 'before_update', 'after_update', 'before_delete', 'after_delete']) - - @testing.resolve_artifact_names - def test_inheritance(self): - Ext, methods = self.extension() - - class AdminUser(User): - pass - - mapper(User, users, extension=Ext()) - mapper(AdminUser, addresses, inherits=User) - - sess = create_session() - am = AdminUser(name='au1', email_address='au1@e1') - sess.add(am) - sess.flush() - am = sess.query(AdminUser).populate_existing().get(am.id) - sess.expunge_all() - am = sess.query(AdminUser).get(am.id) - am.name = 'au1 changed' - sess.flush() - sess.delete(am) - sess.flush() - eq_(methods, - ['instrument_class', 'instrument_class', 'init_instance', - 'before_insert', 'after_insert', 'translate_row', - 'populate_instance', 'append_result', 'translate_row', - 'create_instance', 'populate_instance', 'reconstruct_instance', - 'append_result', 'before_update', 'after_update', 'before_delete', - 'after_delete']) - - @testing.resolve_artifact_names - def test_after_with_no_changes(self): - """after_update is called even if no columns were updated.""" - - Ext, methods = self.extension() - - mapper(Item, items, extension=Ext() , properties={ - 'keywords': relation(Keyword, secondary=item_keywords)}) - mapper(Keyword, keywords, extension=Ext()) - - sess = create_session() - i1 = Item(description="i1") - k1 = Keyword(name="k1") - sess.add(i1) - sess.add(k1) - sess.flush() - eq_(methods, - ['instrument_class', 'instrument_class', 'init_instance', - 'init_instance', 'before_insert', 'after_insert', - 'before_insert', 'after_insert']) - - del methods[:] - i1.keywords.append(k1) - sess.flush() - eq_(methods, ['before_update', 'after_update']) - - - @testing.resolve_artifact_names - def test_inheritance_with_dupes(self): - """Inheritance with the same extension instance on both mappers.""" - Ext, methods = self.extension() - - class AdminUser(User): - pass - - ext = Ext() - mapper(User, users, extension=ext) - mapper(AdminUser, addresses, inherits=User, extension=ext) - - sess = create_session() - am = AdminUser(name="au1", email_address="au1@e1") - sess.add(am) - sess.flush() - am = sess.query(AdminUser).populate_existing().get(am.id) - sess.expunge_all() - am = sess.query(AdminUser).get(am.id) - am.name = 'au1 changed' - sess.flush() - sess.delete(am) - sess.flush() - eq_(methods, - ['instrument_class', 'instrument_class', 'init_instance', - 'before_insert', 'after_insert', 'translate_row', - 'populate_instance', 'append_result', 'translate_row', - 'create_instance', 'populate_instance', 'reconstruct_instance', - 'append_result', 'before_update', 'after_update', 'before_delete', - 'after_delete']) - - @testing.resolve_artifact_names - def test_create_instance(self): - class CreateUserExt(sa.orm.MapperExtension): - def create_instance(self, mapper, selectcontext, row, class_): - return User.__new__(User) - - mapper(User, users, extension=CreateUserExt()) - sess = create_session() - u1 = User() - u1.name = 'ed' - sess.add(u1) - sess.flush() - sess.expunge_all() - assert sess.query(User).first() - - -class RequirementsTest(_base.MappedTest): - """Tests the contract for user classes.""" - - def define_tables(self, metadata): - Table('ht1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('value', String(10))) - Table('ht2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('ht1_id', Integer, ForeignKey('ht1.id')), - Column('value', String(10))) - Table('ht3', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('value', String(10))) - Table('ht4', metadata, - Column('ht1_id', Integer, ForeignKey('ht1.id'), - primary_key=True), - Column('ht3_id', Integer, ForeignKey('ht3.id'), - primary_key=True)) - Table('ht5', metadata, - Column('ht1_id', Integer, ForeignKey('ht1.id'), - primary_key=True)) - Table('ht6', metadata, - Column('ht1a_id', Integer, ForeignKey('ht1.id'), - primary_key=True), - Column('ht1b_id', Integer, ForeignKey('ht1.id'), - primary_key=True), - Column('value', String(10))) - - @testing.resolve_artifact_names - def test_baseclass(self): - class OldStyle: - pass - - self.assertRaises(sa.exc.ArgumentError, mapper, OldStyle, ht1) - - self.assertRaises(sa.exc.ArgumentError, mapper, 123) - - class NoWeakrefSupport(str): - pass - - # TODO: is weakref support detectable without an instance? - #self.assertRaises(sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2) - - @testing.resolve_artifact_names - def test_comparison_overrides(self): - """Simple tests to ensure users can supply comparison __methods__. - - The suite-level test --options are better suited to detect - problems- they add selected __methods__ across the board on all - ORM tests. This test simply shoves a variety of operations - through the ORM to catch basic regressions early in a standard - test run. - """ - - # adding these methods directly to each class to avoid decoration - # by the testlib decorators. - class _Base(object): - def __init__(self, value='abc'): - self.value = value - def __nonzero__(self): - return False - def __hash__(self): - return hash(self.value) - def __eq__(self, other): - if isinstance(other, type(self)): - return self.value == other.value - return False - - class H1(_Base): - pass - class H2(_Base): - pass - class H3(_Base): - pass - class H6(_Base): - pass - - mapper(H1, ht1, properties={ - 'h2s': relation(H2, backref='h1'), - 'h3s': relation(H3, secondary=ht4, backref='h1s'), - 'h1s': relation(H1, secondary=ht5, backref='parent_h1'), - 't6a': relation(H6, backref='h1a', - primaryjoin=ht1.c.id==ht6.c.ht1a_id), - 't6b': relation(H6, backref='h1b', - primaryjoin=ht1.c.id==ht6.c.ht1b_id), - }) - mapper(H2, ht2) - mapper(H3, ht3) - mapper(H6, ht6) - - s = create_session() - for i in range(3): - h1 = H1() - s.add(h1) - - h1.h2s.append(H2()) - h1.h3s.extend([H3(), H3()]) - h1.h1s.append(H1()) - - s.flush() - eq_(ht1.count().scalar(), 4) - - h6 = H6() - h6.h1a = h1 - h6.h1b = h1 - - h6 = H6() - h6.h1a = h1 - h6.h1b = x = H1() - assert x in s - - h6.h1b.h2s.append(H2()) - - s.flush() - - h1.h2s.extend([H2(), H2()]) - s.flush() - - h1s = s.query(H1).options(sa.orm.eagerload('h2s')).all() - eq_(len(h1s), 5) - - self.assert_unordered_result(h1s, H1, - {'h2s': []}, - {'h2s': []}, - {'h2s': (H2, [{'value': 'abc'}, - {'value': 'abc'}, - {'value': 'abc'}])}, - {'h2s': []}, - {'h2s': (H2, [{'value': 'abc'}])}) - - h1s = s.query(H1).options(sa.orm.eagerload('h3s')).all() - - eq_(len(h1s), 5) - h1s = s.query(H1).options(sa.orm.eagerload_all('t6a.h1b'), - sa.orm.eagerload('h2s'), - sa.orm.eagerload_all('h3s.h1s')).all() - eq_(len(h1s), 5) - - -class MagicNamesTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('cartographers', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('alias', String(50)), - Column('quip', String(100))) - Table('maps', metadata, - Column('id', Integer, primary_key=True), - Column('cart_id', Integer, - ForeignKey('cartographers.id')), - Column('state', String(2)), - Column('data', sa.Text)) - - def setup_classes(self): - class Cartographer(_base.BasicEntity): - pass - - class Map(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def test_mappish(self): - mapper(Cartographer, cartographers, properties=dict( - query=cartographers.c.quip)) - mapper(Map, maps, properties=dict( - mapper=relation(Cartographer, backref='maps'))) - - c = Cartographer(name='Lenny', alias='The Dude', - query='Where be dragons?') - m = Map(state='AK', mapper=c) - - sess = create_session() - sess.add(c) - sess.flush() - sess.expunge_all() - - for C, M in ((Cartographer, Map), - (sa.orm.aliased(Cartographer), sa.orm.aliased(Map))): - c1 = (sess.query(C). - filter(C.alias=='The Dude'). - filter(C.query=='Where be dragons?')).one() - m1 = sess.query(M).filter(M.mapper==c1).one() - - @testing.resolve_artifact_names - def test_direct_stateish(self): - for reserved in (sa.orm.attributes.ClassManager.STATE_ATTR, - sa.orm.attributes.ClassManager.MANAGER_ATTR): - t = Table('t', sa.MetaData(), - Column('id', Integer, primary_key=True), - Column(reserved, Integer)) - class T(object): - pass - - self.assertRaisesMessage( - KeyError, - ('%r: requested attribute name conflicts with ' - 'instrumentation attribute of the same name.' % reserved), - mapper, T, t) - - @testing.resolve_artifact_names - def test_indirect_stateish(self): - for reserved in (sa.orm.attributes.ClassManager.STATE_ATTR, - sa.orm.attributes.ClassManager.MANAGER_ATTR): - class M(object): - pass - - self.assertRaisesMessage( - KeyError, - ('requested attribute name conflicts with ' - 'instrumentation attribute of the same name'), - mapper, M, maps, properties={ - reserved: maps.c.state}) - - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/merge.py b/test/orm/merge.py deleted file mode 100644 index fd553f2bf..000000000 --- a/test/orm/merge.py +++ /dev/null @@ -1,736 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa.util import OrderedSet -from testlib.sa.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property -from testlib.testing import eq_, ne_ -from orm import _base, _fixtures - - -class MergeTest(_fixtures.FixtureTest): - """Session..merge() functionality""" - - run_inserts = None - - def on_load_tracker(self, cls, canary=None): - if canary is None: - def canary(instance): - canary.called += 1 - canary.called = 0 - - manager = sa.orm.attributes.manager_of_class(cls) - manager.events.add_listener('on_load', canary) - - return canary - - @testing.resolve_artifact_names - def test_transient_to_pending(self): - mapper(User, users) - sess = create_session() - on_load = self.on_load_tracker(User) - - u = User(id=7, name='fred') - eq_(on_load.called, 0) - u2 = sess.merge(u) - eq_(on_load.called, 1) - assert u2 in sess - eq_(u2, User(id=7, name='fred')) - sess.flush() - sess.expunge_all() - eq_(sess.query(User).first(), User(id=7, name='fred')) - - @testing.resolve_artifact_names - def test_transient_to_pending_collection(self): - mapper(User, users, properties={ - 'addresses': relation(Address, backref='user', - collection_class=OrderedSet)}) - mapper(Address, addresses) - on_load = self.on_load_tracker(User) - self.on_load_tracker(Address, on_load) - - u = User(id=7, name='fred', addresses=OrderedSet([ - Address(id=1, email_address='fred1'), - Address(id=2, email_address='fred2'), - ])) - eq_(on_load.called, 0) - - sess = create_session() - sess.merge(u) - eq_(on_load.called, 3) - - merged_users = [e for e in sess if isinstance(e, User)] - eq_(len(merged_users), 1) - assert merged_users[0] is not u - - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).one(), - User(id=7, name='fred', addresses=OrderedSet([ - Address(id=1, email_address='fred1'), - Address(id=2, email_address='fred2'), - ])) - ) - - @testing.resolve_artifact_names - def test_transient_to_persistent(self): - mapper(User, users) - on_load = self.on_load_tracker(User) - - sess = create_session() - u = User(id=7, name='fred') - sess.add(u) - sess.flush() - sess.expunge_all() - - eq_(on_load.called, 0) - - _u2 = u2 = User(id=7, name='fred jones') - eq_(on_load.called, 0) - u2 = sess.merge(u2) - assert u2 is not _u2 - eq_(on_load.called, 1) - sess.flush() - sess.expunge_all() - eq_(sess.query(User).first(), User(id=7, name='fred jones')) - eq_(on_load.called, 2) - - @testing.resolve_artifact_names - def test_transient_to_persistent_collection(self): - mapper(User, users, properties={ - 'addresses':relation(Address, - backref='user', - collection_class=OrderedSet, - cascade="all, delete-orphan") - }) - mapper(Address, addresses) - - on_load = self.on_load_tracker(User) - self.on_load_tracker(Address, on_load) - - u = User(id=7, name='fred', addresses=OrderedSet([ - Address(id=1, email_address='fred1'), - Address(id=2, email_address='fred2'), - ])) - sess = create_session() - sess.add(u) - sess.flush() - sess.expunge_all() - - eq_(on_load.called, 0) - - u = User(id=7, name='fred', addresses=OrderedSet([ - Address(id=3, email_address='fred3'), - Address(id=4, email_address='fred4'), - ])) - - u = sess.merge(u) - - # 1. merges User object. updates into session. - # 2.,3. merges Address ids 3 & 4, saves into session. - # 4.,5. loads pre-existing elements in "addresses" collection, - # marks as deleted, Address ids 1 and 2. - eq_(on_load.called, 5) - - eq_(u, - User(id=7, name='fred', addresses=OrderedSet([ - Address(id=3, email_address='fred3'), - Address(id=4, email_address='fred4'), - ])) - ) - sess.flush() - sess.expunge_all() - eq_(sess.query(User).one(), - User(id=7, name='fred', addresses=OrderedSet([ - Address(id=3, email_address='fred3'), - Address(id=4, email_address='fred4'), - ])) - ) - - @testing.resolve_artifact_names - def test_detached_to_persistent_collection(self): - mapper(User, users, properties={ - 'addresses':relation(Address, - backref='user', - collection_class=OrderedSet)}) - mapper(Address, addresses) - on_load = self.on_load_tracker(User) - self.on_load_tracker(Address, on_load) - - a = Address(id=1, email_address='fred1') - u = User(id=7, name='fred', addresses=OrderedSet([ - a, - Address(id=2, email_address='fred2'), - ])) - sess = create_session() - sess.add(u) - sess.flush() - sess.expunge_all() - - u.name='fred jones' - u.addresses.add(Address(id=3, email_address='fred3')) - u.addresses.remove(a) - - eq_(on_load.called, 0) - u = sess.merge(u) - eq_(on_load.called, 4) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).first(), - User(id=7, name='fred jones', addresses=OrderedSet([ - Address(id=2, email_address='fred2'), - Address(id=3, email_address='fred3')]))) - - @testing.resolve_artifact_names - def test_unsaved_cascade(self): - """Merge of a transient entity with two child transient entities, with a bidirectional relation.""" - - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), - cascade="all", backref="user") - }) - on_load = self.on_load_tracker(User) - self.on_load_tracker(Address, on_load) - sess = create_session() - - u = User(id=7, name='fred') - a1 = Address(email_address='foo@bar.com') - a2 = Address(email_address='hoho@bar.com') - u.addresses.append(a1) - u.addresses.append(a2) - - u2 = sess.merge(u) - eq_(on_load.called, 3) - - eq_(u, - User(id=7, name='fred', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@bar.com')])) - eq_(u2, - User(id=7, name='fred', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@bar.com')])) - - sess.flush() - sess.expunge_all() - u2 = sess.query(User).get(7) - - eq_(u2, User(id=7, name='fred', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@bar.com')])) - eq_(on_load.called, 6) - - @testing.resolve_artifact_names - def test_merge_empty_attributes(self): - mapper(User, dingalings) - u1 = User(id=1) - sess = create_session() - sess.merge(u1) - sess.flush() - assert u1.address_id is u1.data is None - - @testing.resolve_artifact_names - def test_attribute_cascade(self): - """Merge of a persistent entity with two child persistent entities.""" - - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), backref='user') - }) - on_load = self.on_load_tracker(User) - self.on_load_tracker(Address, on_load) - - sess = create_session() - - # set up data and save - u = User(id=7, name='fred', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address = 'hoho@la.com')]) - sess.add(u) - sess.flush() - - # assert data was saved - sess2 = create_session() - u2 = sess2.query(User).get(7) - eq_(u2, - User(id=7, name='fred', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@la.com')])) - - # make local changes to data - u.name = 'fred2' - u.addresses[1].email_address = 'hoho@lalala.com' - - eq_(on_load.called, 3) - - # new session, merge modified data into session - sess3 = create_session() - u3 = sess3.merge(u) - eq_(on_load.called, 6) - - # ensure local changes are pending - eq_(u3, User(id=7, name='fred2', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@lalala.com')])) - - # save merged data - sess3.flush() - - # assert modified/merged data was saved - sess.expunge_all() - u = sess.query(User).get(7) - eq_(u, User(id=7, name='fred2', addresses=[ - Address(email_address='foo@bar.com'), - Address(email_address='hoho@lalala.com')])) - eq_(on_load.called, 9) - - # merge persistent object into another session - sess4 = create_session() - u = sess4.merge(u) - assert len(u.addresses) - for a in u.addresses: - assert a.user is u - def go(): - sess4.flush() - # no changes; therefore flush should do nothing - self.assert_sql_count(testing.db, go, 0) - eq_(on_load.called, 12) - - # test with "dontload" merge - sess5 = create_session() - u = sess5.merge(u, dont_load=True) - assert len(u.addresses) - for a in u.addresses: - assert a.user is u - def go(): - sess5.flush() - # no changes; therefore flush should do nothing - # but also, dont_load wipes out any difference in committed state, - # so no flush at all - self.assert_sql_count(testing.db, go, 0) - eq_(on_load.called, 15) - - sess4 = create_session() - u = sess4.merge(u, dont_load=True) - # post merge change - u.addresses[1].email_address='afafds' - def go(): - sess4.flush() - # afafds change flushes - self.assert_sql_count(testing.db, go, 1) - eq_(on_load.called, 18) - - sess5 = create_session() - u2 = sess5.query(User).get(u.id) - eq_(u2.name, 'fred2') - eq_(u2.addresses[1].email_address, 'afafds') - eq_(on_load.called, 21) - - @testing.resolve_artifact_names - def test_one_to_many_cascade(self): - - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses))}) - - on_load = self.on_load_tracker(User) - self.on_load_tracker(Address, on_load) - - sess = create_session() - u = User(name='fred') - a1 = Address(email_address='foo@bar') - a2 = Address(email_address='foo@quux') - u.addresses.extend([a1, a2]) - - sess.add(u) - sess.flush() - - eq_(on_load.called, 0) - - sess2 = create_session() - u2 = sess2.query(User).get(u.id) - eq_(on_load.called, 1) - - u.addresses[1].email_address = 'addr 2 modified' - sess2.merge(u) - eq_(u2.addresses[1].email_address, 'addr 2 modified') - eq_(on_load.called, 3) - - sess3 = create_session() - u3 = sess3.query(User).get(u.id) - eq_(on_load.called, 4) - - u.name = 'also fred' - sess3.merge(u) - eq_(on_load.called, 6) - eq_(u3.name, 'also fred') - - @testing.resolve_artifact_names - def test_many_to_many_cascade(self): - - mapper(Order, orders, properties={ - 'items':relation(mapper(Item, items), secondary=order_items)}) - - on_load = self.on_load_tracker(Order) - self.on_load_tracker(Item, on_load) - - sess = create_session() - - i1 = Item() - i1.description='item 1' - - i2 = Item() - i2.description = 'item 2' - - o = Order() - o.description = 'order description' - o.items.append(i1) - o.items.append(i2) - - sess.add(o) - sess.flush() - - eq_(on_load.called, 0) - - sess2 = create_session() - o2 = sess2.query(Order).get(o.id) - eq_(on_load.called, 1) - - o.items[1].description = 'item 2 modified' - sess2.merge(o) - eq_(o2.items[1].description, 'item 2 modified') - eq_(on_load.called, 3) - - sess3 = create_session() - o3 = sess3.query(Order).get(o.id) - eq_( on_load.called, 4) - - o.description = 'desc modified' - sess3.merge(o) - eq_(on_load.called, 6) - eq_(o3.description, 'desc modified') - - @testing.resolve_artifact_names - def test_one_to_one_cascade(self): - - mapper(User, users, properties={ - 'address':relation(mapper(Address, addresses),uselist = False) - }) - on_load = self.on_load_tracker(User) - self.on_load_tracker(Address, on_load) - sess = create_session() - - u = User() - u.id = 7 - u.name = "fred" - a1 = Address() - a1.email_address='foo@bar.com' - u.address = a1 - - sess.add(u) - sess.flush() - - eq_(on_load.called, 0) - - sess2 = create_session() - u2 = sess2.query(User).get(7) - eq_(on_load.called, 1) - u2.name = 'fred2' - u2.address.email_address = 'hoho@lalala.com' - eq_(on_load.called, 2) - - u3 = sess.merge(u2) - eq_(on_load.called, 2) - assert u3 is u - - @testing.resolve_artifact_names - def test_transient_dontload(self): - mapper(User, users) - - sess = create_session() - u = User() - self.assertRaisesMessage(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True) - - - @testing.resolve_artifact_names - def test_dontload_with_backrefs(self): - """dontload populates relations in both directions without requiring a load""" - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), backref='user') - }) - - u = User(id=7, name='fred', addresses=[ - Address(email_address='ad1'), - Address(email_address='ad2')]) - sess = create_session() - sess.add(u) - sess.flush() - sess.close() - assert 'user' in u.addresses[1].__dict__ - - sess = create_session() - u2 = sess.merge(u, dont_load=True) - assert 'user' in u2.addresses[1].__dict__ - eq_(u2.addresses[1].user, User(id=7, name='fred')) - - sess.expire(u2.addresses[1], ['user']) - assert 'user' not in u2.addresses[1].__dict__ - sess.close() - - sess = create_session() - u = sess.merge(u2, dont_load=True) - assert 'user' not in u.addresses[1].__dict__ - eq_(u.addresses[1].user, User(id=7, name='fred')) - - - @testing.resolve_artifact_names - def test_dontload_with_eager(self): - """ - - This test illustrates that with dont_load=True, we can't just copy the - committed_state of the merged instance over; since it references - collection objects which themselves are to be merged. This - committed_state would instead need to be piecemeal 'converted' to - represent the correct objects. However, at the moment I'd rather not - support this use case; if you are merging with dont_load=True, you're - typically dealing with caching and the merged objects shouldnt be - 'dirty'. - - """ - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses)) - }) - sess = create_session() - u = User() - u.id = 7 - u.name = "fred" - a1 = Address() - a1.email_address='foo@bar.com' - u.addresses.append(a1) - - sess.add(u) - sess.flush() - - sess2 = create_session() - u2 = sess2.query(User).options(sa.orm.eagerload('addresses')).get(7) - - sess3 = create_session() - u3 = sess3.merge(u2, dont_load=True) - def go(): - sess3.flush() - self.assert_sql_count(testing.db, go, 0) - - @testing.resolve_artifact_names - def test_dont_load_disallows_dirty(self): - """dont_load doesnt support 'dirty' objects right now - - (see test_dont_load_with_eager()). Therefore lets assert it. - - """ - mapper(User, users) - sess = create_session() - u = User() - u.id = 7 - u.name = "fred" - sess.add(u) - sess.flush() - - u.name = 'ed' - sess2 = create_session() - try: - sess2.merge(u, dont_load=True) - assert False - except sa.exc.InvalidRequestError, e: - assert ("merge() with dont_load=True option does not support " - "objects marked as 'dirty'. flush() all changes on mapped " - "instances before merging with dont_load=True.") in str(e) - - u2 = sess2.query(User).get(7) - - sess3 = create_session() - u3 = sess3.merge(u2, dont_load=True) - assert not sess3.dirty - def go(): - sess3.flush() - self.assert_sql_count(testing.db, go, 0) - - - @testing.resolve_artifact_names - def test_dont_load_sets_backrefs(self): - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses),backref='user')}) - - sess = create_session() - u = User() - u.id = 7 - u.name = "fred" - a1 = Address() - a1.email_address='foo@bar.com' - u.addresses.append(a1) - - sess.add(u) - sess.flush() - - assert u.addresses[0].user is u - - sess2 = create_session() - u2 = sess2.merge(u, dont_load=True) - assert not sess2.dirty - def go(): - assert u2.addresses[0].user is u2 - self.assert_sql_count(testing.db, go, 0) - - @testing.resolve_artifact_names - def test_dont_load_preserves_parents(self): - """Merge with dont_load does not trigger a 'delete-orphan' operation. - - merge with dont_load sets attributes without using events. this means - the 'hasparent' flag is not propagated to the newly merged instance. - in fact this works out OK, because the '_state.parents' collection on - the newly merged instance is empty; since the mapper doesn't see an - active 'False' setting in this collection when _is_orphan() is called, - it does not count as an orphan (i.e. this is the 'optimistic' logic in - mapper._is_orphan().) - - """ - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), - backref='user', cascade="all, delete-orphan")}) - sess = create_session() - u = User() - u.id = 7 - u.name = "fred" - a1 = Address() - a1.email_address='foo@bar.com' - u.addresses.append(a1) - sess.add(u) - sess.flush() - - assert u.addresses[0].user is u - - sess2 = create_session() - u2 = sess2.merge(u, dont_load=True) - assert not sess2.dirty - a2 = u2.addresses[0] - a2.email_address='somenewaddress' - assert not sa.orm.object_mapper(a2)._is_orphan( - sa.orm.attributes.instance_state(a2)) - sess2.flush() - sess2.expunge_all() - - eq_(sess2.query(User).get(u2.id).addresses[0].email_address, - 'somenewaddress') - - # this use case is not supported; this is with a pending Address on - # the pre-merged object, and we currently dont support 'dirty' objects - # being merged with dont_load=True. in this case, the empty - # '_state.parents' collection would be an issue, since the optimistic - # flag is False in _is_orphan() for pending instances. so if we start - # supporting 'dirty' with dont_load=True, this test will need to pass - sess = create_session() - u = sess.query(User).get(7) - u.addresses.append(Address()) - sess2 = create_session() - try: - u2 = sess2.merge(u, dont_load=True) - assert False - - # if dont_load is changed to support dirty objects, this code - # needs to pass - a2 = u2.addresses[0] - a2.email_address='somenewaddress' - assert not sa.orm.object_mapper(a2)._is_orphan( - sa.orm.attributes.instance_state(a2)) - sess2.flush() - sess2.expunge_all() - eq_(sess2.query(User).get(u2.id).addresses[0].email_address, - 'somenewaddress') - except sa.exc.InvalidRequestError, e: - assert "dont_load=True option does not support" in str(e) - - @testing.resolve_artifact_names - def test_synonym_comparable(self): - class User(object): - - class Comparator(PropComparator): - pass - - def _getValue(self): - return self._value - - def _setValue(self, value): - setattr(self, '_value', value) - - value = property(_getValue, _setValue) - - mapper(User, users, properties={ - 'uid':synonym('id'), - 'foobar':comparable_property(User.Comparator,User.value), - }) - - sess = create_session() - u = User() - u.name = 'ed' - sess.add(u) - sess.flush() - sess.expunge(u) - sess.merge(u) - - @testing.resolve_artifact_names - def test_cascade_doesnt_blowaway_manytoone(self): - """a merge test that was fixed by [ticket:1202]""" - - s = create_session(autoflush=True) - mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses),backref='user')}) - - a1 = Address(user=s.merge(User(id=1, name='ed')), email_address='x') - before_id = id(a1.user) - a2 = Address(user=s.merge(User(id=1, name='jack')), email_address='x') - after_id = id(a1.user) - other_id = id(a2.user) - eq_(before_id, other_id) - eq_(after_id, other_id) - eq_(before_id, after_id) - eq_(a1.user, a2.user) - - @testing.resolve_artifact_names - def test_cascades_dont_autoflush(self): - sess = create_session(autoflush=True) - m = mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses),backref='user')}) - user = User(id=8, name='fred', addresses=[Address(email_address='user')]) - merged_user = sess.merge(user) - assert merged_user in sess.new - sess.flush() - assert merged_user not in sess.new - - @testing.resolve_artifact_names - def test_cascades_dont_autoflush_2(self): - mapper(User, users, properties={ - 'addresses':relation(Address, - backref='user', - cascade="all, delete-orphan") - }) - mapper(Address, addresses) - - u = User(id=7, name='fred', addresses=[ - Address(id=1, email_address='fred1'), - ]) - sess = create_session(autoflush=True, autocommit=False) - sess.add(u) - sess.commit() - - sess.expunge_all() - - u = User(id=7, name='fred', addresses=[ - Address(id=1, email_address='fred1'), - Address(id=2, email_address='fred2'), - ]) - sess.merge(u) - assert sess.autoflush - sess.commit() - - - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/naturalpks.py b/test/orm/naturalpks.py deleted file mode 100644 index 8efce660c..000000000 --- a/test/orm/naturalpks.py +++ /dev/null @@ -1,475 +0,0 @@ -""" -Primary key changing capabilities and passive/non-passive cascading updates. - -""" -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from testlib.testing import eq_ -from orm import _base - -class NaturalPKTest(_base.MappedTest): - - def define_tables(self, metadata): - users = Table('users', metadata, - Column('username', String(50), primary_key=True), - Column('fullname', String(100)), - test_needs_fk=True) - - addresses = Table('addresses', metadata, - Column('email', String(50), primary_key=True), - Column('username', String(50), ForeignKey('users.username', onupdate="cascade")), - test_needs_fk=True) - - items = Table('items', metadata, - Column('itemname', String(50), primary_key=True), - Column('description', String(100)), - test_needs_fk=True) - - users_to_items = Table('users_to_items', metadata, - Column('username', String(50), ForeignKey('users.username', onupdate='cascade'), primary_key=True), - Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True), - test_needs_fk=True) - - def setup_classes(self): - class User(_base.ComparableEntity): - pass - class Address(_base.ComparableEntity): - pass - class Item(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def test_entity(self): - mapper(User, users) - - sess = create_session() - u1 = User(username='jack', fullname='jack') - - sess.add(u1) - sess.flush() - assert sess.query(User).get('jack') is u1 - - u1.username = 'ed' - sess.flush() - - def go(): - assert sess.query(User).get('ed') is u1 - self.assert_sql_count(testing.db, go, 0) - - assert sess.query(User).get('jack') is None - - sess.expunge_all() - u1 = sess.query(User).get('ed') - self.assertEquals(User(username='ed', fullname='jack'), u1) - - @testing.resolve_artifact_names - def test_load_after_expire(self): - mapper(User, users) - - sess = create_session() - u1 = User(username='jack', fullname='jack') - - sess.add(u1) - sess.flush() - assert sess.query(User).get('jack') is u1 - - users.update(values={User.username:'jack'}).execute(username='ed') - - # expire/refresh works off of primary key. the PK is gone - # in this case so theres no way to look it up. criterion- - # based session invalidation could solve this [ticket:911] - sess.expire(u1) - self.assertRaises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username') - - sess.expunge_all() - assert sess.query(User).get('jack') is None - assert sess.query(User).get('ed').fullname == 'jack' - - @testing.resolve_artifact_names - def test_flush_new_pk_after_expire(self): - mapper(User, users) - sess = create_session() - u1 = User(username='jack', fullname='jack') - - sess.add(u1) - sess.flush() - assert sess.query(User).get('jack') is u1 - - sess.expire(u1) - u1.username = 'ed' - sess.flush() - sess.expunge_all() - assert sess.query(User).get('ed').fullname == 'jack' - - - @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') - def test_onetomany_passive(self): - self._test_onetomany(True) - - def test_onetomany_nonpassive(self): - self._test_onetomany(False) - - @testing.resolve_artifact_names - def _test_onetomany(self, passive_updates): - mapper(User, users, properties={ - 'addresses':relation(Address, passive_updates=passive_updates) - }) - mapper(Address, addresses) - - sess = create_session() - u1 = User(username='jack', fullname='jack') - u1.addresses.append(Address(email='jack1')) - u1.addresses.append(Address(email='jack2')) - sess.add(u1) - sess.flush() - - assert sess.query(Address).get('jack1') is u1.addresses[0] - - u1.username = 'ed' - sess.flush() - assert u1.addresses[0].username == 'ed' - - sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) - - u1 = sess.query(User).get('ed') - u1.username = 'jack' - def go(): - sess.flush() - if not passive_updates: - self.assert_sql_count(testing.db, go, 4) # test passive_updates=False; load addresses, update user, update 2 addresses - else: - self.assert_sql_count(testing.db, go, 1) # test passive_updates=True; update user - sess.expunge_all() - assert User(username='jack', addresses=[Address(username='jack'), Address(username='jack')]) == sess.query(User).get('jack') - - u1 = sess.query(User).get('jack') - u1.addresses = [] - u1.username = 'fred' - sess.flush() - sess.expunge_all() - assert sess.query(Address).get('jack1').username is None - u1 = sess.query(User).get('fred') - self.assertEquals(User(username='fred', fullname='jack'), u1) - - - @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') - def test_manytoone_passive(self): - self._test_manytoone(True) - - def test_manytoone_nonpassive(self): - self._test_manytoone(False) - - @testing.resolve_artifact_names - def _test_manytoone(self, passive_updates): - mapper(User, users) - mapper(Address, addresses, properties={ - 'user':relation(User, passive_updates=passive_updates) - }) - - sess = create_session() - a1 = Address(email='jack1') - a2 = Address(email='jack2') - - u1 = User(username='jack', fullname='jack') - a1.user = u1 - a2.user = u1 - sess.add(a1) - sess.add(a2) - sess.flush() - - u1.username = 'ed' - - def go(): - sess.flush() - if passive_updates: - self.assert_sql_count(testing.db, go, 1) - else: - self.assert_sql_count(testing.db, go, 3) - - def go(): - sess.flush() - self.assert_sql_count(testing.db, go, 0) - - assert a1.username == a2.username == 'ed' - sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) - - @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') - def test_onetoone_passive(self): - self._test_onetoone(True) - - def test_onetoone_nonpassive(self): - self._test_onetoone(False) - - @testing.resolve_artifact_names - def _test_onetoone(self, passive_updates): - mapper(User, users, properties={ - "address":relation(Address, passive_updates=passive_updates, uselist=False) - }) - mapper(Address, addresses) - - sess = create_session() - u1 = User(username='jack', fullname='jack') - sess.add(u1) - sess.flush() - - a1 = Address(email='jack1') - u1.address = a1 - sess.add(a1) - sess.flush() - - u1.username = 'ed' - - def go(): - sess.flush() - if passive_updates: - sess.expire(u1, ['address']) - self.assert_sql_count(testing.db, go, 1) - else: - self.assert_sql_count(testing.db, go, 2) - - def go(): - sess.flush() - self.assert_sql_count(testing.db, go, 0) - - sess.expunge_all() - self.assertEquals([Address(username='ed')], sess.query(Address).all()) - - @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') - def test_bidirectional_passive(self): - self._test_bidirectional(True) - - def test_bidirectional_nonpassive(self): - self._test_bidirectional(False) - - @testing.resolve_artifact_names - def _test_bidirectional(self, passive_updates): - mapper(User, users) - mapper(Address, addresses, properties={ - 'user':relation(User, passive_updates=passive_updates, - backref='addresses')}) - - sess = create_session() - a1 = Address(email='jack1') - a2 = Address(email='jack2') - - u1 = User(username='jack', fullname='jack') - a1.user = u1 - a2.user = u1 - sess.add(a1) - sess.add(a2) - sess.flush() - - u1.username = 'ed' - (ad1, ad2) = sess.query(Address).all() - self.assertEquals([Address(username='jack'), Address(username='jack')], [ad1, ad2]) - def go(): - sess.flush() - if passive_updates: - sess.expire(u1, ['addresses']) - self.assert_sql_count(testing.db, go, 1) - else: - self.assert_sql_count(testing.db, go, 3) - self.assertEquals([Address(username='ed'), Address(username='ed')], [ad1, ad2]) - sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) - - u1 = sess.query(User).get('ed') - assert len(u1.addresses) == 2 # load addresses - u1.username = 'fred' - def go(): - sess.flush() - # check that the passive_updates is on on the other side - if passive_updates: - sess.expire(u1, ['addresses']) - self.assert_sql_count(testing.db, go, 1) - else: - self.assert_sql_count(testing.db, go, 3) - sess.expunge_all() - self.assertEquals([Address(username='fred'), Address(username='fred')], sess.query(Address).all()) - - - @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') - def test_manytomany_passive(self): - self._test_manytomany(True) - - @testing.fails_on('mysql', 'the executemany() of the association table fails to report the correct row count') - def test_manytomany_nonpassive(self): - self._test_manytomany(False) - - @testing.resolve_artifact_names - def _test_manytomany(self, passive_updates): - mapper(User, users, properties={ - 'items':relation(Item, secondary=users_to_items, backref='users', - passive_updates=passive_updates)}) - mapper(Item, items) - - sess = create_session() - u1 = User(username='jack') - u2 = User(username='fred') - i1 = Item(itemname='item1') - i2 = Item(itemname='item2') - - u1.items.append(i1) - u1.items.append(i2) - i2.users.append(u2) - sess.add(u1) - sess.add(u2) - sess.flush() - - r = sess.query(Item).all() - # ComparableEntity can't handle a comparison with the backrefs - # involved.... - self.assertEquals(Item(itemname='item1'), r[0]) - self.assertEquals(['jack'], [u.username for u in r[0].users]) - self.assertEquals(Item(itemname='item2'), r[1]) - self.assertEquals(['jack', 'fred'], [u.username for u in r[1].users]) - - u2.username='ed' - def go(): - sess.flush() - go() - def go(): - sess.flush() - self.assert_sql_count(testing.db, go, 0) - - sess.expunge_all() - r = sess.query(Item).all() - self.assertEquals(Item(itemname='item1'), r[0]) - self.assertEquals(['jack'], [u.username for u in r[0].users]) - self.assertEquals(Item(itemname='item2'), r[1]) - self.assertEquals(['ed', 'jack'], sorted([u.username for u in r[1].users])) - - sess.expunge_all() - u2 = sess.query(User).get(u2.username) - u2.username='wendy' - sess.flush() - r = sess.query(Item).with_parent(u2).all() - self.assertEquals(Item(itemname='item2'), r[0]) - - -class SelfRefTest(_base.MappedTest): - __unsupported_on__ = 'mssql' # mssql doesn't allow ON UPDATE on self-referential keys - - def define_tables(self, metadata): - Table('nodes', metadata, - Column('name', String(50), primary_key=True), - Column('parent', String(50), - ForeignKey('nodes.name', onupdate='cascade'))) - - def setup_classes(self): - class Node(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def test_onetomany(self): - mapper(Node, nodes, properties={ - 'children': relation(Node, - backref=sa.orm.backref('parentnode', - remote_side=nodes.c.name, - passive_updates=False), - passive_updates=False)}) - - sess = create_session() - n1 = Node(name='n1') - n1.children.append(Node(name='n11')) - n1.children.append(Node(name='n12')) - n1.children.append(Node(name='n13')) - sess.add(n1) - sess.flush() - - n1.name = 'new n1' - sess.flush() - eq_(n1.children[1].parent, 'new n1') - eq_(['new n1', 'new n1', 'new n1'], - [n.parent - for n in sess.query(Node).filter( - Node.name.in_(['n11', 'n12', 'n13']))]) - - -class NonPKCascadeTest(_base.MappedTest): - def define_tables(self, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('username', String(50), unique=True), - Column('fullname', String(100)), - test_needs_fk=True) - - Table('addresses', metadata, - Column('id', Integer, primary_key=True), - Column('email', String(50)), - Column('username', String(50), - ForeignKey('users.username', onupdate="cascade")), - test_needs_fk=True - ) - - def setup_classes(self): - class User(_base.ComparableEntity): - pass - class Address(_base.ComparableEntity): - pass - - @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') - def test_onetomany_passive(self): - self._test_onetomany(True) - - def test_onetomany_nonpassive(self): - self._test_onetomany(False) - - @testing.resolve_artifact_names - def _test_onetomany(self, passive_updates): - mapper(User, users, properties={ - 'addresses':relation(Address, passive_updates=passive_updates)}) - mapper(Address, addresses) - - sess = create_session() - u1 = User(username='jack', fullname='jack') - u1.addresses.append(Address(email='jack1')) - u1.addresses.append(Address(email='jack2')) - sess.add(u1) - sess.flush() - a1 = u1.addresses[0] - - self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)]) - - assert sess.query(Address).get(a1.id) is u1.addresses[0] - - u1.username = 'ed' - sess.flush() - assert u1.addresses[0].username == 'ed' - self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)]) - - sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) - - u1 = sess.query(User).get(u1.id) - u1.username = 'jack' - def go(): - sess.flush() - if not passive_updates: - self.assert_sql_count(testing.db, go, 4) # test passive_updates=False; load addresses, update user, update 2 addresses - else: - self.assert_sql_count(testing.db, go, 1) # test passive_updates=True; update user - sess.expunge_all() - assert User(username='jack', addresses=[Address(username='jack'), Address(username='jack')]) == sess.query(User).get(u1.id) - sess.expunge_all() - - u1 = sess.query(User).get(u1.id) - u1.addresses = [] - u1.username = 'fred' - sess.flush() - sess.expunge_all() - a1 = sess.query(Address).get(a1.id) - self.assertEquals(a1.username, None) - - self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [(None,), (None,)]) - - u1 = sess.query(User).get(u1.id) - self.assertEquals(User(username='fred', fullname='jack'), u1) - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/onetoone.py b/test/orm/onetoone.py deleted file mode 100644 index be0375e48..000000000 --- a/test/orm/onetoone.py +++ /dev/null @@ -1,74 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base - - -class O2OTest(_base.MappedTest): - def define_tables(self, metadata): - Table('jack', metadata, - Column('id', Integer, primary_key=True), - Column('number', String(50)), - Column('status', String(20)), - Column('subroom', String(5))) - - Table('port', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30)), - Column('description', String(100)), - Column('jack_id', Integer, ForeignKey("jack.id"))) - - @testing.resolve_artifact_names - def setup_mappers(self): - class Jack(_base.BasicEntity): - pass - class Port(_base.BasicEntity): - pass - - - @testing.resolve_artifact_names - def test_basic(self): - mapper(Port, port) - mapper(Jack, jack, - order_by=[jack.c.number], - properties=dict( - port=relation(Port, backref='jack', - uselist=False, - )), - ) - - session = create_session() - - j = Jack(number='101') - session.add(j) - p = Port(name='fa0/1') - session.add(p) - - j.port=p - session.flush() - jid = j.id - pid = p.id - - j=session.query(Jack).get(jid) - p=session.query(Port).get(pid) - assert p.jack is not None - assert p.jack is j - assert j.port is not None - p.jack = None - assert j.port is None - - session.expunge_all() - - j = session.query(Jack).get(jid) - p = session.query(Port).get(pid) - - j.port=None - self.assert_(p.jack is None) - session.flush() - - session.delete(j) - session.flush() - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/pickled.py b/test/orm/pickled.py deleted file mode 100644 index 878fe931e..000000000 --- a/test/orm/pickled.py +++ /dev/null @@ -1,190 +0,0 @@ -import testenv; testenv.configure_for_tests() -import pickle -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session, attributes -from orm import _base, _fixtures - - -User, EmailUser = None, None - -class PickleTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def test_transient(self): - mapper(User, users, properties={ - 'addresses':relation(Address, backref="user") - }) - mapper(Address, addresses) - - sess = create_session() - u1 = User(name='ed') - u1.addresses.append(Address(email_address='ed@bar.com')) - - u2 = pickle.loads(pickle.dumps(u1)) - sess.add(u2) - sess.flush() - - sess.expunge_all() - - self.assertEquals(u1, sess.query(User).get(u2.id)) - - @testing.resolve_artifact_names - def test_class_deferred_cols(self): - mapper(User, users, properties={ - 'name':sa.orm.deferred(users.c.name), - 'addresses':relation(Address, backref="user") - }) - mapper(Address, addresses, properties={ - 'email_address':sa.orm.deferred(addresses.c.email_address) - }) - sess = create_session() - u1 = User(name='ed') - u1.addresses.append(Address(email_address='ed@bar.com')) - sess.add(u1) - sess.flush() - sess.expunge_all() - u1 = sess.query(User).get(u1.id) - assert 'name' not in u1.__dict__ - assert 'addresses' not in u1.__dict__ - - u2 = pickle.loads(pickle.dumps(u1)) - sess2 = create_session() - sess2.add(u2) - self.assertEquals(u2.name, 'ed') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) - - u2 = pickle.loads(pickle.dumps(u1)) - sess2 = create_session() - u2 = sess2.merge(u2, dont_load=True) - self.assertEquals(u2.name, 'ed') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) - - @testing.resolve_artifact_names - def test_instance_deferred_cols(self): - mapper(User, users, properties={ - 'addresses':relation(Address, backref="user") - }) - mapper(Address, addresses) - - sess = create_session() - u1 = User(name='ed') - u1.addresses.append(Address(email_address='ed@bar.com')) - sess.add(u1) - sess.flush() - sess.expunge_all() - - u1 = sess.query(User).options(sa.orm.defer('name'), sa.orm.defer('addresses.email_address')).get(u1.id) - assert 'name' not in u1.__dict__ - assert 'addresses' not in u1.__dict__ - - u2 = pickle.loads(pickle.dumps(u1)) - sess2 = create_session() - sess2.add(u2) - self.assertEquals(u2.name, 'ed') - assert 'addresses' not in u2.__dict__ - ad = u2.addresses[0] - assert 'email_address' not in ad.__dict__ - self.assertEquals(ad.email_address, 'ed@bar.com') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) - - u2 = pickle.loads(pickle.dumps(u1)) - sess2 = create_session() - u2 = sess2.merge(u2, dont_load=True) - self.assertEquals(u2.name, 'ed') - assert 'addresses' not in u2.__dict__ - ad = u2.addresses[0] - assert 'email_address' in ad.__dict__ # mapper options dont transmit over merge() right now - self.assertEquals(ad.email_address, 'ed@bar.com') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) - - @testing.resolve_artifact_names - def test_options_with_descriptors(self): - mapper(User, users, properties={ - 'addresses':relation(Address, backref="user") - }) - mapper(Address, addresses) - sess = create_session() - u1 = User(name='ed') - u1.addresses.append(Address(email_address='ed@bar.com')) - sess.add(u1) - sess.flush() - sess.expunge_all() - - for opt in [ - sa.orm.eagerload(User.addresses), - sa.orm.eagerload("addresses"), - sa.orm.defer("name"), - sa.orm.defer(User.name), - sa.orm.defer([User.name]), - sa.orm.eagerload("addresses", User.addresses), - sa.orm.eagerload(["addresses", User.addresses]), - ]: - opt2 = pickle.loads(pickle.dumps(opt)) - self.assertEquals(opt.key, opt2.key) - - u1 = sess.query(User).options(opt).first() - - u2 = pickle.loads(pickle.dumps(u1)) - - -class PolymorphicDeferredTest(_base.MappedTest): - def define_tables(self, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30)), - Column('type', String(30))) - Table('email_users', metadata, - Column('id', Integer, ForeignKey('users.id'), primary_key=True), - Column('email_address', String(30))) - - def setup_classes(self): - global User, EmailUser - class User(_base.BasicEntity): - pass - - class EmailUser(User): - pass - - def tearDownAll(self): - global User, EmailUser - User, EmailUser = None, None - _base.MappedTest.tearDownAll(self) - - @testing.resolve_artifact_names - def test_polymorphic_deferred(self): - mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type) - mapper(EmailUser, email_users, inherits=User, polymorphic_identity='emailuser') - - eu = EmailUser(name="user1", email_address='foo@bar.com') - sess = create_session() - sess.add(eu) - sess.flush() - sess.expunge_all() - - eu = sess.query(User).first() - eu2 = pickle.loads(pickle.dumps(eu)) - sess2 = create_session() - sess2.add(eu2) - assert 'email_address' not in eu2.__dict__ - self.assertEquals(eu2.email_address, 'foo@bar.com') - -class CustomSetupTeardowntest(_fixtures.FixtureTest): - @testing.resolve_artifact_names - def test_rebuild_state(self): - """not much of a 'test', but illustrate how to - remove instance-level state before pickling. - - """ - mapper(User, users) - - u1 = User() - attributes.manager_of_class(User).teardown_instance(u1) - assert not u1.__dict__ - u2 = pickle.loads(pickle.dumps(u1)) - attributes.manager_of_class(User).setup_instance(u2) - assert attributes.instance_state(u2) - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/query.py b/test/orm/query.py deleted file mode 100644 index 33c3e39d7..000000000 --- a/test/orm/query.py +++ /dev/null @@ -1,3012 +0,0 @@ -import testenv; testenv.configure_for_tests() -import operator -from sqlalchemy import * -from sqlalchemy import exc as sa_exc, util -from sqlalchemy.sql import compiler, table, column -from sqlalchemy.engine import default -from sqlalchemy.orm import * -from sqlalchemy.orm import attributes - -from testlib.testing import eq_ - -from testlib import sa, testing, AssertsCompiledSQL, Column, engines - -from orm import _fixtures -from orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \ - Dingaling, item_keywords, dingalings, User, items,\ - orders, Address, users, nodes, \ - order_items, Item, Order, Node - -from orm import _base - -from sqlalchemy.orm.util import join, outerjoin, with_parent - -class QueryTest(_fixtures.FixtureTest): - run_setup_mappers = 'once' - run_inserts = 'once' - run_deletes = None - - - def setup_mappers(self): - mapper(User, users, properties={ - 'addresses':relation(Address, backref='user', order_by=addresses.c.id), - 'orders':relation(Order, backref='user', order_by=orders.c.id), # o2m, m2o - }) - mapper(Address, addresses, properties={ - 'dingaling':relation(Dingaling, uselist=False, backref="address") #o2o - }) - mapper(Dingaling, dingalings) - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m - 'address':relation(Address), # m2o - }) - mapper(Item, items, properties={ - 'keywords':relation(Keyword, secondary=item_keywords) #m2m - }) - mapper(Keyword, keywords) - - mapper(Node, nodes, properties={ - 'children':relation(Node, - backref=backref('parent', remote_side=[nodes.c.id]) - ) - }) - - compile_mappers() - -class RowTupleTest(QueryTest): - run_setup_mappers = None - - def test_custom_names(self): - mapper(User, users, properties={ - 'uname':users.c.name - }) - - row = create_session().query(User.id, User.uname).filter(User.id==7).first() - assert row.id == 7 - assert row.uname == 'jack' - -class GetTest(QueryTest): - def test_get(self): - s = create_session() - assert s.query(User).get(19) is None - u = s.query(User).get(7) - u2 = s.query(User).get(7) - assert u is u2 - s.expunge_all() - u2 = s.query(User).get(7) - assert u is not u2 - - def test_no_criterion(self): - """test that get()/load() does not use preexisting filter/etc. criterion""" - - s = create_session() - - q = s.query(User).join('addresses').filter(Address.user_id==8) - self.assertRaises(sa_exc.InvalidRequestError, q.get, 7) - self.assertRaises(sa_exc.InvalidRequestError, s.query(User).filter(User.id==7).get, 19) - - # order_by()/get() doesn't raise - s.query(User).order_by(User.id).get(8) - - def test_unique_param_names(self): - class SomeUser(object): - pass - s = users.select(users.c.id!=12).alias('users') - m = mapper(SomeUser, s) - print s.primary_key - print m.primary_key - assert s.primary_key == m.primary_key - - row = s.select(use_labels=True).execute().fetchone() - print row[s.primary_key[0]] - - sess = create_session() - assert sess.query(SomeUser).get(7).name == 'jack' - - def test_load(self): - s = create_session() - - assert s.query(User).populate_existing().get(19) is None - - u = s.query(User).populate_existing().get(7) - u2 = s.query(User).populate_existing().get(7) - assert u is u2 - s.expunge_all() - u2 = s.query(User).populate_existing().get(7) - assert u is not u2 - - u2.name = 'some name' - a = Address(email_address='some other name') - u2.addresses.append(a) - assert u2 in s.dirty - assert a in u2.addresses - - s.query(User).populate_existing().get(7) - assert u2 not in s.dirty - assert u2.name =='jack' - assert a not in u2.addresses - - @testing.requires.unicode_connections - def test_unicode(self): - """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail - on postgres, mysql and oracle unless it is converted to an encoded string""" - - metadata = MetaData(engines.utf8_engine()) - table = Table('unicode_data', metadata, - Column('id', Unicode(40), primary_key=True), - Column('data', Unicode(40))) - try: - metadata.create_all() - ustring = 'petit voix m\xe2\x80\x99a'.decode('utf-8') - table.insert().execute(id=ustring, data=ustring) - class LocalFoo(Base): - pass - mapper(LocalFoo, table) - self.assertEquals(create_session().query(LocalFoo).get(ustring), - LocalFoo(id=ustring, data=ustring)) - finally: - metadata.drop_all() - - def test_populate_existing(self): - s = create_session() - - userlist = s.query(User).all() - - u = userlist[0] - u.name = 'foo' - a = Address(name='ed') - u.addresses.append(a) - - self.assert_(a in u.addresses) - - s.query(User).populate_existing().all() - - self.assert_(u not in s.dirty) - - self.assert_(u.name == 'jack') - - self.assert_(a not in u.addresses) - - u.addresses[0].email_address = 'lala' - u.orders[1].items[2].description = 'item 12' - # test that lazy load doesnt change child items - s.query(User).populate_existing().all() - assert u.addresses[0].email_address == 'lala' - assert u.orders[1].items[2].description == 'item 12' - - # eager load does - s.query(User).options(eagerload('addresses'), eagerload_all('orders.items')).populate_existing().all() - assert u.addresses[0].email_address == 'jack@bean.com' - assert u.orders[1].items[2].description == 'item 5' - - @testing.fails_on_everything_except('sqlite', 'mssql') - def test_query_str(self): - s = create_session() - q = s.query(User).filter(User.id==1) - self.assertEquals( - str(q).replace('\n',''), - 'SELECT users.id AS users_id, users.name AS users_name FROM users WHERE users.id = ?' - ) - -class InvalidGenerationsTest(QueryTest): - def test_no_limit_offset(self): - s = create_session() - - for q in ( - s.query(User).limit(2), - s.query(User).offset(2), - s.query(User).limit(2).offset(2) - ): - self.assertRaises(sa_exc.InvalidRequestError, q.join, "addresses") - - self.assertRaises(sa_exc.InvalidRequestError, q.filter, User.name=='ed') - - self.assertRaises(sa_exc.InvalidRequestError, q.filter_by, name='ed') - - self.assertRaises(sa_exc.InvalidRequestError, q.order_by, 'foo') - - self.assertRaises(sa_exc.InvalidRequestError, q.group_by, 'foo') - - self.assertRaises(sa_exc.InvalidRequestError, q.having, 'foo') - - def test_no_from(self): - s = create_session() - - q = s.query(User).select_from(users) - self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users) - - q = s.query(User).join('addresses') - self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users) - - q = s.query(User).order_by(User.id) - self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users) - - # this is fine, however - q.from_self() - - def test_invalid_select_from(self): - s = create_session() - q = s.query(User) - self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id==5) - self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id) - - def test_invalid_from_statement(self): - s = create_session() - q = s.query(User) - self.assertRaises(sa_exc.ArgumentError, q.from_statement, User.id==5) - self.assertRaises(sa_exc.ArgumentError, q.from_statement, users.join(addresses)) - - def test_invalid_column(self): - s = create_session() - q = s.query(User) - self.assertRaises(sa_exc.InvalidRequestError, q.add_column, object()) - - def test_mapper_zero(self): - s = create_session() - - q = s.query(User, Address) - self.assertRaises(sa_exc.InvalidRequestError, q.get, 5) - - def test_from_statement(self): - s = create_session() - - q = s.query(User).filter(User.id==5) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") - - q = s.query(User).filter_by(id=5) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") - - q = s.query(User).limit(5) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") - - q = s.query(User).group_by(User.name) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") - - q = s.query(User).order_by(User.name) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") - -class OperatorTest(QueryTest, AssertsCompiledSQL): - """test sql.Comparator implementation for MapperProperties""" - - def _test(self, clause, expected): - self.assert_compile(clause, expected, dialect=default.DefaultDialect()) - - def test_arithmetic(self): - create_session().query(User) - for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), - (operator.sub, '-'), (operator.div, '/'), - ): - for (lhs, rhs, res) in ( - (5, User.id, ':id_1 %s users.id'), - (5, literal(6), ':param_1 %s :param_2'), - (User.id, 5, 'users.id %s :id_1'), - (User.id, literal('b'), 'users.id %s :param_1'), - (User.id, User.id, 'users.id %s users.id'), - (literal(5), 'b', ':param_1 %s :param_2'), - (literal(5), User.id, ':param_1 %s users.id'), - (literal(5), literal(6), ':param_1 %s :param_2'), - ): - self._test(py_op(lhs, rhs), res % sql_op) - - def test_comparison(self): - create_session().query(User) - ualias = aliased(User) - - for (py_op, fwd_op, rev_op) in ((operator.lt, '<', '>'), - (operator.gt, '>', '<'), - (operator.eq, '=', '='), - (operator.ne, '!=', '!='), - (operator.le, '<=', '>='), - (operator.ge, '>=', '<=')): - for (lhs, rhs, l_sql, r_sql) in ( - ('a', User.id, ':id_1', 'users.id'), - ('a', literal('b'), ':param_2', ':param_1'), # note swap! - (User.id, 'b', 'users.id', ':id_1'), - (User.id, literal('b'), 'users.id', ':param_1'), - (User.id, User.id, 'users.id', 'users.id'), - (literal('a'), 'b', ':param_1', ':param_2'), - (literal('a'), User.id, ':param_1', 'users.id'), - (literal('a'), literal('b'), ':param_1', ':param_2'), - (ualias.id, literal('b'), 'users_1.id', ':param_1'), - (User.id, ualias.name, 'users.id', 'users_1.name'), - (User.name, ualias.name, 'users.name', 'users_1.name'), - (ualias.name, User.name, 'users_1.name', 'users.name'), - ): - - # the compiled clause should match either (e.g.): - # 'a' < 'b' -or- 'b' > 'a'. - compiled = str(py_op(lhs, rhs).compile(dialect=default.DefaultDialect())) - fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql) - rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql) - - self.assert_(compiled == fwd_sql or compiled == rev_sql, - "\n'" + compiled + "'\n does not match\n'" + - fwd_sql + "'\n or\n'" + rev_sql + "'") - - def test_negated_null(self): - self._test(User.id == None, "users.id IS NULL") - self._test(~(User.id==None), "users.id IS NOT NULL") - self._test(None == User.id, "users.id IS NULL") - self._test(~(None == User.id), "users.id IS NOT NULL") - self._test(Address.user == None, "addresses.user_id IS NULL") - self._test(~(Address.user==None), "addresses.user_id IS NOT NULL") - self._test(None == Address.user, "addresses.user_id IS NULL") - self._test(~(None == Address.user), "addresses.user_id IS NOT NULL") - - def test_relation(self): - self._test(User.addresses.any(Address.id==17), - "EXISTS (SELECT 1 " - "FROM addresses " - "WHERE users.id = addresses.user_id AND addresses.id = :id_1)" - ) - - u7 = User(id=7) - attributes.instance_state(u7).commit_all(attributes.instance_dict(u7)) - - self._test(Address.user == u7, ":param_1 = addresses.user_id") - - self._test(Address.user != u7, "addresses.user_id != :user_id_1 OR addresses.user_id IS NULL") - - self._test(Address.user == None, "addresses.user_id IS NULL") - - self._test(Address.user != None, "addresses.user_id IS NOT NULL") - - def test_selfref_relation(self): - nalias = aliased(Node) - - # auto self-referential aliasing - self._test( - Node.children.any(Node.data=='n1'), - "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE " - "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)" - ) - - # needs autoaliasing - self._test( - Node.children==None, - "NOT (EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE nodes.id = nodes_1.parent_id))" - ) - - self._test( - Node.parent==None, - "nodes.parent_id IS NULL" - ) - - self._test( - nalias.parent==None, - "nodes_1.parent_id IS NULL" - ) - - self._test( - nalias.children==None, - "NOT (EXISTS (SELECT 1 FROM nodes WHERE nodes_1.id = nodes.parent_id))" - ) - - self._test( - nalias.children.any(Node.data=='some data'), - "EXISTS (SELECT 1 FROM nodes WHERE " - "nodes_1.id = nodes.parent_id AND nodes.data = :data_1)") - - # fails, but I think I want this to fail - #self._test( - # Node.children.any(nalias.data=='some data'), - # "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE " - # "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)" - # ) - - self._test( - nalias.parent.has(Node.data=='some data'), - "EXISTS (SELECT 1 FROM nodes WHERE nodes.id = nodes_1.parent_id AND nodes.data = :data_1)" - ) - - self._test( - Node.parent.has(Node.data=='some data'), - "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE nodes_1.id = nodes.parent_id AND nodes_1.data = :data_1)" - ) - - self._test( - Node.parent == Node(id=7), - ":param_1 = nodes.parent_id" - ) - - self._test( - nalias.parent == Node(id=7), - ":param_1 = nodes_1.parent_id" - ) - - self._test( - nalias.parent != Node(id=7), - 'nodes_1.parent_id != :parent_id_1 OR nodes_1.parent_id IS NULL' - ) - - self._test( - nalias.children.contains(Node(id=7)), "nodes_1.id = :param_1" - ) - - def test_op(self): - self._test(User.name.op('ilike')('17'), "users.name ilike :name_1") - - def test_in(self): - self._test(User.id.in_(['a', 'b']), - "users.id IN (:id_1, :id_2)") - - def test_in_on_relation_not_supported(self): - self.assertRaises(NotImplementedError, Address.user.in_, [User(id=5)]) - - def test_between(self): - self._test(User.id.between('a', 'b'), - "users.id BETWEEN :id_1 AND :id_2") - - def test_selfref_between(self): - ualias = aliased(User) - self._test(User.id.between(ualias.id, ualias.id), "users.id BETWEEN users_1.id AND users_1.id") - self._test(ualias.id.between(User.id, User.id), "users_1.id BETWEEN users.id AND users.id") - - def test_clauses(self): - for (expr, compare) in ( - (func.max(User.id), "max(users.id)"), - (User.id.desc(), "users.id DESC"), - (between(5, User.id, Address.id), ":param_1 BETWEEN users.id AND addresses.id"), - # this one would require adding compile() to InstrumentedScalarAttribute. do we want this ? - #(User.id, "users.id") - ): - c = expr.compile(dialect=default.DefaultDialect()) - assert str(c) == compare, "%s != %s" % (str(c), compare) - - -class RawSelectTest(QueryTest, AssertsCompiledSQL): - """compare a bunch of select() tests with the equivalent Query using straight table/columns. - - Results should be the same as Query should act as a select() pass-thru for ClauseElement entities. - - """ - def test_select(self): - sess = create_session() - - self.assert_compile(sess.query(users).select_from(users.select()).with_labels().statement, - "SELECT users.id AS users_id, users.name AS users_name FROM users, (SELECT users.id AS id, users.name AS name FROM users) AS anon_1") - - self.assert_compile(sess.query(users, exists([1], from_obj=addresses)).with_labels().statement, - "SELECT users.id AS users_id, users.name AS users_name, EXISTS (SELECT 1 FROM addresses) AS anon_1 FROM users") - - # a little tedious here, adding labels to work around Query's auto-labelling. - # also correlate needed explicitly. hmmm..... - # TODO: can we detect only one table in the "froms" and then turn off use_labels ? - s = sess.query(addresses.c.id.label('id'), addresses.c.email_address.label('email')).\ - filter(addresses.c.user_id==users.c.id).correlate(users).statement.alias() - - self.assert_compile(sess.query(users, s.c.email).select_from(users.join(s, s.c.id==users.c.id)).with_labels().statement, - "SELECT users.id AS users_id, users.name AS users_name, anon_1.email AS anon_1_email " - "FROM users JOIN (SELECT addresses.id AS id, addresses.email_address AS email FROM addresses " - "WHERE addresses.user_id = users.id) AS anon_1 ON anon_1.id = users.id", - dialect=default.DefaultDialect() - ) - - x = func.lala(users.c.id).label('foo') - self.assert_compile(sess.query(x).filter(x==5).statement, - "SELECT lala(users.id) AS foo FROM users WHERE lala(users.id) = :param_1", dialect=default.DefaultDialect()) - -class ExpressionTest(QueryTest, AssertsCompiledSQL): - - def test_deferred_instances(self): - session = create_session() - s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).statement - - l = list(session.query(User).instances(s.execute(emailad = 'jack@bean.com'))) - eq_([User(id=7)], l) - - def test_scalar_subquery(self): - session = create_session() - - q = session.query(User.id).filter(User.id==7).subquery() - - q = session.query(User).filter(User.id==q) - - eq_(User(id=7), q.one()) - - - def test_in(self): - session = create_session() - s = session.query(User.id).join(User.addresses).group_by(User.id).having(func.count(Address.id) > 2) - eq_( - session.query(User).filter(User.id.in_(s)).all(), - [User(id=8)] - ) - - def test_union(self): - s = create_session() - - q1 = s.query(User).filter(User.name=='ed').with_labels() - q2 = s.query(User).filter(User.name=='fred').with_labels() - eq_( - s.query(User).from_statement(union(q1, q2).order_by('users_name')).all(), - [User(name='ed'), User(name='fred')] - ) - - def test_select(self): - s = create_session() - - # this is actually not legal on most DBs since the subquery has no alias - q1 = s.query(User).filter(User.name=='ed') - self.assert_compile( - select([q1]), - "SELECT id, name FROM (SELECT users.id AS id, users.name AS name FROM users WHERE users.name = :name_1)", - dialect=default.DefaultDialect() - ) - - def test_join(self): - s = create_session() - - # TODO: do we want aliased() to detect a query and convert to subquery() - # automatically ? - q1 = s.query(Address).filter(Address.email_address=='jack@bean.com') - adalias = aliased(Address, q1.subquery()) - eq_( - s.query(User, adalias).join((adalias, User.id==adalias.user_id)).all(), - [(User(id=7,name=u'jack'), Address(email_address=u'jack@bean.com',user_id=7,id=1))] - ) - -# more slice tests are available in test/orm/generative.py -class SliceTest(QueryTest): - def test_first(self): - assert User(id=7) == create_session().query(User).first() - - assert create_session().query(User).filter(User.id==27).first() is None - - @testing.fails_on_everything_except('sqlite') - def test_limit_offset_applies(self): - """Test that the expected LIMIT/OFFSET is applied for slices. - - The LIMIT/OFFSET syntax differs slightly on all databases, and - query[x:y] executes immediately, so we are asserting against - SQL strings using sqlite's syntax. - - """ - sess = create_session() - q = sess.query(User) - - self.assert_sql(testing.db, lambda: q[10:20], [ - ("SELECT users.id AS users_id, users.name AS users_name FROM users LIMIT 10 OFFSET 10", {}) - ]) - - self.assert_sql(testing.db, lambda: q[:20], [ - ("SELECT users.id AS users_id, users.name AS users_name FROM users LIMIT 20 OFFSET 0", {}) - ]) - - self.assert_sql(testing.db, lambda: q[5:], [ - ("SELECT users.id AS users_id, users.name AS users_name FROM users LIMIT -1 OFFSET 5", {}) - ]) - - self.assert_sql(testing.db, lambda: q[2:2], []) - - self.assert_sql(testing.db, lambda: q[-2:-5], []) - - self.assert_sql(testing.db, lambda: q[-5:-2], [ - ("SELECT users.id AS users_id, users.name AS users_name FROM users", {}) - ]) - - self.assert_sql(testing.db, lambda: q[-5:], [ - ("SELECT users.id AS users_id, users.name AS users_name FROM users", {}) - ]) - - self.assert_sql(testing.db, lambda: q[:], [ - ("SELECT users.id AS users_id, users.name AS users_name FROM users", {}) - ]) - - - -class FilterTest(QueryTest): - def test_basic(self): - assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).all() - - @testing.fails_on('maxdb', 'FIXME: unknown') - def test_limit(self): - assert [User(id=8), User(id=9)] == create_session().query(User).order_by(User.id).limit(2).offset(1).all() - - assert [User(id=8), User(id=9)] == list(create_session().query(User).order_by(User.id)[1:3]) - - assert User(id=8) == create_session().query(User).order_by(User.id)[1] - - assert [] == create_session().query(User).order_by(User.id)[3:3] - assert [] == create_session().query(User).order_by(User.id)[0:0] - - - def test_one_filter(self): - assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.name.endswith('ed')).all() - - def test_contains(self): - """test comparing a collection to an object instance.""" - - sess = create_session() - address = sess.query(Address).get(3) - assert [User(id=8)] == sess.query(User).filter(User.addresses.contains(address)).all() - - try: - sess.query(User).filter(User.addresses == address) - assert False - except sa_exc.InvalidRequestError: - assert True - - assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all() - - try: - assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all() - assert False - except sa_exc.InvalidRequestError: - assert True - - #assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all() - - def test_any(self): - sess = create_session() - - assert [User(id=8), User(id=9)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'))).all() - - assert [User(id=8)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'), id=4)).all() - - assert [User(id=8)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'))).\ - filter(User.addresses.any(id=4)).all() - - assert [User(id=9)] == sess.query(User).filter(User.addresses.any(email_address='fred@fred.com')).all() - - # test that any() doesn't overcorrelate - assert [User(id=7), User(id=8)] == sess.query(User).join("addresses").filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all() - - # test that the contents are not adapted by the aliased join - assert [User(id=7), User(id=8)] == sess.query(User).join("addresses", aliased=True).filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all() - - assert [User(id=10)] == sess.query(User).outerjoin("addresses", aliased=True).filter(~User.addresses.any()).all() - - @testing.crashes('maxdb', 'can dump core') - def test_has(self): - sess = create_session() - assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all() - - assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).all() - - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all() - - # test has() doesn't overcorrelate - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).all() - - # test has() doesnt' get subquery contents adapted by aliased join - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).all() - - dingaling = sess.query(Dingaling).get(2) - assert [User(id=9)] == sess.query(User).filter(User.addresses.any(Address.dingaling==dingaling)).all() - - def test_contains_m2m(self): - sess = create_session() - item = sess.query(Item).get(3) - assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items.contains(item)).all() - - assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(~Order.items.contains(item)).all() - - item2 = sess.query(Item).get(5) - assert [Order(id=3)] == sess.query(Order).filter(Order.items.contains(item)).filter(Order.items.contains(item2)).all() - - - def test_comparison(self): - """test scalar comparison to an object instance""" - - sess = create_session() - user = sess.query(User).get(8) - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user==user).all() - - assert [Address(id=1), Address(id=5)] == sess.query(Address).filter(Address.user!=user).all() - - # generates an IS NULL - assert [] == sess.query(Address).filter(Address.user == None).all() - - assert [Order(id=5)] == sess.query(Order).filter(Order.address == None).all() - - # o2o - dingaling = sess.query(Dingaling).get(2) - assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all() - - # m2m - self.assertEquals(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)]) - self.assertEquals(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)]) - - def test_filter_by(self): - sess = create_session() - user = sess.query(User).get(8) - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter_by(user=user).all() - - # many to one generates IS NULL - assert [] == sess.query(Address).filter_by(user = None).all() - - # one to many generates WHERE NOT EXISTS - assert [User(name='chuck')] == sess.query(User).filter_by(addresses = None).all() - - def test_none_comparison(self): - sess = create_session() - - # o2o - self.assertEquals([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all()) - self.assertEquals([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all()) - - # m2o - self.assertEquals([Order(id=5)], sess.query(Order).filter(Order.address==None).all()) - self.assertEquals([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).order_by(Order.id).filter(Order.address!=None).all()) - - # o2m - self.assertEquals([User(id=10)], sess.query(User).filter(User.addresses==None).all()) - self.assertEquals([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).order_by(User.id).all()) - - -class FromSelfTest(QueryTest, AssertsCompiledSQL): - def test_filter(self): - - assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().all() - - assert [User(id=8), User(id=9)] == create_session().query(User).order_by(User.id).slice(1,3)._from_self().all() - assert [User(id=8)] == list(create_session().query(User).filter(User.id.in_([8,9]))._from_self().order_by(User.id)[0:1]) - - def test_join(self): - assert [ - (User(id=8), Address(id=2)), - (User(id=8), Address(id=3)), - (User(id=8), Address(id=4)), - (User(id=9), Address(id=5)) - ] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().\ - join('addresses').add_entity(Address).order_by(User.id, Address.id).all() - - def test_group_by(self): - eq_( - create_session().query(Address.user_id, func.count(Address.id).label('count')).\ - group_by(Address.user_id).order_by(Address.user_id).all(), - [(7, 1), (8, 3), (9, 1)] - ) - - eq_( - create_session().query(Address.user_id, Address.id).\ - from_self(Address.user_id, func.count(Address.id)).\ - group_by(Address.user_id).order_by(Address.user_id).all(), - [(7, 1), (8, 3), (9, 1)] - ) - - def test_no_eagerload(self): - """test that eagerloads are pushed outwards and not rendered in subqueries.""" - - s = create_session() - - self.assert_compile( - s.query(User).options(eagerload(User.addresses)).from_self().statement, - "SELECT anon_1.users_id, anon_1.users_name, addresses_1.id, addresses_1.user_id, "\ - "addresses_1.email_address FROM (SELECT users.id AS users_id, users.name AS users_name FROM users) AS anon_1 "\ - "LEFT OUTER JOIN addresses AS addresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id" - ) - - def test_aliases(self): - """test that aliased objects are accessible externally to a from_self() call.""" - - s = create_session() - - ualias = aliased(User) - eq_( - s.query(User, ualias).filter(User.id > ualias.id).from_self(User.name, ualias.name). - order_by(User.name, ualias.name).all(), - [ - (u'chuck', u'ed'), - (u'chuck', u'fred'), - (u'chuck', u'jack'), - (u'ed', u'jack'), - (u'fred', u'ed'), - (u'fred', u'jack') - ] - ) - - eq_( - s.query(User, ualias).filter(User.id > ualias.id).from_self(User.name, ualias.name).filter(ualias.name=='ed')\ - .order_by(User.name, ualias.name).all(), - [(u'chuck', u'ed'), (u'fred', u'ed')] - ) - - eq_( - s.query(User, ualias).filter(User.id > ualias.id).from_self(ualias.name, Address.email_address). - join(ualias.addresses).order_by(ualias.name, Address.email_address).all(), - [ - (u'ed', u'fred@fred.com'), - (u'jack', u'ed@bettyboop.com'), - (u'jack', u'ed@lala.com'), - (u'jack', u'ed@wood.com'), - (u'jack', u'fred@fred.com')] - ) - - - def test_multiple_entities(self): - sess = create_session() - - self.assertEquals( - sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().all(), - [ - (User(id=8), Address(id=2)), - (User(id=9), Address(id=5)) - ] - ) - - self.assertEquals( - sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().options(eagerload('addresses')).first(), - - # order_by(User.id, Address.id).first(), - (User(id=8, addresses=[Address(), Address(), Address()]), Address(id=2)), - ) - -class SetOpsTest(QueryTest, AssertsCompiledSQL): - - def test_union(self): - s = create_session() - - fred = s.query(User).filter(User.name=='fred') - ed = s.query(User).filter(User.name=='ed') - jack = s.query(User).filter(User.name=='jack') - - self.assertEquals(fred.union(ed).order_by(User.name).all(), - [User(name='ed'), User(name='fred')] - ) - - self.assertEquals(fred.union(ed, jack).order_by(User.name).all(), - [User(name='ed'), User(name='fred'), User(name='jack')] - ) - - @testing.fails_on('mysql', "mysql doesn't support intersect") - def test_intersect(self): - s = create_session() - - fred = s.query(User).filter(User.name=='fred') - ed = s.query(User).filter(User.name=='ed') - jack = s.query(User).filter(User.name=='jack') - self.assertEquals(fred.intersect(ed, jack).all(), - [] - ) - - self.assertEquals(fred.union(ed).intersect(ed.union(jack)).all(), - [User(name='ed')] - ) - - def test_eager_load(self): - s = create_session() - - fred = s.query(User).filter(User.name=='fred') - ed = s.query(User).filter(User.name=='ed') - jack = s.query(User).filter(User.name=='jack') - - def go(): - self.assertEquals( - fred.union(ed).order_by(User.name).options(eagerload(User.addresses)).all(), - [ - User(name='ed', addresses=[Address(), Address(), Address()]), - User(name='fred', addresses=[Address()]) - ] - ) - self.assert_sql_count(testing.db, go, 1) - - -class AggregateTest(QueryTest): - - def test_sum(self): - sess = create_session() - orders = sess.query(Order).filter(Order.id.in_([2, 3, 4])) - self.assertEquals(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,)) - self.assertEquals(orders.value(func.sum(Order.user_id * Order.address_id)), 79) - - def test_apply(self): - sess = create_session() - assert sess.query(func.sum(Order.user_id * Order.address_id)).filter(Order.id.in_([2, 3, 4])).one() == (79,) - - def test_having(self): - sess = create_session() - assert [User(name=u'ed',id=8)] == sess.query(User).order_by(User.id).group_by(User).join('addresses').having(func.count(Address.id)> 2).all() - - assert [User(name=u'jack',id=7), User(name=u'fred',id=9)] == sess.query(User).order_by(User.id).group_by(User).join('addresses').having(func.count(Address.id)< 2).all() - -class CountTest(QueryTest): - def test_basic(self): - s = create_session() - - eq_(s.query(User).count(), 4) - - eq_(s.query(User).filter(users.c.name.endswith('ed')).count(), 2) - - def test_multiple_entity(self): - s = create_session() - q = s.query(User, Address) - eq_(q.count(), 20) # cartesian product - - q = s.query(User, Address).join(User.addresses) - eq_(q.count(), 5) - - def test_nested(self): - s = create_session() - q = s.query(User, Address).limit(2) - eq_(q.count(), 2) - - q = s.query(User, Address).limit(100) - eq_(q.count(), 20) - - q = s.query(User, Address).join(User.addresses).limit(100) - eq_(q.count(), 5) - - def test_cols(self): - """test that column-based queries always nest.""" - - s = create_session() - - q = s.query(func.count(distinct(User.name))) - eq_(q.count(), 1) - - q = s.query(func.count(distinct(User.name))).distinct() - eq_(q.count(), 1) - - q = s.query(User.name) - eq_(q.count(), 4) - - q = s.query(User.name, Address) - eq_(q.count(), 20) - - q = s.query(Address.user_id) - eq_(q.count(), 5) - eq_(q.distinct().count(), 3) - - -class DistinctTest(QueryTest): - def test_basic(self): - assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).distinct().all() - assert [User(id=7), User(id=9), User(id=8),User(id=10)] == create_session().query(User).distinct().order_by(desc(User.name)).all() - - def test_joined(self): - """test that orderbys from a joined table get placed into the columns clause when DISTINCT is used""" - - sess = create_session() - q = sess.query(User).join('addresses').distinct().order_by(desc(Address.email_address)) - - assert [User(id=7), User(id=9), User(id=8)] == q.all() - - sess.expunge_all() - - # test that it works on embedded eagerload/LIMIT subquery - q = sess.query(User).join('addresses').distinct().options(eagerload('addresses')).order_by(desc(Address.email_address)).limit(2) - - def go(): - assert [ - User(id=7, addresses=[ - Address(id=1) - ]), - User(id=9, addresses=[ - Address(id=5) - ]), - ] == q.all() - self.assert_sql_count(testing.db, go, 1) - - -class YieldTest(QueryTest): - def test_basic(self): - import gc - sess = create_session() - q = iter(sess.query(User).yield_per(1).from_statement("select * from users")) - - ret = [] - self.assertEquals(len(sess.identity_map), 0) - ret.append(q.next()) - ret.append(q.next()) - self.assertEquals(len(sess.identity_map), 2) - ret.append(q.next()) - ret.append(q.next()) - self.assertEquals(len(sess.identity_map), 4) - try: - q.next() - assert False - except StopIteration: - pass - -class TextTest(QueryTest): - def test_fulltext(self): - assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).from_statement("select * from users order by id").all() - - assert User(id=7) == create_session().query(User).from_statement("select * from users order by id").first() - assert None == create_session().query(User).from_statement("select * from users where name='nonexistent'").first() - - def test_fragment(self): - assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (8, 9)").all() - - assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter("id=9").all() - - assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(User.id==9).all() - - def test_binds(self): - assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all() - - def test_as_column(self): - s = create_session() - self.assertRaises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name")) - - eq_(s.query(User.id, "name").order_by(User.id).all(), [(7, u'jack'), (8, u'ed'), (9, u'fred'), (10, u'chuck')]) - -class ParentTest(QueryTest): - def test_o2m(self): - sess = create_session() - q = sess.query(User) - - u1 = q.filter_by(name='jack').one() - - # test auto-lookup of property - o = sess.query(Order).with_parent(u1).all() - assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o - - # test with explicit property - o = sess.query(Order).with_parent(u1, property='orders').all() - assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o - - o = sess.query(Order).filter(with_parent(u1, User.orders)).all() - assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o - - # test static method - @testing.uses_deprecated(".*Use sqlalchemy.orm.with_parent") - def go(): - o = Query.query_from_parent(u1, property='orders', session=sess).all() - assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o - go() - - # test generative criterion - o = sess.query(Order).with_parent(u1).filter(orders.c.id>2).all() - assert [Order(description="order 3"), Order(description="order 5")] == o - - # test against None for parent? this can't be done with the current API since we don't know - # what mapper to use - #assert sess.query(Order).with_parent(None, property='addresses').all() == [Order(description="order 5")] - - def test_noparent(self): - sess = create_session() - q = sess.query(User) - - u1 = q.filter_by(name='jack').one() - - try: - q = sess.query(Item).with_parent(u1) - assert False - except sa_exc.InvalidRequestError, e: - assert str(e) == "Could not locate a property which relates instances of class 'Item' to instances of class 'User'" - - def test_m2m(self): - sess = create_session() - i1 = sess.query(Item).filter_by(id=2).one() - k = sess.query(Keyword).with_parent(i1).all() - assert [Keyword(name='red'), Keyword(name='small'), Keyword(name='square')] == k - - -class JoinTest(QueryTest): - - def test_overlapping_paths(self): - for aliased in (True,False): - # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack) - result = create_session().query(User).join(['orders', 'items'], aliased=aliased).filter_by(id=3).join(['orders','address'], aliased=aliased).filter_by(id=1).all() - assert [User(id=7, name='jack')] == result - - def test_overlapping_paths_outerjoin(self): - result = create_session().query(User).outerjoin(['orders', 'items']).filter_by(id=3).outerjoin(['orders','address']).filter_by(id=1).all() - assert [User(id=7, name='jack')] == result - - def test_from_joinpoint(self): - sess = create_session() - - for oalias,ialias in [(True, True), (False, False), (True, False), (False, True)]: - self.assertEquals( - sess.query(User).join('orders', aliased=oalias).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description == 'item 4').all(), - [User(name='jack')] - ) - - # use middle criterion - self.assertEquals( - sess.query(User).join('orders', aliased=oalias).filter(Order.user_id==9).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description=='item 4').all(), - [] - ) - - orderalias = aliased(Order) - itemalias = aliased(Item) - self.assertEquals( - sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(itemalias.description == 'item 4').all(), - [User(name='jack')] - ) - self.assertEquals( - sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(orderalias.user_id==9).filter(itemalias.description=='item 4').all(), - [] - ) - - def test_backwards_join(self): - # a more controversial feature. join from - # User->Address, but the onclause is Address.user. - - sess = create_session() - - self.assertEquals( - sess.query(User).join(Address.user).filter(Address.email_address=='ed@wood.com').all(), - [User(id=8,name=u'ed')] - ) - - # its actually not so controversial if you view it in terms - # of multiple entities. - self.assertEquals( - sess.query(User, Address).join(Address.user).filter(Address.email_address=='ed@wood.com').all(), - [(User(id=8,name=u'ed'), Address(email_address='ed@wood.com'))] - ) - - # this was the controversial part. now, raise an error if the feature is abused. - # before the error raise was added, this would silently work..... - self.assertRaises( - sa_exc.InvalidRequestError, - sess.query(User).join, (Address, Address.user), - ) - - # but this one would silently fail - adalias = aliased(Address) - self.assertRaises( - sa_exc.InvalidRequestError, - sess.query(User).join, (adalias, Address.user), - ) - - def test_multiple_with_aliases(self): - sess = create_session() - - ualias = aliased(User) - oalias1 = aliased(Order) - oalias2 = aliased(Order) - result = sess.query(ualias).join((oalias1, ualias.orders), (oalias2, ualias.orders)).\ - filter(or_(oalias1.user_id==9, oalias2.user_id==7)).all() - self.assertEquals(result, [User(id=7,name=u'jack'), User(id=9,name=u'fred')]) - - def test_orderby_arg_bug(self): - sess = create_session() - # no arg error - result = sess.query(User).join('orders', aliased=True).order_by([Order.id]).reset_joinpoint().order_by(users.c.id).all() - - def test_no_onclause(self): - sess = create_session() - - self.assertEquals( - sess.query(User).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(), - [User(name='jack')] - ) - - self.assertEquals( - sess.query(User.name).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(), - [('jack',)] - ) - - self.assertEquals( - sess.query(User).join(Order, (Item, Order.items)).filter(Item.description == 'item 4').all(), - [User(name='jack')] - ) - - def test_clause_onclause(self): - sess = create_session() - - self.assertEquals( - sess.query(User).join( - (Order, User.id==Order.user_id), - (order_items, Order.id==order_items.c.order_id), - (Item, order_items.c.item_id==Item.id) - ).filter(Item.description == 'item 4').all(), - [User(name='jack')] - ) - - self.assertEquals( - sess.query(User.name).join( - (Order, User.id==Order.user_id), - (order_items, Order.id==order_items.c.order_id), - (Item, order_items.c.item_id==Item.id) - ).filter(Item.description == 'item 4').all(), - [('jack',)] - ) - - ualias = aliased(User) - self.assertEquals( - sess.query(ualias.name).join( - (Order, ualias.id==Order.user_id), - (order_items, Order.id==order_items.c.order_id), - (Item, order_items.c.item_id==Item.id) - ).filter(Item.description == 'item 4').all(), - [('jack',)] - ) - - # explicit onclause with from_self(), means - # the onclause must be aliased against the query's custom - # FROM object - self.assertEquals( - sess.query(User).order_by(User.id).offset(2).from_self().join( - (Order, User.id==Order.user_id) - ).all(), - [User(name='fred')] - ) - - # same with an explicit select_from() - self.assertEquals( - sess.query(User).select_from(select([users]).order_by(User.id).offset(2).alias()).join( - (Order, User.id==Order.user_id) - ).all(), - [User(name='fred')] - ) - - - def test_aliased_classes(self): - sess = create_session() - - (user7, user8, user9, user10) = sess.query(User).all() - (address1, address2, address3, address4, address5) = sess.query(Address).all() - expected = [(user7, address1), - (user8, address2), - (user8, address3), - (user8, address4), - (user9, address5), - (user10, None)] - - q = sess.query(User) - AdAlias = aliased(Address) - q = q.add_entity(AdAlias).select_from(outerjoin(User, AdAlias)) - l = q.order_by(User.id, AdAlias.id).all() - self.assertEquals(l, expected) - - sess.expunge_all() - - q = sess.query(User).add_entity(AdAlias) - l = q.select_from(outerjoin(User, AdAlias)).filter(AdAlias.email_address=='ed@bettyboop.com').all() - self.assertEquals(l, [(user8, address3)]) - - l = q.select_from(outerjoin(User, AdAlias, 'addresses')).filter(AdAlias.email_address=='ed@bettyboop.com').all() - self.assertEquals(l, [(user8, address3)]) - - l = q.select_from(outerjoin(User, AdAlias, User.id==AdAlias.user_id)).filter(AdAlias.email_address=='ed@bettyboop.com').all() - self.assertEquals(l, [(user8, address3)]) - - # this is the first test where we are joining "backwards" - from AdAlias to User even though - # the query is against User - q = sess.query(User, AdAlias) - l = q.join(AdAlias.user).filter(User.name=='ed') - self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) - - q = sess.query(User, AdAlias).select_from(join(AdAlias, User, AdAlias.user)).filter(User.name=='ed') - self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) - - def test_implicit_joins_from_aliases(self): - sess = create_session() - OrderAlias = aliased(Order) - - self.assertEquals( - sess.query(OrderAlias).join('items').filter_by(description='item 3').\ - order_by(OrderAlias.id).all(), - [ - Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), - Order(address_id=4,description=u'order 2',isopen=0,user_id=9,id=2), - Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3) - ] - ) - - self.assertEquals( - sess.query(User, OrderAlias, Item.description).join(('orders', OrderAlias), 'items').filter_by(description='item 3').\ - order_by(User.id, OrderAlias.id).all(), - [ - (User(name=u'jack',id=7), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), u'item 3'), - (User(name=u'jack',id=7), Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), u'item 3'), - (User(name=u'fred',id=9), Order(address_id=4,description=u'order 2',isopen=0,user_id=9,id=2), u'item 3') - ] - ) - - def test_aliased_classes_m2m(self): - sess = create_session() - - (order1, order2, order3, order4, order5) = sess.query(Order).all() - (item1, item2, item3, item4, item5) = sess.query(Item).all() - expected = [ - (order1, item1), - (order1, item2), - (order1, item3), - (order2, item1), - (order2, item2), - (order2, item3), - (order3, item3), - (order3, item4), - (order3, item5), - (order4, item1), - (order4, item5), - (order5, item5), - ] - - q = sess.query(Order) - q = q.add_entity(Item).select_from(join(Order, Item, 'items')).order_by(Order.id, Item.id) - l = q.all() - self.assertEquals(l, expected) - - IAlias = aliased(Item) - q = sess.query(Order, IAlias).select_from(join(Order, IAlias, 'items')).filter(IAlias.description=='item 3') - l = q.all() - self.assertEquals(l, - [ - (order1, item3), - (order2, item3), - (order3, item3), - ] - ) - - def test_reset_joinpoint(self): - for aliased in (True, False): - # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack) - result = create_session().query(User).join(['orders', 'items'], aliased=aliased).filter_by(id=3).reset_joinpoint().join(['orders','address'], aliased=aliased).filter_by(id=1).all() - assert [User(id=7, name='jack')] == result - - result = create_session().query(User).outerjoin(['orders', 'items'], aliased=aliased).filter_by(id=3).reset_joinpoint().outerjoin(['orders','address'], aliased=aliased).filter_by(id=1).all() - assert [User(id=7, name='jack')] == result - - def test_overlap_with_aliases(self): - oalias = orders.alias('oalias') - - result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_(["order 1", "order 2", "order 3"])).join(['orders', 'items']).order_by(User.id).all() - assert [User(id=7, name='jack'), User(id=9, name='fred')] == result - - result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_(["order 1", "order 2", "order 3"])).join(['orders', 'items']).filter_by(id=4).all() - assert [User(id=7, name='jack')] == result - - def test_aliased(self): - """test automatic generation of aliased joins.""" - - sess = create_session() - - # test a basic aliasized path - q = sess.query(User).join('addresses', aliased=True).filter_by(email_address='jack@bean.com') - assert [User(id=7)] == q.all() - - q = sess.query(User).join('addresses', aliased=True).filter(Address.email_address=='jack@bean.com') - assert [User(id=7)] == q.all() - - q = sess.query(User).join('addresses', aliased=True).filter(or_(Address.email_address=='jack@bean.com', Address.email_address=='fred@fred.com')) - assert [User(id=7), User(id=9)] == q.all() - - # test two aliasized paths, one to 'orders' and the other to 'orders','items'. - # one row is returned because user 7 has order 3 and also has order 1 which has item 1 - # this tests a o2m join and a m2m join. - q = sess.query(User).join('orders', aliased=True).filter(Order.description=="order 3").join(['orders', 'items'], aliased=True).filter(Item.description=="item 1") - assert q.count() == 1 - assert [User(id=7)] == q.all() - - # test the control version - same joins but not aliased. rows are not returned because order 3 does not have item 1 - q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Item.description=="item 1") - assert [] == q.all() - assert q.count() == 0 - - # the left half of the join condition of the any() is aliased. - q = sess.query(User).join('orders', aliased=True).filter(Order.items.any(Item.description=='item 4')) - assert [User(id=7)] == q.all() - - # test that aliasing gets reset when join() is called - q = sess.query(User).join('orders', aliased=True).filter(Order.description=="order 3").join('orders', aliased=True).filter(Order.description=="order 5") - assert q.count() == 1 - assert [User(id=7)] == q.all() - - def test_aliased_order_by(self): - sess = create_session() - - ualias = aliased(User) - self.assertEquals( - sess.query(User, ualias).filter(User.id > ualias.id).order_by(desc(ualias.id), User.name).all(), - [ - (User(id=10,name=u'chuck'), User(id=9,name=u'fred')), - (User(id=10,name=u'chuck'), User(id=8,name=u'ed')), - (User(id=9,name=u'fred'), User(id=8,name=u'ed')), - (User(id=10,name=u'chuck'), User(id=7,name=u'jack')), - (User(id=8,name=u'ed'), User(id=7,name=u'jack')), - (User(id=9,name=u'fred'), User(id=7,name=u'jack')) - ] - ) - - def test_plain_table(self): - - sess = create_session() - - self.assertEquals( - sess.query(User.name).join((addresses, User.id==addresses.c.user_id)).order_by(User.id).all(), - [(u'jack',), (u'ed',), (u'ed',), (u'ed',), (u'fred',)] - ) - - -class MultiplePathTest(_base.MappedTest): - def define_tables(self, metadata): - global t1, t2, t1t2_1, t1t2_2 - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)) - ) - t2 = Table('t2', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)) - ) - - t1t2_1 = Table('t1t2_1', metadata, - Column('t1id', Integer, ForeignKey('t1.id')), - Column('t2id', Integer, ForeignKey('t2.id')) - ) - - t1t2_2 = Table('t1t2_2', metadata, - Column('t1id', Integer, ForeignKey('t1.id')), - Column('t2id', Integer, ForeignKey('t2.id')) - ) - - def test_basic(self): - class T1(object):pass - class T2(object):pass - - mapper(T1, t1, properties={ - 't2s_1':relation(T2, secondary=t1t2_1), - 't2s_2':relation(T2, secondary=t1t2_2), - }) - mapper(T2, t2) - - q = create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint() - self.assertRaisesMessage(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.", - q.join, 't2s_2' - ) - - create_session().query(T1).join('t2s_1', aliased=True).filter(t2.c.id==5).reset_joinpoint().join('t2s_2').all() - create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2', aliased=True).all() - -class SynonymTest(QueryTest): - - def setup_mappers(self): - mapper(User, users, properties={ - 'name_syn':synonym('name'), - 'addresses':relation(Address), - 'orders':relation(Order, backref='user'), # o2m, m2o - 'orders_syn':synonym('orders') - }) - mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items), #m2m - 'address':relation(Address), # m2o - 'items_syn':synonym('items') - }) - mapper(Item, items, properties={ - 'keywords':relation(Keyword, secondary=item_keywords) #m2m - }) - mapper(Keyword, keywords) - - def test_joins(self): - for j in ( - ['orders', 'items'], - ['orders_syn', 'items'], - ['orders', 'items_syn'], - ['orders_syn', 'items_syn'], - ): - result = create_session().query(User).join(j).filter_by(id=3).all() - assert [User(id=7, name='jack'), User(id=9, name='fred')] == result - - def test_with_parent(self): - for nameprop, orderprop in ( - ('name', 'orders'), - ('name_syn', 'orders'), - ('name', 'orders_syn'), - ('name_syn', 'orders_syn'), - ): - sess = create_session() - q = sess.query(User) - - u1 = q.filter_by(**{nameprop:'jack'}).one() - - o = sess.query(Order).with_parent(u1, property=orderprop).all() - assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o - -class InstancesTest(QueryTest, AssertsCompiledSQL): - - def test_from_alias(self): - - query = users.select(users.c.id==7).union(users.select(users.c.id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True,order_by=['ulist.id', addresses.c.id]) - sess =create_session() - q = sess.query(User) - - def go(): - l = list(q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute())) - assert self.static.user_address_result == l - self.assert_sql_count(testing.db, go, 1) - - sess.expunge_all() - - def go(): - l = q.options(contains_alias('ulist'), contains_eager('addresses')).from_statement(query).all() - assert self.static.user_address_result == l - self.assert_sql_count(testing.db, go, 1) - - # better way. use select_from() - def go(): - l = sess.query(User).select_from(query).options(contains_eager('addresses')).all() - assert self.static.user_address_result == l - self.assert_sql_count(testing.db, go, 1) - - # same thing, but alias addresses, so that the adapter generated by select_from() is wrapped within - # the adapter created by contains_eager() - adalias = addresses.alias() - query = users.select(users.c.id==7).union(users.select(users.c.id>7)).alias('ulist').outerjoin(adalias).select(use_labels=True,order_by=['ulist.id', adalias.c.id]) - def go(): - l = sess.query(User).select_from(query).options(contains_eager('addresses', alias=adalias)).all() - assert self.static.user_address_result == l - self.assert_sql_count(testing.db, go, 1) - - def test_contains_eager(self): - sess = create_session() - - # test that contains_eager suppresses the normal outer join rendering - q = sess.query(User).outerjoin(User.addresses).options(contains_eager(User.addresses)).order_by(User.id) - self.assert_compile(q.with_labels().statement, - "SELECT addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "\ - "addresses.email_address AS addresses_email_address, users.id AS users_id, "\ - "users.name AS users_name FROM users LEFT OUTER JOIN addresses "\ - "ON users.id = addresses.user_id ORDER BY users.id" - , dialect=default.DefaultDialect()) - - def go(): - assert self.static.user_address_result == q.all() - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - - adalias = addresses.alias() - q = sess.query(User).select_from(users.outerjoin(adalias)).options(contains_eager(User.addresses, alias=adalias)) - def go(): - self.assertEquals(self.static.user_address_result, q.order_by(User.id).all()) - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - - selectquery = users.outerjoin(addresses).select(users.c.id<10, use_labels=True, order_by=[users.c.id, addresses.c.id]) - q = sess.query(User) - - def go(): - l = list(q.options(contains_eager('addresses')).instances(selectquery.execute())) - assert self.static.user_address_result[0:3] == l - self.assert_sql_count(testing.db, go, 1) - - sess.expunge_all() - - def go(): - l = list(q.options(contains_eager(User.addresses)).instances(selectquery.execute())) - assert self.static.user_address_result[0:3] == l - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - - def go(): - l = q.options(contains_eager('addresses')).from_statement(selectquery).all() - assert self.static.user_address_result[0:3] == l - self.assert_sql_count(testing.db, go, 1) - - def test_contains_eager_alias(self): - adalias = addresses.alias('adalias') - selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.id, adalias.c.id]) - sess = create_session() - q = sess.query(User) - - # string alias name - def go(): - l = list(q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())) - assert self.static.user_address_result == l - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - - # expression.Alias object - def go(): - l = list(q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())) - assert self.static.user_address_result == l - self.assert_sql_count(testing.db, go, 1) - - sess.expunge_all() - - # Aliased object - adalias = aliased(Address) - def go(): - l = q.options(contains_eager('addresses', alias=adalias)).outerjoin((adalias, User.addresses)).order_by(User.id, adalias.id) - assert self.static.user_address_result == l.all() - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - - oalias = orders.alias('o1') - ialias = items.alias('i1') - query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True).order_by(users.c.id, oalias.c.id, ialias.c.id) - q = create_session().query(User) - # test using string alias with more than one level deep - def go(): - l = list(q.options(contains_eager('orders', alias='o1'), contains_eager('orders.items', alias='i1')).instances(query.execute())) - assert self.static.user_order_result == l - self.assert_sql_count(testing.db, go, 1) - - sess.expunge_all() - - # test using Alias with more than one level deep - def go(): - l = list(q.options(contains_eager('orders', alias=oalias), contains_eager('orders.items', alias=ialias)).instances(query.execute())) - assert self.static.user_order_result == l - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - - # test using Aliased with more than one level deep - oalias = aliased(Order) - ialias = aliased(Item) - def go(): - l = q.options(contains_eager(User.orders, alias=oalias), contains_eager(User.orders, Order.items, alias=ialias)).\ - outerjoin((oalias, User.orders), (ialias, oalias.items)).order_by(User.id, oalias.id, ialias.id) - assert self.static.user_order_result == l.all() - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - - def test_mixed_eager_contains_with_limit(self): - sess = create_session() - - q = sess.query(User) - def go(): - # outerjoin to User.orders, offset 1/limit 2 so we get user 7 + second two orders. - # then eagerload the addresses. User + Order columns go into the subquery, address - # left outer joins to the subquery, eagerloader for User.orders applies context.adapter - # to result rows. This was [ticket:1180]. - l = q.outerjoin(User.orders).options(eagerload(User.addresses), contains_eager(User.orders)).order_by(User.id, Order.id).offset(1).limit(2).all() - eq_(l, [User(id=7, - addresses=[Address(email_address=u'jack@bean.com',user_id=7,id=1)], - name=u'jack', - orders=[ - Order(address_id=1,user_id=7,description=u'order 3',isopen=1,id=3), - Order(address_id=None,user_id=7,description=u'order 5',isopen=0,id=5) - ])]) - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - - def go(): - # same as above, except Order is aliased, so two adapters are applied by the - # eager loader - oalias = aliased(Order) - l = q.outerjoin((User.orders, oalias)).options(eagerload(User.addresses), contains_eager(User.orders, alias=oalias)).order_by(User.id, oalias.id).offset(1).limit(2).all() - eq_(l, [User(id=7, - addresses=[Address(email_address=u'jack@bean.com',user_id=7,id=1)], - name=u'jack', - orders=[ - Order(address_id=1,user_id=7,description=u'order 3',isopen=1,id=3), - Order(address_id=None,user_id=7,description=u'order 5',isopen=0,id=5) - ])]) - self.assert_sql_count(testing.db, go, 1) - - -class MixedEntitiesTest(QueryTest): - - def test_values(self): - sess = create_session() - - assert list(sess.query(User).values()) == list() - - sel = users.select(User.id.in_([7, 8])).alias() - q = sess.query(User) - q2 = q.select_from(sel).values(User.name) - self.assertEquals(list(q2), [(u'jack',), (u'ed',)]) - - q = sess.query(User) - q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String)) - self.assertEquals(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')]) - - q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address) - self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) - - q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice(1, 3).values(User.name, Address.email_address) - self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')]) - - adalias = aliased(Address) - q2 = q.join(('addresses', adalias)).filter(User.name.like('%e%')).values(User.name, adalias.email_address) - self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) - - q2 = q.values(func.count(User.name)) - assert q2.next() == (4,) - - q2 = q.select_from(sel).filter(User.id==8).values(User.name, sel.c.name, User.name) - self.assertEquals(list(q2), [(u'ed', u'ed', u'ed')]) - - # using User.xxx is alised against "sel", so this query returns nothing - q2 = q.select_from(sel).filter(User.id==8).filter(User.id>sel.c.id).values(User.name, sel.c.name, User.name) - self.assertEquals(list(q2), []) - - # whereas this uses users.c.xxx, is not aliased and creates a new join - q2 = q.select_from(sel).filter(users.c.id==8).filter(users.c.id>sel.c.id).values(users.c.name, sel.c.name, User.name) - self.assertEquals(list(q2), [(u'ed', u'jack', u'jack')]) - - @testing.fails_on('mssql', 'FIXME: unknown') - def test_values_specific_order_by(self): - sess = create_session() - - assert list(sess.query(User).values()) == list() - - sel = users.select(User.id.in_([7, 8])).alias() - q = sess.query(User) - u2 = aliased(User) - q2 = q.select_from(sel).filter(u2.id>1).order_by([User.id, sel.c.id, u2.id]).values(User.name, sel.c.name, u2.name) - self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')]) - - @testing.fails_on('mssql', 'FIXME: unknown') - def test_values_with_boolean_selects(self): - """Tests a values clause that works with select boolean evaluations""" - sess = create_session() - - q = sess.query(User) - q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%'))) - self.assertEquals(list(q2), [(True, 1), (False, 3)]) - - def test_correlated_subquery(self): - """test that a subquery constructed from ORM attributes doesn't leak out - those entities to the outermost query. - - """ - sess = create_session() - - subq = select([func.count()]).\ - where(User.id==Address.user_id).\ - correlate(users).\ - label('count') - - # we don't want Address to be outside of the subquery here - self.assertEquals( - list(sess.query(User, subq)[0:3]), - [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)] - ) - - # same thing without the correlate, as it should - # not be needed - subq = select([func.count()]).\ - where(User.id==Address.user_id).\ - label('count') - - # we don't want Address to be outside of the subquery here - self.assertEquals( - list(sess.query(User, subq)[0:3]), - [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)] - ) - - def test_tuple_labeling(self): - sess = create_session() - for row in sess.query(User, Address).join(User.addresses).all(): - self.assertEquals(set(row.keys()), set(['User', 'Address'])) - self.assertEquals(row.User, row[0]) - self.assertEquals(row.Address, row[1]) - - for row in sess.query(User.name, User.id.label('foobar')): - self.assertEquals(set(row.keys()), set(['name', 'foobar'])) - self.assertEquals(row.name, row[0]) - self.assertEquals(row.foobar, row[1]) - - for row in sess.query(User).values(User.name, User.id.label('foobar')): - self.assertEquals(set(row.keys()), set(['name', 'foobar'])) - self.assertEquals(row.name, row[0]) - self.assertEquals(row.foobar, row[1]) - - oalias = aliased(Order) - for row in sess.query(User, oalias).join(User.orders).all(): - self.assertEquals(set(row.keys()), set(['User'])) - self.assertEquals(row.User, row[0]) - - oalias = aliased(Order, name='orders') - for row in sess.query(User, oalias).join(User.orders).all(): - self.assertEquals(set(row.keys()), set(['User', 'orders'])) - self.assertEquals(row.User, row[0]) - self.assertEquals(row.orders, row[1]) - - - def test_column_queries(self): - sess = create_session() - - self.assertEquals(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)]) - - sel = users.select(User.id.in_([7, 8])).alias() - q = sess.query(User.name) - q2 = q.select_from(sel).all() - self.assertEquals(list(q2), [(u'jack',), (u'ed',)]) - - self.assertEquals(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [ - (u'jack', u'jack@bean.com'), (u'ed', u'ed@wood.com'), - (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), - (u'fred', u'fred@fred.com') - ]) - - self.assertEquals(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(), - [(u'jack', 1), (u'ed', 3), (u'fred', 1), (u'chuck', 0)] - ) - - self.assertEquals(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), - [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)] - ) - - self.assertEquals(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), - [(1, User(name='jack',id=7)), (3, User(name='ed',id=8)), (1, User(name='fred',id=9)), (0, User(name='chuck',id=10))] - ) - - adalias = aliased(Address) - self.assertEquals(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(), - [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)] - ) - - self.assertEquals(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(), - [(1, User(name=u'jack',id=7)), (3, User(name=u'ed',id=8)), (1, User(name=u'fred',id=9)), (0, User(name=u'chuck',id=10))] - ) - - # select from aliasing + explicit aliasing - self.assertEquals( - sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).order_by(User.id, adalias.id).all(), - [ - (User(name=u'jack',id=7), u'jack@bean.com'), - (User(name=u'ed',id=8), u'ed@wood.com'), - (User(name=u'ed',id=8), u'ed@bettyboop.com'), - (User(name=u'ed',id=8), u'ed@lala.com'), - (User(name=u'fred',id=9), u'fred@fred.com'), - (User(name=u'chuck',id=10), None) - ] - ) - - # anon + select from aliasing - self.assertEquals( - sess.query(User).join(User.addresses, aliased=True).filter(Address.email_address.like('%ed%')).from_self().all(), - [ - User(name=u'ed',id=8), - User(name=u'fred',id=9), - ] - ) - - # test eager aliasing, with/without select_from aliasing - for q in [ - sess.query(User, adalias.email_address).outerjoin((User.addresses, adalias)).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10), - sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10), - ]: - self.assertEquals( - - q.all(), - [(User(addresses=[Address(user_id=7,email_address=u'jack@bean.com',id=1)],name=u'jack',id=7), u'jack@bean.com'), - (User(addresses=[ - Address(user_id=8,email_address=u'ed@wood.com',id=2), - Address(user_id=8,email_address=u'ed@bettyboop.com',id=3), - Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@wood.com'), - (User(addresses=[ - Address(user_id=8,email_address=u'ed@wood.com',id=2), - Address(user_id=8,email_address=u'ed@bettyboop.com',id=3), - Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@bettyboop.com'), - (User(addresses=[ - Address(user_id=8,email_address=u'ed@wood.com',id=2), - Address(user_id=8,email_address=u'ed@bettyboop.com',id=3), - Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@lala.com'), - (User(addresses=[Address(user_id=9,email_address=u'fred@fred.com',id=5)],name=u'fred',id=9), u'fred@fred.com'), - - (User(addresses=[],name=u'chuck',id=10), None)] - ) - - def test_column_from_limited_eagerload(self): - sess = create_session() - - def go(): - results = sess.query(User).limit(1).options(eagerload('addresses')).add_column(User.name).all() - self.assertEquals(results, [(User(name='jack'), 'jack')]) - self.assert_sql_count(testing.db, go, 1) - - def test_self_referential(self): - - sess = create_session() - oalias = aliased(Order) - - for q in [ - sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id).order_by(Order.id, oalias.id), - sess.query(Order, oalias)._from_self().filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id).order_by(Order.id, oalias.id), - - # same thing, but reversed. - sess.query(oalias, Order)._from_self().filter(oalias.user_id==Order.user_id).filter(oalias.user_id==7).filter(Order.idoalias.id)._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)), - - # gratuitous four layers - sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id)._from_self()._from_self()._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)), - - ]: - - self.assertEquals( - q.all(), - [ - (Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)), - (Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)), - (Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5), Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3)) - ] - ) - - def test_multi_mappers(self): - - test_session = create_session() - - (user7, user8, user9, user10) = test_session.query(User).all() - (address1, address2, address3, address4, address5) = test_session.query(Address).all() - - expected = [(user7, address1), - (user8, address2), - (user8, address3), - (user8, address4), - (user9, address5), - (user10, None)] - - sess = create_session() - - selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id]) - self.assertEquals(list(sess.query(User, Address).instances(selectquery.execute())), expected) - sess.expunge_all() - - for address_entity in (Address, aliased(Address)): - q = sess.query(User).add_entity(address_entity).outerjoin(('addresses', address_entity)).order_by(User.id, address_entity.id) - self.assertEquals(q.all(), expected) - sess.expunge_all() - - q = sess.query(User).add_entity(address_entity) - q = q.join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com') - self.assertEquals(q.all(), [(user8, address3)]) - sess.expunge_all() - - q = sess.query(User, address_entity).join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com') - self.assertEquals(q.all(), [(user8, address3)]) - sess.expunge_all() - - q = sess.query(User, address_entity).join(('addresses', address_entity)).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com') - self.assertEquals(list(util.OrderedSet(q.all())), [(user8, address3)]) - sess.expunge_all() - - def test_aliased_multi_mappers(self): - sess = create_session() - - (user7, user8, user9, user10) = sess.query(User).all() - (address1, address2, address3, address4, address5) = sess.query(Address).all() - - expected = [(user7, address1), - (user8, address2), - (user8, address3), - (user8, address4), - (user9, address5), - (user10, None)] - - q = sess.query(User) - adalias = addresses.alias('adalias') - q = q.add_entity(Address, alias=adalias).select_from(users.outerjoin(adalias)) - l = q.order_by(User.id, adalias.c.id).all() - assert l == expected - - sess.expunge_all() - - q = sess.query(User).add_entity(Address, alias=adalias) - l = q.select_from(users.outerjoin(adalias)).filter(adalias.c.email_address=='ed@bettyboop.com').all() - assert l == [(user8, address3)] - - def test_multi_columns(self): - sess = create_session() - - expected = [(u, u.name) for u in sess.query(User).all()] - - for add_col in (User.name, users.c.name): - assert sess.query(User).add_column(add_col).all() == expected - sess.expunge_all() - - self.assertRaises(sa_exc.InvalidRequestError, sess.query(User).add_column, object()) - - def test_add_multi_columns(self): - """test that add_column accepts a FROM clause.""" - - sess = create_session() - - eq_( - sess.query(User.id).add_column(users).all(), - [(7, 7, u'jack'), (8, 8, u'ed'), (9, 9, u'fred'), (10, 10, u'chuck')] - ) - - def test_multi_columns_2(self): - """test aliased/nonalised joins with the usage of add_column()""" - sess = create_session() - - (user7, user8, user9, user10) = sess.query(User).all() - expected = [(user7, 1), - (user8, 3), - (user9, 1), - (user10, 0) - ] - - q = sess.query(User) - q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(Address.id).label('count')) - self.assertEquals(q.all(), expected) - sess.expunge_all() - - adalias = aliased(Address) - q = sess.query(User) - q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin(('addresses', adalias)).add_column(func.count(adalias.id).label('count')) - self.assertEquals(q.all(), expected) - sess.expunge_all() - - s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id) - q = sess.query(User) - l = q.add_column("count").from_statement(s).all() - assert l == expected - - - def test_raw_columns(self): - sess = create_session() - (user7, user8, user9, user10) = sess.query(User).all() - expected = [ - (user7, 1, "Name:jack"), - (user8, 3, "Name:ed"), - (user9, 1, "Name:fred"), - (user10, 0, "Name:chuck")] - - adalias = addresses.alias() - q = create_session().query(User).add_column(func.count(adalias.c.id))\ - .add_column(("Name:" + users.c.name)).outerjoin(('addresses', adalias))\ - .group_by([c for c in users.c]).order_by(users.c.id) - - assert q.all() == expected - - # test with a straight statement - s = select([users, func.count(addresses.c.id).label('count'), ("Name:" + users.c.name).label('concat')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.id]) - q = create_session().query(User) - l = q.add_column("count").add_column("concat").from_statement(s).all() - assert l == expected - - sess.expunge_all() - - # test with select_from() - q = create_session().query(User).add_column(func.count(addresses.c.id))\ - .add_column(("Name:" + users.c.name)).select_from(users.outerjoin(addresses))\ - .group_by([c for c in users.c]).order_by(users.c.id) - - assert q.all() == expected - sess.expunge_all() - - q = create_session().query(User).add_column(func.count(addresses.c.id))\ - .add_column(("Name:" + users.c.name)).outerjoin('addresses')\ - .group_by([c for c in users.c]).order_by(users.c.id) - - assert q.all() == expected - sess.expunge_all() - - q = create_session().query(User).add_column(func.count(adalias.c.id))\ - .add_column(("Name:" + users.c.name)).outerjoin(('addresses', adalias))\ - .group_by([c for c in users.c]).order_by(users.c.id) - - assert q.all() == expected - sess.expunge_all() - -class ImmediateTest(_fixtures.FixtureTest): - run_inserts = 'once' - run_deletes = None - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Address, addresses) - - mapper(User, users, properties=dict( - addresses=relation(Address))) - - @testing.resolve_artifact_names - def test_one(self): - sess = create_session() - - self.assertRaises(sa.orm.exc.NoResultFound, - sess.query(User).filter(User.id == 99).one) - - eq_(sess.query(User).filter(User.id == 7).one().id, 7) - - self.assertRaises(sa.orm.exc.MultipleResultsFound, - sess.query(User).one) - - self.assertRaises( - sa.orm.exc.NoResultFound, - sess.query(User.id, User.name).filter(User.id == 99).one) - - eq_(sess.query(User.id, User.name).filter(User.id == 7).one(), - (7, 'jack')) - - self.assertRaises(sa.orm.exc.MultipleResultsFound, - sess.query(User.id, User.name).one) - - self.assertRaises(sa.orm.exc.NoResultFound, - (sess.query(User, Address). - join(User.addresses). - filter(Address.id == 99)).one) - - eq_((sess.query(User, Address). - join(User.addresses). - filter(Address.id == 4)).one(), - (User(id=8), Address(id=4))) - - self.assertRaises(sa.orm.exc.MultipleResultsFound, - sess.query(User, Address).join(User.addresses).one) - - @testing.future - def test_getslice(self): - assert False - - @testing.resolve_artifact_names - def test_scalar(self): - sess = create_session() - - eq_(sess.query(User.id).filter_by(id=7).scalar(), 7) - eq_(sess.query(User.id, User.name).filter_by(id=7).scalar(), 7) - eq_(sess.query(User.id).filter_by(id=0).scalar(), None) - eq_(sess.query(User).filter_by(id=7).scalar(), - sess.query(User).filter_by(id=7).one()) - - @testing.resolve_artifact_names - def test_value(self): - sess = create_session() - - eq_(sess.query(User).filter_by(id=7).value(User.id), 7) - eq_(sess.query(User.id, User.name).filter_by(id=7).value(User.id), 7) - eq_(sess.query(User).filter_by(id=0).value(User.id), None) - - sess.bind = sa.testing.db - eq_(sess.query().value(sa.literal_column('1').label('x')), 1) - - -class SelectFromTest(QueryTest): - run_setup_mappers = None - - def test_replace_with_select(self): - mapper(User, users, properties = { - 'addresses':relation(Address) - }) - mapper(Address, addresses) - - sel = users.select(users.c.id.in_([7, 8])).alias() - sess = create_session() - - self.assertEquals(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)]) - - self.assertEquals(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)]) - - self.assertEquals(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [ - User(name='jack',id=7), User(name='ed',id=8) - ]) - - self.assertEquals(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [ - User(name='ed',id=8), User(name='jack',id=7) - ]) - - self.assertEquals(sess.query(User).select_from(sel).options(eagerload('addresses')).first(), - User(name='jack', addresses=[Address(id=1)]) - ) - - def test_join_mapper_order_by(self): - """test that mapper-level order_by is adapted to a selectable.""" - - mapper(User, users, order_by=users.c.id) - - sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() - - self.assertEquals(sess.query(User).select_from(sel).all(), - [ - User(name='jack',id=7), User(name='ed',id=8) - ] - ) - - def test_join_no_order_by(self): - mapper(User, users) - - sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() - - self.assertEquals(sess.query(User).select_from(sel).all(), - [ - User(name='jack',id=7), User(name='ed',id=8) - ] - ) - - def test_join(self): - mapper(User, users, properties = { - 'addresses':relation(Address) - }) - mapper(Address, addresses) - - sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() - - self.assertEquals(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(), - [ - (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)), - (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)), - (User(name='ed',id=8), Address(user_id=8,email_address='ed@bettyboop.com',id=3)), - (User(name='ed',id=8), Address(user_id=8,email_address='ed@lala.com',id=4)) - ] - ) - - adalias = aliased(Address) - self.assertEquals(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(), - [ - (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)), - (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)), - (User(name='ed',id=8), Address(user_id=8,email_address='ed@bettyboop.com',id=3)), - (User(name='ed',id=8), Address(user_id=8,email_address='ed@lala.com',id=4)) - ] - ) - - - def test_more_joins(self): - mapper(User, users, properties={ - 'orders':relation(Order, backref='user'), # o2m, m2o - }) - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m - }) - mapper(Item, items, properties={ - 'keywords':relation(Keyword, secondary=item_keywords, order_by=keywords.c.id) #m2m - }) - mapper(Keyword, keywords) - - sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() - - # TODO: remove - sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all() - - self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ - User(name=u'jack',id=7) - ]) - - self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ - User(name=u'jack',id=7) - ]) - - def go(): - self.assertEquals(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ - User(name=u'jack',orders=[ - Order(description=u'order 1',items=[ - Item(description=u'item 1',keywords=[Keyword(name=u'red'), Keyword(name=u'big'), Keyword(name=u'round')]), - Item(description=u'item 2',keywords=[Keyword(name=u'red',id=2), Keyword(name=u'small',id=5), Keyword(name=u'square')]), - Item(description=u'item 3',keywords=[Keyword(name=u'green',id=3), Keyword(name=u'big',id=4), Keyword(name=u'round',id=6)]) - ]), - Order(description=u'order 3',items=[ - Item(description=u'item 3',keywords=[Keyword(name=u'green',id=3), Keyword(name=u'big',id=4), Keyword(name=u'round',id=6)]), - Item(description=u'item 4',keywords=[],id=4), - Item(description=u'item 5',keywords=[],id=5) - ]), - Order(description=u'order 5',items=[Item(description=u'item 5',keywords=[])])]) - ]) - self.assert_sql_count(testing.db, go, 1) - - sess.expunge_all() - sel2 = orders.select(orders.c.id.in_([1,2,3])) - self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').order_by(Order.id).all(), [ - Order(description=u'order 1',id=1), - Order(description=u'order 2',id=2), - ]) - self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [ - Order(description=u'order 1',id=1), - Order(description=u'order 2',id=2), - ]) - - - def test_replace_with_eager(self): - mapper(User, users, properties = { - 'addresses':relation(Address, order_by=addresses.c.id) - }) - mapper(Address, addresses) - - sel = users.select(users.c.id.in_([7, 8])) - sess = create_session() - - def go(): - self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id).all(), - [ - User(id=7, addresses=[Address(id=1)]), - User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]) - ] - ) - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - - def go(): - self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).order_by(User.id).all(), - [User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])] - ) - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - - def go(): - self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])) - self.assert_sql_count(testing.db, go, 1) - -class CustomJoinTest(QueryTest): - run_setup_mappers = None - - def test_double_same_mappers(self): - """test aliasing of joins with a custom join condition""" - mapper(Address, addresses) - mapper(Order, orders, properties={ - 'items':relation(Item, secondary=order_items, lazy=True, order_by=items.c.id), - }) - mapper(Item, items) - mapper(User, users, properties = dict( - addresses = relation(Address, lazy=True), - open_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 1, users.c.id==orders.c.user_id), lazy=True), - closed_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 0, users.c.id==orders.c.user_id), lazy=True) - )) - q = create_session().query(User) - - assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all() - -class SelfReferentialTest(_base.MappedTest): - run_setup_mappers = 'once' - run_inserts = 'once' - run_deletes = None - - def define_tables(self, metadata): - global nodes - nodes = Table('nodes', metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('nodes.id')), - Column('data', String(30))) - - def insert_data(self): - global Node - - class Node(Base): - def append(self, node): - self.children.append(node) - - mapper(Node, nodes, properties={ - 'children':relation(Node, lazy=True, join_depth=3, - backref=backref('parent', remote_side=[nodes.c.id]) - ) - }) - sess = create_session() - n1 = Node(data='n1') - n1.append(Node(data='n11')) - n1.append(Node(data='n12')) - n1.append(Node(data='n13')) - n1.children[1].append(Node(data='n121')) - n1.children[1].append(Node(data='n122')) - n1.children[1].append(Node(data='n123')) - sess.add(n1) - sess.flush() - sess.close() - - def test_join(self): - sess = create_session() - - node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first() - assert node.data=='n12' - - ret = sess.query(Node.data).join(Node.children, aliased=True).filter_by(data='n122').all() - assert ret == [('n12',)] - - - node = sess.query(Node).join(['children', 'children'], aliased=True).filter_by(data='n122').first() - assert node.data=='n1' - - node = sess.query(Node).filter_by(data='n122').join('parent', aliased=True).filter_by(data='n12').\ - join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first() - assert node.data == 'n122' - - def test_explicit_join(self): - sess = create_session() - - n1 = aliased(Node) - n2 = aliased(Node) - - node = sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data=='n122').first() - assert node.data=='n12' - - node = sess.query(Node).select_from(join(Node, n1, 'children').join(n2, 'children')).\ - filter(n2.data=='n122').first() - assert node.data=='n1' - - # mix explicit and named onclauses - node = sess.query(Node).select_from(join(Node, n1, Node.id==n1.parent_id).join(n2, 'children')).\ - filter(n2.data=='n122').first() - assert node.data=='n1' - - node = sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\ - filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).first() - assert node.data == 'n122' - - self.assertEquals( - list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\ - filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)), - [('n122', 'n12', 'n1')]) - - def test_join_to_nonaliased(self): - sess = create_session() - - n1 = aliased(Node) - - # using 'n1.parent' implicitly joins to unaliased Node - self.assertEquals( - sess.query(n1).join(n1.parent).filter(Node.data=='n1').all(), - [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)] - ) - - # explicit (new syntax) - self.assertEquals( - sess.query(n1).join((Node, n1.parent)).filter(Node.data=='n1').all(), - [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)] - ) - - def test_multiple_explicit_entities(self): - sess = create_session() - - parent = aliased(Node) - grandparent = aliased(Node) - self.assertEquals( - sess.query(Node, parent, grandparent).\ - join((Node.parent, parent), (parent.parent, grandparent)).\ - filter(Node.data=='n122').filter(parent.data=='n12').\ - filter(grandparent.data=='n1').first(), - (Node(data='n122'), Node(data='n12'), Node(data='n1')) - ) - - self.assertEquals( - sess.query(Node, parent, grandparent).\ - join((Node.parent, parent), (parent.parent, grandparent)).\ - filter(Node.data=='n122').filter(parent.data=='n12').\ - filter(grandparent.data=='n1')._from_self().first(), - (Node(data='n122'), Node(data='n12'), Node(data='n1')) - ) - - # same, change order around - self.assertEquals( - sess.query(parent, grandparent, Node).\ - join((Node.parent, parent), (parent.parent, grandparent)).\ - filter(Node.data=='n122').filter(parent.data=='n12').\ - filter(grandparent.data=='n1')._from_self().first(), - (Node(data='n12'), Node(data='n1'), Node(data='n122')) - ) - - self.assertEquals( - sess.query(Node, parent, grandparent).\ - join((Node.parent, parent), (parent.parent, grandparent)).\ - filter(Node.data=='n122').filter(parent.data=='n12').\ - filter(grandparent.data=='n1').\ - options(eagerload(Node.children)).first(), - (Node(data='n122'), Node(data='n12'), Node(data='n1')) - ) - - self.assertEquals( - sess.query(Node, parent, grandparent).\ - join((Node.parent, parent), (parent.parent, grandparent)).\ - filter(Node.data=='n122').filter(parent.data=='n12').\ - filter(grandparent.data=='n1')._from_self().\ - options(eagerload(Node.children)).first(), - (Node(data='n122'), Node(data='n12'), Node(data='n1')) - ) - - - def test_any(self): - sess = create_session() - self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), []) - self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')]) - self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) - - def test_has(self): - sess = create_session() - - self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) - self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), []) - self.assertEquals(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')]) - - def test_contains(self): - sess = create_session() - - n122 = sess.query(Node).filter(Node.data=='n122').one() - self.assertEquals(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')]) - - n13 = sess.query(Node).filter(Node.data=='n13').one() - self.assertEquals(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')]) - - def test_eq_ne(self): - sess = create_session() - - n12 = sess.query(Node).filter(Node.data=='n12').one() - self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) - - self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')]) - -class SelfReferentialM2MTest(_base.MappedTest): - run_setup_mappers = 'once' - run_inserts = 'once' - run_deletes = None - - def define_tables(self, metadata): - global nodes, node_to_nodes - nodes = Table('nodes', metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('data', String(30))) - - node_to_nodes =Table('node_to_nodes', metadata, - Column('left_node_id', Integer, ForeignKey('nodes.id'),primary_key=True), - Column('right_node_id', Integer, ForeignKey('nodes.id'),primary_key=True), - ) - - def insert_data(self): - global Node - - class Node(Base): - pass - - mapper(Node, nodes, properties={ - 'children':relation(Node, lazy=True, secondary=node_to_nodes, - primaryjoin=nodes.c.id==node_to_nodes.c.left_node_id, - secondaryjoin=nodes.c.id==node_to_nodes.c.right_node_id, - ) - }) - sess = create_session() - n1 = Node(data='n1') - n2 = Node(data='n2') - n3 = Node(data='n3') - n4 = Node(data='n4') - n5 = Node(data='n5') - n6 = Node(data='n6') - n7 = Node(data='n7') - - n1.children = [n2, n3, n4] - n2.children = [n3, n6, n7] - n3.children = [n5, n4] - - sess.add(n1) - sess.add(n2) - sess.add(n3) - sess.add(n4) - sess.flush() - sess.close() - - def test_any(self): - sess = create_session() - self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')]) - - def test_contains(self): - sess = create_session() - n4 = sess.query(Node).filter_by(data='n4').one() - - self.assertEquals(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')]) - self.assertEquals(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')]) - - def test_explicit_join(self): - sess = create_session() - - n1 = aliased(Node) - self.assertEquals( - sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data.in_(['n3', 'n7'])).order_by(Node.id).all(), - [Node(data='n1'), Node(data='n2')] - ) - -class ExternalColumnsTest(QueryTest): - """test mappers with SQL-expressions added as column properties.""" - - run_setup_mappers = None - - def test_external_columns_bad(self): - - self.assertRaisesMessage(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={ - 'concat': (users.c.id * 2), - }) - clear_mappers() - - def test_external_columns(self): - """test querying mappings that reference external columns or selectables.""" - - mapper(User, users, properties={ - 'concat': column_property((users.c.id * 2)), - 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users).as_scalar()) - }) - - mapper(Address, addresses, properties={ - 'user':relation(User) - }) - - sess = create_session() - - sess.query(Address).options(eagerload('user')).all() - - self.assertEquals(sess.query(User).all(), - [ - User(id=7, concat=14, count=1), - User(id=8, concat=16, count=3), - User(id=9, concat=18, count=1), - User(id=10, concat=20, count=0), - ] - ) - - address_result = [ - Address(id=1, user=User(id=7, concat=14, count=1)), - Address(id=2, user=User(id=8, concat=16, count=3)), - Address(id=3, user=User(id=8, concat=16, count=3)), - Address(id=4, user=User(id=8, concat=16, count=3)), - Address(id=5, user=User(id=9, concat=18, count=1)) - ] - self.assertEquals(sess.query(Address).all(), address_result) - - # run the eager version twice to test caching of aliased clauses - for x in range(2): - sess.expunge_all() - def go(): - self.assertEquals(sess.query(Address).options(eagerload('user')).all(), address_result) - self.assert_sql_count(testing.db, go, 1) - - ualias = aliased(User) - self.assertEquals( - sess.query(Address, ualias).join(('user', ualias)).all(), - [(address, address.user) for address in address_result] - ) - - self.assertEquals( - sess.query(Address, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(), - [ - (Address(id=1), 1), - (Address(id=2), 3), - (Address(id=3), 3), - (Address(id=4), 3), - (Address(id=5), 1) - ] - ) - - self.assertEquals(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(), - [ - (Address(id=1), 14, 1), - (Address(id=2), 16, 3), - (Address(id=3), 16, 3), - (Address(id=4), 16, 3), - (Address(id=5), 18, 1) - ] - ) - - ua = aliased(User) - self.assertEquals(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(), - [ - (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1), - (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3), - (Address(id=3, user=User(id=8, concat=16, count=3)), 16, 3), - (Address(id=4, user=User(id=8, concat=16, count=3)), 16, 3), - (Address(id=5, user=User(id=9, concat=18, count=1)), 18, 1) - ] - ) - - self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), - [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] - ) - - self.assertEquals(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)), - [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] - ) - - def test_external_columns_eagerload(self): - # in this test, we have a subquery on User that accesses "addresses", underneath - # an eagerload for "addresses". So the "addresses" alias adapter needs to *not* hit - # the "addresses" table within the "user" subquery, but "user" still needs to be adapted. - # therefore the long standing practice of eager adapters being "chained" has been removed - # since its unnecessary and breaks this exact condition. - mapper(User, users, properties={ - 'addresses':relation(Address, backref='user', order_by=addresses.c.id), - 'concat': column_property((users.c.id * 2)), - 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users)) - }) - mapper(Address, addresses) - mapper(Order, orders, properties={ - 'address':relation(Address), # m2o - }) - - sess = create_session() - def go(): - o1 = sess.query(Order).options(eagerload_all('address.user')).get(1) - self.assertEquals(o1.address.user.count, 1) - self.assert_sql_count(testing.db, go, 1) - - sess = create_session() - def go(): - o1 = sess.query(Order).options(eagerload_all('address.user')).first() - self.assertEquals(o1.address.user.count, 1) - self.assert_sql_count(testing.db, go, 1) - -class TestOverlyEagerEquivalentCols(_base.MappedTest): - def define_tables(self, metadata): - global base, sub1, sub2 - base = Table('base', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) - - sub1 = Table('sub1', metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - Column('data', String(50)) - ) - - sub2 = Table('sub2', metadata, - Column('id', Integer, ForeignKey('base.id'), ForeignKey('sub1.id'), primary_key=True), - Column('data', String(50)) - ) - - def test_equivs(self): - class Base(_base.ComparableEntity): - pass - class Sub1(_base.ComparableEntity): - pass - class Sub2(_base.ComparableEntity): - pass - - mapper(Base, base, properties={ - 'sub1':relation(Sub1), - 'sub2':relation(Sub2) - }) - - mapper(Sub1, sub1) - mapper(Sub2, sub2) - sess = create_session() - - s11 = Sub1(data='s11') - s12 = Sub1(data='s12') - s2 = Sub2(data='s2') - b1 = Base(data='b1', sub1=[s11], sub2=[]) - b2 = Base(data='b1', sub1=[s12], sub2=[]) - sess.add(b1) - sess.add(b2) - sess.flush() - - # theres an overlapping ForeignKey here, so not much option except - # to artifically control the flush order - b2.sub2 = [s2] - sess.flush() - - q = sess.query(Base).outerjoin('sub2', aliased=True) - assert sub1.c.id not in q._filter_aliases.equivalents - - self.assertEquals( - sess.query(Base).join('sub1').outerjoin('sub2', aliased=True).\ - filter(Sub1.id==1).one(), - b1 - ) - -class UpdateDeleteTest(_base.MappedTest): - def define_tables(self, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(32)), - Column('age', Integer)) - - Table('documents', metadata, - Column('id', Integer, primary_key=True), - Column('user_id', None, ForeignKey('users.id')), - Column('title', String(32))) - - def setup_classes(self): - class User(_base.ComparableEntity): - pass - - class Document(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def insert_data(self): - users.insert().execute([ - dict(id=1, name='john', age=25), - dict(id=2, name='jack', age=47), - dict(id=3, name='jill', age=29), - dict(id=4, name='jane', age=37), - ]) - - @testing.resolve_artifact_names - def insert_documents(self): - documents.insert().execute([ - dict(id=1, user_id=1, title='foo'), - dict(id=2, user_id=1, title='bar'), - dict(id=3, user_id=2, title='baz'), - ]) - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(User, users) - mapper(Document, documents, properties={ - 'user': relation(User, lazy=False, backref=backref('documents', lazy=True)) - }) - - @testing.resolve_artifact_names - def test_delete(self): - sess = create_session(bind=testing.db, autocommit=False) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete() - - assert john not in sess and jill not in sess - - eq_(sess.query(User).order_by(User.id).all(), [jack,jane]) - - @testing.resolve_artifact_names - def test_delete_with_bindparams(self): - sess = create_session(bind=testing.db, autocommit=False) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter('name = :name').params(name='john').delete() - assert john not in sess - - eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane]) - - @testing.resolve_artifact_names - def test_delete_rollback(self): - sess = sessionmaker()() - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session='evaluate') - assert john not in sess and jill not in sess - sess.rollback() - assert john in sess and jill in sess - - @testing.resolve_artifact_names - def test_delete_rollback_with_fetch(self): - sess = sessionmaker()() - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session='fetch') - assert john not in sess and jill not in sess - sess.rollback() - assert john in sess and jill in sess - - @testing.resolve_artifact_names - def test_delete_without_session_sync(self): - sess = create_session(bind=testing.db, autocommit=False) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session=False) - - assert john in sess and jill in sess - - eq_(sess.query(User).order_by(User.id).all(), [jack,jane]) - - @testing.resolve_artifact_names - def test_delete_with_fetch_strategy(self): - sess = create_session(bind=testing.db, autocommit=False) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session='fetch') - - assert john not in sess and jill not in sess - - eq_(sess.query(User).order_by(User.id).all(), [jack,jane]) - - @testing.fails_on('mysql', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_delete_fallback(self): - sess = create_session(bind=testing.db, autocommit=False) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(User.name == select([func.max(User.name)])).delete(synchronize_session='evaluate') - - assert john not in sess - - eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane]) - - @testing.resolve_artifact_names - def test_update(self): - sess = create_session(bind=testing.db, autocommit=False) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='evaluate') - - eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) - eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) - - sess.query(User).filter(User.age > 29).update({User.age: User.age - 10}, synchronize_session='evaluate') - eq_([john.age, jack.age, jill.age, jane.age], [25,27,29,27]) - eq_(sess.query(User.age).order_by(User.id).all(), zip([25,27,29,27])) - - sess.query(User).filter(User.age > 27).update({users.c.age: User.age - 10}, synchronize_session='evaluate') - eq_([john.age, jack.age, jill.age, jane.age], [25,27,19,27]) - eq_(sess.query(User.age).order_by(User.id).all(), zip([25,27,19,27])) - - - @testing.resolve_artifact_names - def test_update_with_bindparams(self): - sess = create_session(bind=testing.db, autocommit=False) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - - sess.query(User).filter('age > :x').params(x=29).update({'age': User.age - 10}, synchronize_session='evaluate') - - eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) - eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) - - @testing.resolve_artifact_names - def test_update_changes_resets_dirty(self): - sess = create_session(bind=testing.db, autocommit=False, autoflush=False) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - - john.age = 50 - jack.age = 37 - - # autoflush is false. therefore our '50' and '37' are getting blown away by this operation. - - sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='evaluate') - - for x in (john, jack, jill, jane): - assert not sess.is_modified(x) - - eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) - - john.age = 25 - assert john in sess.dirty - assert jack in sess.dirty - assert jill not in sess.dirty - assert not sess.is_modified(john) - assert not sess.is_modified(jack) - - @testing.resolve_artifact_names - def test_update_changes_with_autoflush(self): - sess = create_session(bind=testing.db, autocommit=False, autoflush=True) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - - john.age = 50 - jack.age = 37 - - sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='evaluate') - - for x in (john, jack, jill, jane): - assert not sess.is_modified(x) - - eq_([john.age, jack.age, jill.age, jane.age], [40, 27, 29, 27]) - - john.age = 25 - assert john in sess.dirty - assert jack not in sess.dirty - assert jill not in sess.dirty - assert sess.is_modified(john) - assert not sess.is_modified(jack) - - - - @testing.resolve_artifact_names - def test_update_with_expire_strategy(self): - sess = create_session(bind=testing.db, autocommit=False) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire') - - eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) - eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) - - @testing.resolve_artifact_names - def test_update_returns_rowcount(self): - sess = create_session(bind=testing.db, autocommit=False) - - rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age + 0}) - self.assertEquals(rowcount, 2) - - rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10}) - self.assertEquals(rowcount, 2) - - @testing.resolve_artifact_names - def test_delete_returns_rowcount(self): - sess = create_session(bind=testing.db, autocommit=False) - - rowcount = sess.query(User).filter(User.age > 26).delete(synchronize_session=False) - self.assertEquals(rowcount, 3) - - @testing.resolve_artifact_names - def test_update_with_eager_relations(self): - self.insert_documents() - - sess = create_session(bind=testing.db, autocommit=False) - - foo,bar,baz = sess.query(Document).order_by(Document.id).all() - sess.query(Document).filter(Document.user_id == 1).update({'title': Document.title+Document.title}, synchronize_session='expire') - - eq_([foo.title, bar.title, baz.title], ['foofoo','barbar', 'baz']) - eq_(sess.query(Document.title).order_by(Document.id).all(), zip(['foofoo','barbar', 'baz'])) - - @testing.resolve_artifact_names - def test_update_with_explicit_eagerload(self): - sess = create_session(bind=testing.db, autocommit=False) - - john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).options(eagerload(User.documents)).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire') - - eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) - eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) - - @testing.resolve_artifact_names - def test_delete_with_eager_relations(self): - self.insert_documents() - - sess = create_session(bind=testing.db, autocommit=False) - - sess.query(Document).filter(Document.user_id == 1).delete(synchronize_session=False) - - eq_(sess.query(Document.title).all(), zip(['baz'])) - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/relationships.py b/test/orm/relationships.py deleted file mode 100644 index a0a8900b2..000000000 --- a/test/orm/relationships.py +++ /dev/null @@ -1,1876 +0,0 @@ -import testenv; testenv.configure_for_tests() -import datetime -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, and_ -from testlib.sa.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers -from testlib.testing import eq_, startswith_ -from orm import _base, _fixtures - - -class RelationTest(_base.MappedTest): - """An extended topological sort test - - This is essentially an extension of the "dependency.py" topological sort - test. In this test, a table is dependent on two other tables that are - otherwise unrelated to each other. The dependency sort must ensure that - this childmost table is below both parent tables in the outcome (a bug - existed where this was not always the case). - - While the straight topological sort tests should expose this, since the - sorting can be different due to subtle differences in program execution, - this test case was exposing the bug whereas the simpler tests were not. - - """ - - run_setup_mappers = 'once' - run_inserts = 'once' - run_deletes = None - - def define_tables(self, metadata): - Table("tbl_a", metadata, - Column("id", Integer, primary_key=True), - Column("name", String(128))) - Table("tbl_b", metadata, - Column("id", Integer, primary_key=True), - Column("name", String(128))) - Table("tbl_c", metadata, - Column("id", Integer, primary_key=True), - Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False), - Column("name", String(128))) - Table("tbl_d", metadata, - Column("id", Integer, primary_key=True), - Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False), - Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")), - Column("name", String(128))) - - def setup_classes(self): - class A(_base.Entity): - pass - class B(_base.Entity): - pass - class C(_base.Entity): - pass - class D(_base.Entity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(A, tbl_a, properties=dict( - c_rows=relation(C, cascade="all, delete-orphan", backref="a_row"))) - mapper(B, tbl_b) - mapper(C, tbl_c, properties=dict( - d_rows=relation(D, cascade="all, delete-orphan", backref="c_row"))) - mapper(D, tbl_d, properties=dict( - b_row=relation(B))) - - @testing.resolve_artifact_names - def insert_data(self): - session = create_session() - a = A(name='a1') - b = B(name='b1') - c = C(name='c1', a_row=a) - - d1 = D(name='d1', b_row=b, c_row=c) - d2 = D(name='d2', b_row=b, c_row=c) - d3 = D(name='d3', b_row=b, c_row=c) - session.add(a) - session.add(b) - session.flush() - - @testing.resolve_artifact_names - def testDeleteRootTable(self): - session = create_session() - a = session.query(A).filter_by(name='a1').one() - - session.delete(a) - session.flush() - - @testing.resolve_artifact_names - def testDeleteMiddleTable(self): - session = create_session() - c = session.query(C).filter_by(name='c1').one() - - session.delete(c) - session.flush() - - -class RelationTest2(_base.MappedTest): - """Tests a relationship on a column included in multiple foreign keys. - - This test tests a relationship on a column that is included in multiple - foreign keys, as well as a self-referential relationship on a composite - key where one column in the foreign key is 'joined to itself'. - - """ - def define_tables(self, metadata): - Table('company_t', metadata, - Column('company_id', Integer, primary_key=True), - Column('name', sa.Unicode(30))) - - Table('employee_t', metadata, - Column('company_id', Integer, primary_key=True), - Column('emp_id', Integer, primary_key=True), - Column('name', sa.Unicode(30)), - Column('reports_to_id', Integer), - sa.ForeignKeyConstraint( - ['company_id'], - ['company_t.company_id']), - sa.ForeignKeyConstraint( - ['company_id', 'reports_to_id'], - ['employee_t.company_id', 'employee_t.emp_id'])) - - @testing.resolve_artifact_names - def test_explicit(self): - """test with mappers that have fairly explicit join conditions""" - - class Company(_base.Entity): - pass - - class Employee(_base.Entity): - def __init__(self, name, company, emp_id, reports_to=None): - self.name = name - self.company = company - self.emp_id = emp_id - self.reports_to = reports_to - - mapper(Company, company_t) - mapper(Employee, employee_t, properties= { - 'company':relation(Company, primaryjoin=employee_t.c.company_id==company_t.c.company_id, backref='employees'), - 'reports_to':relation(Employee, primaryjoin= - sa.and_( - employee_t.c.emp_id==employee_t.c.reports_to_id, - employee_t.c.company_id==employee_t.c.company_id - ), - remote_side=[employee_t.c.emp_id, employee_t.c.company_id], - foreign_keys=[employee_t.c.reports_to_id], - backref='employees') - }) - - sess = create_session() - c1 = Company() - c2 = Company() - - e1 = Employee(u'emp1', c1, 1) - e2 = Employee(u'emp2', c1, 2, e1) - e3 = Employee(u'emp3', c1, 3, e1) - e4 = Employee(u'emp4', c1, 4, e3) - e5 = Employee(u'emp5', c2, 1) - e6 = Employee(u'emp6', c2, 2, e5) - e7 = Employee(u'emp7', c2, 3, e5) - - sess.add_all((c1, c2)) - sess.flush() - sess.expunge_all() - - test_c1 = sess.query(Company).get(c1.company_id) - test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id]) - assert test_e1.name == 'emp1', test_e1.name - test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id]) - assert test_e5.name == 'emp5', test_e5.name - assert [x.name for x in test_e1.employees] == ['emp2', 'emp3'] - eq_(sess.query(Employee).get([c1.company_id, 3]).reports_to.name, 'emp1') - eq_(sess.query(Employee).get([c2.company_id, 3]).reports_to.name, 'emp5') - - @testing.resolve_artifact_names - def test_implicit(self): - """test with mappers that have the most minimal arguments""" - class Company(_base.Entity): - pass - class Employee(_base.Entity): - def __init__(self, name, company, emp_id, reports_to=None): - self.name = name - self.company = company - self.emp_id = emp_id - self.reports_to = reports_to - - mapper(Company, company_t) - mapper(Employee, employee_t, properties= { - 'company':relation(Company, backref='employees'), - 'reports_to':relation(Employee, - remote_side=[employee_t.c.emp_id, employee_t.c.company_id], - foreign_keys=[employee_t.c.reports_to_id], - backref='employees') - }) - - sess = create_session() - c1 = Company() - c2 = Company() - - e1 = Employee(u'emp1', c1, 1) - e2 = Employee(u'emp2', c1, 2, e1) - e3 = Employee(u'emp3', c1, 3, e1) - e4 = Employee(u'emp4', c1, 4, e3) - e5 = Employee(u'emp5', c2, 1) - e6 = Employee(u'emp6', c2, 2, e5) - e7 = Employee(u'emp7', c2, 3, e5) - - sess.add_all((c1, c2)) - sess.flush() - sess.expunge_all() - - test_c1 = sess.query(Company).get(c1.company_id) - test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id]) - assert test_e1.name == 'emp1', test_e1.name - test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id]) - assert test_e5.name == 'emp5', test_e5.name - assert [x.name for x in test_e1.employees] == ['emp2', 'emp3'] - assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1' - assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5' - -class RelationTest3(_base.MappedTest): - def define_tables(self, metadata): - Table("jobs", metadata, - Column("jobno", sa.Unicode(15), primary_key=True), - Column("created", sa.DateTime, nullable=False, - default=datetime.datetime.now), - Column("deleted", sa.Boolean, nullable=False, default=False)) - - Table("pageversions", metadata, - Column("jobno", sa.Unicode(15), primary_key=True), - Column("pagename", sa.Unicode(30), primary_key=True), - Column("version", Integer, primary_key=True, default=1), - Column("created", sa.DateTime, nullable=False, - default=datetime.datetime.now), - Column("md5sum", String(32)), - Column("width", Integer, nullable=False, default=0), - Column("height", Integer, nullable=False, default=0), - sa.ForeignKeyConstraint( - ["jobno", "pagename"], - ["pages.jobno", "pages.pagename"])) - - Table("pages", metadata, - Column("jobno", sa.Unicode(15), ForeignKey("jobs.jobno"), - primary_key=True), - Column("pagename", sa.Unicode(30), primary_key=True), - Column("created", sa.DateTime, nullable=False, - default=datetime.datetime.now), - Column("deleted", sa.Boolean, nullable=False, default=False), - Column("current_version", Integer)) - - Table("pagecomments", metadata, - Column("jobno", sa.Unicode(15), primary_key=True), - Column("pagename", sa.Unicode(30), primary_key=True), - Column("comment_id", Integer, primary_key=True, - autoincrement=False), - Column("content", sa.UnicodeText), - sa.ForeignKeyConstraint( - ["jobno", "pagename"], - ["pages.jobno", "pages.pagename"])) - - @testing.resolve_artifact_names - def setup_mappers(self): - class Job(_base.Entity): - def create_page(self, pagename): - return Page(job=self, pagename=pagename) - class PageVersion(_base.Entity): - def __init__(self, page=None, version=None): - self.page = page - self.version = version - class Page(_base.Entity): - def __init__(self, job=None, pagename=None): - self.job = job - self.pagename = pagename - self.currentversion = PageVersion(self, 1) - def add_version(self): - self.currentversion = PageVersion( - page=self, version=self.currentversion.version+1) - comment = self.add_comment() - comment.closeable = False - comment.content = u'some content' - return self.currentversion - def add_comment(self): - nextnum = max([-1] + [c.comment_id for c in self.comments]) + 1 - newcomment = PageComment() - newcomment.comment_id = nextnum - self.comments.append(newcomment) - newcomment.created_version = self.currentversion.version - return newcomment - class PageComment(_base.Entity): - pass - - mapper(Job, jobs) - mapper(PageVersion, pageversions) - mapper(Page, pages, properties={ - 'job': relation( - Job, - backref=backref('pages', - cascade="all, delete-orphan", - order_by=pages.c.pagename)), - 'currentversion': relation( - PageVersion, - foreign_keys=[pages.c.current_version], - primaryjoin=sa.and_( - pages.c.jobno==pageversions.c.jobno, - pages.c.pagename==pageversions.c.pagename, - pages.c.current_version==pageversions.c.version), - post_update=True), - 'versions': relation( - PageVersion, - cascade="all, delete-orphan", - primaryjoin=sa.and_(pages.c.jobno==pageversions.c.jobno, - pages.c.pagename==pageversions.c.pagename), - order_by=pageversions.c.version, - backref=backref('page',lazy=False) - )}) - mapper(PageComment, pagecomments, properties={ - 'page': relation( - Page, - primaryjoin=sa.and_(pages.c.jobno==pagecomments.c.jobno, - pages.c.pagename==pagecomments.c.pagename), - backref=backref("comments", - cascade="all, delete-orphan", - order_by=pagecomments.c.comment_id))}) - - @testing.resolve_artifact_names - def testbasic(self): - """A combination of complicated join conditions with post_update.""" - - j1 = Job(jobno=u'somejob') - j1.create_page(u'page1') - j1.create_page(u'page2') - j1.create_page(u'page3') - - j2 = Job(jobno=u'somejob2') - j2.create_page(u'page1') - j2.create_page(u'page2') - j2.create_page(u'page3') - - j2.pages[0].add_version() - j2.pages[0].add_version() - j2.pages[1].add_version() - - s = create_session() - s.add_all((j1, j2)) - - s.flush() - - s.expunge_all() - j = s.query(Job).filter_by(jobno=u'somejob').one() - oldp = list(j.pages) - j.pages = [] - - s.flush() - - s.expunge_all() - j = s.query(Job).filter_by(jobno=u'somejob2').one() - j.pages[1].current_version = 12 - s.delete(j) - s.flush() - -class RelationTest4(_base.MappedTest): - """Syncrules on foreign keys that are also primary""" - - def define_tables(self, metadata): - Table("tableA", metadata, - Column("id",Integer,primary_key=True), - Column("foo",Integer,), - test_needs_fk=True) - Table("tableB",metadata, - Column("id",Integer,ForeignKey("tableA.id"),primary_key=True), - test_needs_fk=True) - - def setup_classes(self): - class A(_base.Entity): - pass - - class B(_base.Entity): - pass - - @testing.resolve_artifact_names - def test_no_delete_PK_AtoB(self): - """A cant be deleted without B because B would have no PK value.""" - mapper(A, tableA, properties={ - 'bs':relation(B, cascade="save-update")}) - mapper(B, tableB) - - a1 = A() - a1.bs.append(B()) - sess = create_session() - sess.add(a1) - sess.flush() - - sess.delete(a1) - try: - sess.flush() - assert False - except AssertionError, e: - startswith_(str(e), - "Dependency rule tried to blank-out " - "primary key column 'tableB.id' on instance ") - - @testing.resolve_artifact_names - def test_no_delete_PK_BtoA(self): - mapper(B, tableB, properties={ - 'a':relation(A, cascade="save-update")}) - mapper(A, tableA) - - b1 = B() - a1 = A() - b1.a = a1 - sess = create_session() - sess.add(b1) - sess.flush() - b1.a = None - try: - sess.flush() - assert False - except AssertionError, e: - startswith_(str(e), - "Dependency rule tried to blank-out " - "primary key column 'tableB.id' on instance ") - - @testing.fails_on_everything_except('sqlite', 'mysql') - @testing.resolve_artifact_names - def test_nullPKsOK_BtoA(self): - # postgres cant handle a nullable PK column...? - tableC = Table('tablec', tableA.metadata, - Column('id', Integer, primary_key=True), - Column('a_id', Integer, ForeignKey('tableA.id'), - primary_key=True, autoincrement=False, nullable=True)) - tableC.create() - - class C(_base.Entity): - pass - mapper(C, tableC, properties={ - 'a':relation(A, cascade="save-update") - }, allow_null_pks=True) - mapper(A, tableA) - - c1 = C() - c1.id = 5 - c1.a = None - sess = create_session() - sess.add(c1) - # test that no error is raised. - sess.flush() - - @testing.resolve_artifact_names - def test_delete_cascade_BtoA(self): - """No 'blank the PK' error when the child is to be deleted as part of a cascade""" - - for cascade in ("save-update, delete", - #"save-update, delete-orphan", - "save-update, delete, delete-orphan"): - mapper(B, tableB, properties={ - 'a':relation(A, cascade=cascade, single_parent=True) - }) - mapper(A, tableA) - - b1 = B() - a1 = A() - b1.a = a1 - sess = create_session() - sess.add(b1) - sess.flush() - sess.delete(b1) - sess.flush() - assert a1 not in sess - assert b1 not in sess - sess.expunge_all() - sa.orm.clear_mappers() - - @testing.resolve_artifact_names - def test_delete_cascade_AtoB(self): - """No 'blank the PK' error when the child is to be deleted as part of a cascade""" - for cascade in ("save-update, delete", - #"save-update, delete-orphan", - "save-update, delete, delete-orphan"): - mapper(A, tableA, properties={ - 'bs':relation(B, cascade=cascade) - }) - mapper(B, tableB) - - a1 = A() - b1 = B() - a1.bs.append(b1) - sess = create_session() - sess.add(a1) - sess.flush() - - sess.delete(a1) - sess.flush() - assert a1 not in sess - assert b1 not in sess - sess.expunge_all() - sa.orm.clear_mappers() - - @testing.resolve_artifact_names - def test_delete_manual_AtoB(self): - mapper(A, tableA, properties={ - 'bs':relation(B, cascade="none")}) - mapper(B, tableB) - - a1 = A() - b1 = B() - a1.bs.append(b1) - sess = create_session() - sess.add(a1) - sess.add(b1) - sess.flush() - - sess.delete(a1) - sess.delete(b1) - sess.flush() - assert a1 not in sess - assert b1 not in sess - sess.expunge_all() - - @testing.resolve_artifact_names - def test_delete_manual_BtoA(self): - mapper(B, tableB, properties={ - 'a':relation(A, cascade="none")}) - mapper(A, tableA) - - b1 = B() - a1 = A() - b1.a = a1 - sess = create_session() - sess.add(b1) - sess.add(a1) - sess.flush() - sess.delete(b1) - sess.delete(a1) - sess.flush() - assert a1 not in sess - assert b1 not in sess - -class RelationTest5(_base.MappedTest): - """Test a map to a select that relates to a map to the table.""" - - def define_tables(self, metadata): - Table('items', metadata, - Column('item_policy_num', String(10), primary_key=True, - key='policyNum'), - Column('item_policy_eff_date', sa.Date, primary_key=True, - key='policyEffDate'), - Column('item_type', String(20), primary_key=True, - key='type'), - Column('item_id', Integer, primary_key=True, - key='id', autoincrement=False)) - - @testing.resolve_artifact_names - def test_basic(self): - class Container(_base.Entity): - pass - class LineItem(_base.Entity): - pass - - container_select = sa.select( - [items.c.policyNum, items.c.policyEffDate, items.c.type], - distinct=True, - ).alias('container_select') - - mapper(LineItem, items) - - mapper(Container, - container_select, - order_by=sa.asc(container_select.c.type), - properties=dict( - lineItems=relation(LineItem, - lazy=True, - cascade='all, delete-orphan', - order_by=sa.asc(items.c.type), - primaryjoin=sa.and_( - container_select.c.policyNum==items.c.policyNum, - container_select.c.policyEffDate==items.c.policyEffDate, - container_select.c.type==items.c.type), - foreign_keys=[ - items.c.policyNum, - items.c.policyEffDate, - items.c.type]))) - - session = create_session() - con = Container() - con.policyNum = "99" - con.policyEffDate = datetime.date.today() - con.type = "TESTER" - session.add(con) - for i in range(0, 10): - li = LineItem() - li.id = i - con.lineItems.append(li) - session.add(li) - session.flush() - session.expunge_all() - newcon = session.query(Container).first() - assert con.policyNum == newcon.policyNum - assert len(newcon.lineItems) == 10 - for old, new in zip(con.lineItems, newcon.lineItems): - assert old.id == new.id - -class RelationTest6(_base.MappedTest): - """test a relation with a non-column entity in the primary join, - is not viewonly, and also has the non-column's clause mentioned in the - foreign keys list. - - """ - - def define_tables(self, metadata): - Table('tags', metadata, Column("id", Integer, primary_key=True), - Column("data", String(50)), - ) - - Table('tag_foo', metadata, - Column("id", Integer, primary_key=True), - Column('tagid', Integer), - Column("data", String(50)), - ) - - @testing.resolve_artifact_names - def test_basic(self): - class Tag(_base.ComparableEntity): - pass - class TagInstance(_base.ComparableEntity): - pass - - mapper(Tag, tags, properties={ - 'foo':relation(TagInstance, - primaryjoin=sa.and_(tag_foo.c.data=='iplc_case', - tag_foo.c.tagid==tags.c.id), - foreign_keys=[tag_foo.c.tagid, tag_foo.c.data], - ), - }) - - mapper(TagInstance, tag_foo) - - sess = create_session() - t1 = Tag(data='some tag') - t1.foo.append(TagInstance(data='iplc_case')) - t1.foo.append(TagInstance(data='not_iplc_case')) - sess.add(t1) - sess.flush() - sess.expunge_all() - - # relation works - eq_(sess.query(Tag).all(), [Tag(data='some tag', foo=[TagInstance(data='iplc_case')])]) - - # both TagInstances were persisted - eq_( - sess.query(TagInstance).order_by(TagInstance.data).all(), - [TagInstance(data='iplc_case'), TagInstance(data='not_iplc_case')] - ) - -class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest): - """test ambiguous joins due to FKs on both sides treated as self-referential. - - this mapping is very similar to that of test/orm/inheritance/query.py - SelfReferentialTestJoinedToBase , except that inheritance is not used - here. - - """ - - def define_tables(self, metadata): - subscriber_table = Table('subscriber', metadata, - Column('id', Integer, primary_key=True), - Column('dummy', String(10)) # to appease older sqlite version - ) - - address_table = Table('address', - metadata, - Column('subscriber_id', Integer, ForeignKey('subscriber.id'), primary_key=True), - Column('type', String(1), primary_key=True), - ) - - @testing.resolve_artifact_names - def setup_mappers(self): - subscriber_and_address = subscriber.join(address, - and_(address.c.subscriber_id==subscriber.c.id, address.c.type.in_(['A', 'B', 'C']))) - - class Address(_base.ComparableEntity): - pass - - class Subscriber(_base.ComparableEntity): - pass - - mapper(Address, address) - - mapper(Subscriber, subscriber_and_address, properties={ - 'id':[subscriber.c.id, address.c.subscriber_id], - 'addresses' : relation(Address, - backref=backref("customer")) - }) - - @testing.resolve_artifact_names - def test_mapping(self): - from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE - sess = create_session() - assert Subscriber.addresses.property.direction is ONETOMANY - assert Address.customer.property.direction is MANYTOONE - - s1 = Subscriber(type='A', - addresses = [ - Address(type='D'), - Address(type='E'), - ] - ) - a1 = Address(type='B', customer=Subscriber(type='C')) - - assert s1.addresses[0].customer is s1 - assert a1.customer.addresses[0] is a1 - - sess.add_all([s1, a1]) - - sess.flush() - sess.expunge_all() - - eq_( - sess.query(Subscriber).order_by(Subscriber.type).all(), - [ - Subscriber(id=1, type=u'A'), - Subscriber(id=2, type=u'B'), - Subscriber(id=2, type=u'C') - ] - ) - - -class ManualBackrefTest(_fixtures.FixtureTest): - """Test explicit relations that are backrefs to each other.""" - - run_inserts = None - - @testing.resolve_artifact_names - def test_o2m(self): - mapper(User, users, properties={ - 'addresses':relation(Address, back_populates='user') - }) - - mapper(Address, addresses, properties={ - 'user':relation(User, back_populates='addresses') - }) - - sess = create_session() - - u1 = User(name='u1') - a1 = Address(email_address='foo') - u1.addresses.append(a1) - assert a1.user is u1 - - sess.add(u1) - sess.flush() - sess.expire_all() - assert sess.query(Address).one() is a1 - assert a1.user is u1 - assert a1 in u1.addresses - - @testing.resolve_artifact_names - def test_invalid_key(self): - mapper(User, users, properties={ - 'addresses':relation(Address, back_populates='userr') - }) - - mapper(Address, addresses, properties={ - 'user':relation(User, back_populates='addresses') - }) - - self.assertRaises(sa.exc.InvalidRequestError, compile_mappers) - - @testing.resolve_artifact_names - def test_invalid_target(self): - mapper(User, users, properties={ - 'addresses':relation(Address, back_populates='dingaling'), - }) - - mapper(Dingaling, dingalings) - mapper(Address, addresses, properties={ - 'dingaling':relation(Dingaling) - }) - - self.assertRaisesMessage(sa.exc.ArgumentError, - r"reverse_property 'dingaling' on relation User.addresses references " - "relation Address.dingaling, which does not reference mapper Mapper\|User\|users", - compile_mappers) - -class JoinConditionErrorTest(testing.TestBase): - - def test_clauseelement_pj(self): - from sqlalchemy.ext.declarative import declarative_base - Base = declarative_base() - class C1(Base): - __tablename__ = 'c1' - id = Column('id', Integer, primary_key=True) - class C2(Base): - __tablename__ = 'c2' - id = Column('id', Integer, primary_key=True) - c1id = Column('c1id', Integer, ForeignKey('c1.id')) - c2 = relation(C1, primaryjoin=C1.id) - - self.assertRaises(sa.exc.ArgumentError, compile_mappers) - - def test_clauseelement_pj_false(self): - from sqlalchemy.ext.declarative import declarative_base - Base = declarative_base() - class C1(Base): - __tablename__ = 'c1' - id = Column('id', Integer, primary_key=True) - class C2(Base): - __tablename__ = 'c2' - id = Column('id', Integer, primary_key=True) - c1id = Column('c1id', Integer, ForeignKey('c1.id')) - c2 = relation(C1, primaryjoin="x"=="y") - - self.assertRaises(sa.exc.ArgumentError, compile_mappers) - - - def test_fk_error_raised(self): - m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - Column('foo_id', Integer, ForeignKey('t2.nonexistent_id')), - ) - t2 = Table('t2', m, - Column('id', Integer, primary_key=True), - ) - - t3 = Table('t3', m, - Column('id', Integer, primary_key=True), - Column('t1id', Integer, ForeignKey('t1.id')) - ) - - class C1(object): - pass - class C2(object): - pass - - mapper(C1, t1, properties={'c2':relation(C2)}) - mapper(C2, t3) - - self.assertRaises(sa.exc.NoReferencedColumnError, compile_mappers) - - def test_join_error_raised(self): - m = MetaData() - t1 = Table('t1', m, - Column('id', Integer, primary_key=True), - ) - t2 = Table('t2', m, - Column('id', Integer, primary_key=True), - ) - - t3 = Table('t3', m, - Column('id', Integer, primary_key=True), - Column('t1id', Integer) - ) - - class C1(object): - pass - class C2(object): - pass - - mapper(C1, t1, properties={'c2':relation(C2)}) - mapper(C2, t3) - - self.assertRaises(sa.exc.ArgumentError, compile_mappers) - - def tearDown(self): - clear_mappers() - -class TypeMatchTest(_base.MappedTest): - """test errors raised when trying to add items whose type is not handled by a relation""" - - def define_tables(self, metadata): - Table("a", metadata, - Column('aid', Integer, primary_key=True), - Column('data', String(30))) - Table("b", metadata, - Column('bid', Integer, primary_key=True), - Column("a_id", Integer, ForeignKey("a.aid")), - Column('data', String(30))) - Table("c", metadata, - Column('cid', Integer, primary_key=True), - Column("b_id", Integer, ForeignKey("b.bid")), - Column('data', String(30))) - Table("d", metadata, - Column('did', Integer, primary_key=True), - Column("a_id", Integer, ForeignKey("a.aid")), - Column('data', String(30))) - - @testing.resolve_artifact_names - def test_o2m_oncascade(self): - class A(_base.Entity): pass - class B(_base.Entity): pass - class C(_base.Entity): pass - mapper(A, a, properties={'bs':relation(B)}) - mapper(B, b) - mapper(C, c) - - a1 = A() - b1 = B() - c1 = C() - a1.bs.append(b1) - a1.bs.append(c1) - sess = create_session() - try: - sess.add(a1) - assert False - except AssertionError, err: - eq_(str(err), - "Attribute 'bs' on class '%s' doesn't handle " - "objects of type '%s'" % (A, C)) - - @testing.resolve_artifact_names - def test_o2m_onflush(self): - class A(_base.Entity): pass - class B(_base.Entity): pass - class C(_base.Entity): pass - mapper(A, a, properties={'bs':relation(B, cascade="none")}) - mapper(B, b) - mapper(C, c) - - a1 = A() - b1 = B() - c1 = C() - a1.bs.append(b1) - a1.bs.append(c1) - sess = create_session() - sess.add(a1) - sess.add(b1) - sess.add(c1) - self.assertRaisesMessage(sa.orm.exc.FlushError, - "Attempting to flush an item", sess.flush) - - @testing.resolve_artifact_names - def test_o2m_nopoly_onflush(self): - class A(_base.Entity): pass - class B(_base.Entity): pass - class C(B): pass - mapper(A, a, properties={'bs':relation(B, cascade="none")}) - mapper(B, b) - mapper(C, c, inherits=B) - - a1 = A() - b1 = B() - c1 = C() - a1.bs.append(b1) - a1.bs.append(c1) - sess = create_session() - sess.add(a1) - sess.add(b1) - sess.add(c1) - self.assertRaisesMessage(sa.orm.exc.FlushError, - "Attempting to flush an item", sess.flush) - - @testing.resolve_artifact_names - def test_m2o_nopoly_onflush(self): - class A(_base.Entity): pass - class B(A): pass - class D(_base.Entity): pass - mapper(A, a) - mapper(B, b, inherits=A) - mapper(D, d, properties={"a":relation(A, cascade="none")}) - b1 = B() - d1 = D() - d1.a = b1 - sess = create_session() - sess.add(b1) - sess.add(d1) - self.assertRaisesMessage(sa.orm.exc.FlushError, - "Attempting to flush an item", sess.flush) - - @testing.resolve_artifact_names - def test_m2o_oncascade(self): - class A(_base.Entity): pass - class B(_base.Entity): pass - class D(_base.Entity): pass - mapper(A, a) - mapper(B, b) - mapper(D, d, properties={"a":relation(A)}) - b1 = B() - d1 = D() - d1.a = b1 - sess = create_session() - self.assertRaisesMessage(AssertionError, - "doesn't handle objects of type", sess.add, d1) - -class TypedAssociationTable(_base.MappedTest): - - def define_tables(self, metadata): - class MySpecialType(sa.types.TypeDecorator): - impl = String - def process_bind_param(self, value, dialect): - return "lala" + value - def process_result_value(self, value, dialect): - return value[4:] - - Table('t1', metadata, - Column('col1', MySpecialType(30), primary_key=True), - Column('col2', String(30))) - Table('t2', metadata, - Column('col1', MySpecialType(30), primary_key=True), - Column('col2', String(30))) - Table('t3', metadata, - Column('t1c1', MySpecialType(30), ForeignKey('t1.col1')), - Column('t2c1', MySpecialType(30), ForeignKey('t2.col1'))) - - @testing.resolve_artifact_names - def testm2m(self): - """Many-to-many tables with special types for candidate keys.""" - - class T1(_base.Entity): pass - class T2(_base.Entity): pass - mapper(T2, t2) - mapper(T1, t1, properties={ - 't2s':relation(T2, secondary=t3, backref='t1s')}) - - a = T1() - a.col1 = "aid" - b = T2() - b.col1 = "bid" - c = T2() - c.col1 = "cid" - a.t2s.append(b) - a.t2s.append(c) - sess = create_session() - sess.add(a) - sess.flush() - - assert t3.count().scalar() == 2 - - a.t2s.remove(c) - sess.flush() - - assert t3.count().scalar() == 1 - - -class ViewOnlyOverlappingNames(_base.MappedTest): - """'viewonly' mappings with overlapping PK column names.""" - - def define_tables(self, metadata): - Table("t1", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(40))) - Table("t2", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(40)), - Column('t1id', Integer, ForeignKey('t1.id'))) - Table("t3", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(40)), - Column('t2id', Integer, ForeignKey('t2.id'))) - - @testing.resolve_artifact_names - def test_three_table_view(self): - """A three table join with overlapping PK names. - - A third table is pulled into the primary join condition using - overlapping PK column names and should not produce 'conflicting column' - error. - - """ - class C1(_base.Entity): pass - class C2(_base.Entity): pass - class C3(_base.Entity): pass - - mapper(C1, t1, properties={ - 't2s':relation(C2), - 't2_view':relation(C2, - viewonly=True, - primaryjoin=sa.and_(t1.c.id==t2.c.t1id, - t3.c.t2id==t2.c.id, - t3.c.data==t1.c.data))}) - mapper(C2, t2) - mapper(C3, t3, properties={ - 't2':relation(C2)}) - - c1 = C1() - c1.data = 'c1data' - c2a = C2() - c1.t2s.append(c2a) - c2b = C2() - c1.t2s.append(c2b) - c3 = C3() - c3.data='c1data' - c3.t2 = c2b - sess = create_session() - sess.add(c1) - sess.add(c3) - sess.flush() - sess.expunge_all() - - c1 = sess.query(C1).get(c1.id) - assert set([x.id for x in c1.t2s]) == set([c2a.id, c2b.id]) - assert set([x.id for x in c1.t2_view]) == set([c2b.id]) - -class ViewOnlyUniqueNames(_base.MappedTest): - """'viewonly' mappings with unique PK column names.""" - - def define_tables(self, metadata): - Table("t1", metadata, - Column('t1id', Integer, primary_key=True), - Column('data', String(40))) - Table("t2", metadata, - Column('t2id', Integer, primary_key=True), - Column('data', String(40)), - Column('t1id_ref', Integer, ForeignKey('t1.t1id'))) - Table("t3", metadata, - Column('t3id', Integer, primary_key=True), - Column('data', String(40)), - Column('t2id_ref', Integer, ForeignKey('t2.t2id'))) - - @testing.resolve_artifact_names - def test_three_table_view(self): - """A three table join with overlapping PK names. - - A third table is pulled into the primary join condition using unique - PK column names and should not produce 'mapper has no columnX' error. - - """ - class C1(_base.Entity): pass - class C2(_base.Entity): pass - class C3(_base.Entity): pass - - mapper(C1, t1, properties={ - 't2s':relation(C2), - 't2_view':relation(C2, - viewonly=True, - primaryjoin=sa.and_(t1.c.t1id==t2.c.t1id_ref, - t3.c.t2id_ref==t2.c.t2id, - t3.c.data==t1.c.data))}) - mapper(C2, t2) - mapper(C3, t3, properties={ - 't2':relation(C2)}) - - c1 = C1() - c1.data = 'c1data' - c2a = C2() - c1.t2s.append(c2a) - c2b = C2() - c1.t2s.append(c2b) - c3 = C3() - c3.data='c1data' - c3.t2 = c2b - sess = create_session() - - sess.add_all((c1, c3)) - sess.flush() - sess.expunge_all() - - c1 = sess.query(C1).get(c1.t1id) - assert set([x.t2id for x in c1.t2s]) == set([c2a.t2id, c2b.t2id]) - assert set([x.t2id for x in c1.t2_view]) == set([c2b.t2id]) - -class ViewOnlyLocalRemoteM2M(testing.TestBase): - """test that local-remote is correctly determined for m2m""" - - def test_local_remote(self): - meta = MetaData() - - t1 = Table('t1', meta, - Column('id', Integer, primary_key=True), - ) - t2 = Table('t2', meta, - Column('id', Integer, primary_key=True), - ) - t12 = Table('tab', meta, - Column('t1_id', Integer, ForeignKey('t1.id',)), - Column('t2_id', Integer, ForeignKey('t2.id',)), - ) - - class A(object): pass - class B(object): pass - mapper( B, t2, ) - m = mapper( A, t1, properties=dict( - b_view = relation( B, secondary=t12, viewonly=True), - b_plain= relation( B, secondary=t12), - ) - ) - compile_mappers() - assert m.get_property('b_view').local_remote_pairs == \ - m.get_property('b_plain').local_remote_pairs == \ - [(t1.c.id, t12.c.t1_id), (t12.c.t2_id, t2.c.id)] - - - -class ViewOnlyNonEquijoin(_base.MappedTest): - """'viewonly' mappings based on non-equijoins.""" - - def define_tables(self, metadata): - Table('foos', metadata, - Column('id', Integer, primary_key=True)) - Table('bars', metadata, - Column('id', Integer, primary_key=True), - Column('fid', Integer)) - - @testing.resolve_artifact_names - def test_viewonly_join(self): - class Foo(_base.ComparableEntity): - pass - class Bar(_base.ComparableEntity): - pass - - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - primaryjoin=foos.c.id > bars.c.fid, - foreign_keys=[bars.c.fid], - viewonly=True)}) - - mapper(Bar, bars) - - sess = create_session() - sess.add_all((Foo(id=4), - Foo(id=9), - Bar(id=1, fid=2), - Bar(id=2, fid=3), - Bar(id=3, fid=6), - Bar(id=4, fid=7))) - sess.flush() - - sess = create_session() - eq_(sess.query(Foo).filter_by(id=4).one(), - Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)])) - eq_(sess.query(Foo).filter_by(id=9).one(), - Foo(id=9, bars=[Bar(fid=2), Bar(fid=3), Bar(fid=6), Bar(fid=7)])) - - -class ViewOnlyRepeatedRemoteColumn(_base.MappedTest): - """'viewonly' mappings that contain the same 'remote' column twice""" - - def define_tables(self, metadata): - Table('foos', metadata, - Column('id', Integer, primary_key=True), - Column('bid1', Integer,ForeignKey('bars.id')), - Column('bid2', Integer,ForeignKey('bars.id'))) - - Table('bars', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) - - @testing.resolve_artifact_names - def test_relation_on_or(self): - class Foo(_base.ComparableEntity): - pass - class Bar(_base.ComparableEntity): - pass - - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - primaryjoin=sa.or_(bars.c.id == foos.c.bid1, - bars.c.id == foos.c.bid2), - uselist=True, - viewonly=True)}) - mapper(Bar, bars) - - sess = create_session() - b1 = Bar(id=1, data='b1') - b2 = Bar(id=2, data='b2') - b3 = Bar(id=3, data='b3') - f1 = Foo(bid1=1, bid2=2) - f2 = Foo(bid1=3, bid2=None) - - sess.add_all((b1, b2, b3)) - sess.flush() - - sess.add_all((f1, f2)) - sess.flush() - - sess.expunge_all() - eq_(sess.query(Foo).filter_by(id=f1.id).one(), - Foo(bars=[Bar(data='b1'), Bar(data='b2')])) - eq_(sess.query(Foo).filter_by(id=f2.id).one(), - Foo(bars=[Bar(data='b3')])) - -class ViewOnlyRepeatedLocalColumn(_base.MappedTest): - """'viewonly' mappings that contain the same 'local' column twice""" - - def define_tables(self, metadata): - Table('foos', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) - - Table('bars', metadata, Column('id', Integer, primary_key=True), - Column('fid1', Integer, ForeignKey('foos.id')), - Column('fid2', Integer, ForeignKey('foos.id')), - Column('data', String(50))) - - @testing.resolve_artifact_names - def test_relation_on_or(self): - class Foo(_base.ComparableEntity): - pass - class Bar(_base.ComparableEntity): - pass - - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - primaryjoin=sa.or_(bars.c.fid1 == foos.c.id, - bars.c.fid2 == foos.c.id), - viewonly=True)}) - mapper(Bar, bars) - - sess = create_session() - f1 = Foo(id=1, data='f1') - f2 = Foo(id=2, data='f2') - b1 = Bar(fid1=1, data='b1') - b2 = Bar(fid2=1, data='b2') - b3 = Bar(fid1=2, data='b3') - b4 = Bar(fid1=1, fid2=2, data='b4') - - sess.add_all((f1, f2)) - sess.flush() - - sess.add_all((b1, b2, b3, b4)) - sess.flush() - - sess.expunge_all() - eq_(sess.query(Foo).filter_by(id=f1.id).one(), - Foo(bars=[Bar(data='b1'), Bar(data='b2'), Bar(data='b4')])) - eq_(sess.query(Foo).filter_by(id=f2.id).one(), - Foo(bars=[Bar(data='b3'), Bar(data='b4')])) - -class ViewOnlyComplexJoin(_base.MappedTest): - """'viewonly' mappings with a complex join condition.""" - - def define_tables(self, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) - Table('t2', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('t1id', Integer, ForeignKey('t1.id'))) - Table('t3', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) - Table('t2tot3', metadata, - Column('t2id', Integer, ForeignKey('t2.id')), - Column('t3id', Integer, ForeignKey('t3.id'))) - - def setup_classes(self): - class T1(_base.ComparableEntity): - pass - class T2(_base.ComparableEntity): - pass - class T3(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def test_basic(self): - mapper(T1, t1, properties={ - 't3s':relation(T3, primaryjoin=sa.and_( - t1.c.id==t2.c.t1id, - t2.c.id==t2tot3.c.t2id, - t3.c.id==t2tot3.c.t3id), - viewonly=True, - foreign_keys=t3.c.id, remote_side=t2.c.t1id) - }) - mapper(T2, t2, properties={ - 't1':relation(T1), - 't3s':relation(T3, secondary=t2tot3) - }) - mapper(T3, t3) - - sess = create_session() - sess.add(T2(data='t2', t1=T1(data='t1'), t3s=[T3(data='t3')])) - sess.flush() - sess.expunge_all() - - a = sess.query(T1).first() - eq_(a.t3s, [T3(data='t3')]) - - - @testing.resolve_artifact_names - def test_remote_side_escalation(self): - mapper(T1, t1, properties={ - 't3s':relation(T3, - primaryjoin=sa.and_(t1.c.id==t2.c.t1id, - t2.c.id==t2tot3.c.t2id, - t3.c.id==t2tot3.c.t3id - ), - viewonly=True, - foreign_keys=t3.c.id)}) - mapper(T2, t2, properties={ - 't1':relation(T1), - 't3s':relation(T3, secondary=t2tot3)}) - mapper(T3, t3) - self.assertRaisesMessage(sa.exc.ArgumentError, - "Specify remote_side argument", - sa.orm.compile_mappers) - - -class ExplicitLocalRemoteTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('t1', metadata, - Column('id', String(50), primary_key=True), - Column('data', String(50))) - Table('t2', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('t1id', String(50))) - - @testing.resolve_artifact_names - def setup_classes(self): - class T1(_base.ComparableEntity): - pass - class T2(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def test_onetomany_funcfk(self): - # use a function within join condition. but specifying - # local_remote_pairs overrides all parsing of the join condition. - mapper(T1, t1, properties={ - 't2s':relation(T2, - primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), - _local_remote_pairs=[(t1.c.id, t2.c.t1id)], - foreign_keys=[t2.c.t1id])}) - mapper(T2, t2) - - sess = create_session() - a1 = T1(id='number1', data='a1') - a2 = T1(id='number2', data='a2') - b1 = T2(data='b1', t1id='NuMbEr1') - b2 = T2(data='b2', t1id='Number1') - b3 = T2(data='b3', t1id='Number2') - sess.add_all((a1, a2, b1, b2, b3)) - sess.flush() - sess.expunge_all() - - eq_(sess.query(T1).first(), - T1(id='number1', data='a1', t2s=[ - T2(data='b1', t1id='NuMbEr1'), - T2(data='b2', t1id='Number1')])) - - @testing.resolve_artifact_names - def test_manytoone_funcfk(self): - mapper(T1, t1) - mapper(T2, t2, properties={ - 't1':relation(T1, - primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), - _local_remote_pairs=[(t2.c.t1id, t1.c.id)], - foreign_keys=[t2.c.t1id], - uselist=True)}) - - sess = create_session() - a1 = T1(id='number1', data='a1') - a2 = T1(id='number2', data='a2') - b1 = T2(data='b1', t1id='NuMbEr1') - b2 = T2(data='b2', t1id='Number1') - b3 = T2(data='b3', t1id='Number2') - sess.add_all((a1, a2, b1, b2, b3)) - sess.flush() - sess.expunge_all() - - eq_(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), - [T2(data='b1', t1=[T1(id='number1', data='a1')]), - T2(data='b2', t1=[T1(id='number1', data='a1')])]) - - @testing.resolve_artifact_names - def test_onetomany_func_referent(self): - mapper(T1, t1, properties={ - 't2s':relation(T2, - primaryjoin=sa.func.lower(t1.c.id)==t2.c.t1id, - _local_remote_pairs=[(t1.c.id, t2.c.t1id)], - foreign_keys=[t2.c.t1id])}) - mapper(T2, t2) - - sess = create_session() - a1 = T1(id='NuMbeR1', data='a1') - a2 = T1(id='NuMbeR2', data='a2') - b1 = T2(data='b1', t1id='number1') - b2 = T2(data='b2', t1id='number1') - b3 = T2(data='b2', t1id='number2') - sess.add_all((a1, a2, b1, b2, b3)) - sess.flush() - sess.expunge_all() - - eq_(sess.query(T1).first(), - T1(id='NuMbeR1', data='a1', t2s=[ - T2(data='b1', t1id='number1'), - T2(data='b2', t1id='number1')])) - - @testing.resolve_artifact_names - def test_manytoone_func_referent(self): - mapper(T1, t1) - mapper(T2, t2, properties={ - 't1':relation(T1, - primaryjoin=sa.func.lower(t1.c.id)==t2.c.t1id, - _local_remote_pairs=[(t2.c.t1id, t1.c.id)], - foreign_keys=[t2.c.t1id], uselist=True)}) - - sess = create_session() - a1 = T1(id='NuMbeR1', data='a1') - a2 = T1(id='NuMbeR2', data='a2') - b1 = T2(data='b1', t1id='number1') - b2 = T2(data='b2', t1id='number1') - b3 = T2(data='b3', t1id='number2') - sess.add_all((a1, a2, b1, b2, b3)) - sess.flush() - sess.expunge_all() - - eq_(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), - [T2(data='b1', t1=[T1(id='NuMbeR1', data='a1')]), - T2(data='b2', t1=[T1(id='NuMbeR1', data='a1')])]) - - @testing.resolve_artifact_names - def test_escalation_1(self): - mapper(T1, t1, properties={ - 't2s':relation(T2, - primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), - _local_remote_pairs=[(t1.c.id, t2.c.t1id)], - foreign_keys=[t2.c.t1id], - remote_side=[t2.c.t1id])}) - mapper(T2, t2) - self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_escalation_2(self): - mapper(T1, t1, properties={ - 't2s':relation(T2, - primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), - _local_remote_pairs=[(t1.c.id, t2.c.t1id)])}) - mapper(T2, t2) - self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers) - -class InvalidRemoteSideTest(_base.MappedTest): - def define_tables(self, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column('t_id', Integer, ForeignKey('t1.id')) - ) - - @testing.resolve_artifact_names - def setup_classes(self): - class T1(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def test_o2m_backref(self): - mapper(T1, t1, properties={ - 't1s':relation(T1, backref='parent') - }) - - self.assertRaisesMessage(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " - "both of the same direction . Did you " - "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_m2o_backref(self): - mapper(T1, t1, properties={ - 't1s':relation(T1, backref=backref('parent', remote_side=t1.c.id), remote_side=t1.c.id) - }) - - self.assertRaisesMessage(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " - "both of the same direction . Did you " - "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_o2m_explicit(self): - mapper(T1, t1, properties={ - 't1s':relation(T1, back_populates='parent'), - 'parent':relation(T1, back_populates='t1s'), - }) - - # can't be sure of ordering here - self.assertRaisesMessage(sa.exc.ArgumentError, - "both of the same direction . Did you " - "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_m2o_explicit(self): - mapper(T1, t1, properties={ - 't1s':relation(T1, back_populates='parent', remote_side=t1.c.id), - 'parent':relation(T1, back_populates='t1s', remote_side=t1.c.id) - }) - - # can't be sure of ordering here - self.assertRaisesMessage(sa.exc.ArgumentError, - "both of the same direction . Did you " - "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) - - -class InvalidRelationEscalationTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('foos', metadata, - Column('id', Integer, primary_key=True), - Column('fid', Integer)) - Table('bars', metadata, - Column('id', Integer, primary_key=True), - Column('fid', Integer)) - - def setup_classes(self): - class Foo(_base.Entity): - pass - class Bar(_base.Entity): - pass - - @testing.resolve_artifact_names - def test_no_join(self): - mapper(Foo, foos, properties={ - 'bars':relation(Bar)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine join condition between parent/child " - "tables on relation", sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_no_join_self_ref(self): - mapper(Foo, foos, properties={ - 'foos':relation(Foo)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine join condition between parent/child " - "tables on relation", sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_no_equated(self): - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - primaryjoin=foos.c.id>bars.c.fid)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine relation direction for primaryjoin condition", - sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_no_equated_fks(self): - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - primaryjoin=foos.c.id>bars.c.fid, - foreign_keys=bars.c.fid)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not locate any equated, locally mapped column pairs " - "for primaryjoin condition", sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_ambiguous_fks(self): - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - primaryjoin=foos.c.id==bars.c.fid, - foreign_keys=[foos.c.id, bars.c.fid])}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Do the columns in 'foreign_keys' represent only the " - "'foreign' columns in this join condition ?", - sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_ambiguous_remoteside_o2m(self): - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - primaryjoin=foos.c.id==bars.c.fid, - foreign_keys=[bars.c.fid], - remote_side=[foos.c.id, bars.c.fid], - viewonly=True - )}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "could not determine any local/remote column pairs", - sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_ambiguous_remoteside_m2o(self): - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - primaryjoin=foos.c.id==bars.c.fid, - foreign_keys=[foos.c.id], - remote_side=[foos.c.id, bars.c.fid], - viewonly=True - )}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "could not determine any local/remote column pairs", - sa.orm.compile_mappers) - - - @testing.resolve_artifact_names - def test_no_equated_self_ref(self): - mapper(Foo, foos, properties={ - 'foos':relation(Foo, - primaryjoin=foos.c.id>foos.c.fid)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine relation direction for primaryjoin condition", - sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_no_equated_self_ref(self): - mapper(Foo, foos, properties={ - 'foos':relation(Foo, - primaryjoin=foos.c.id>foos.c.fid, - foreign_keys=[foos.c.fid])}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not locate any equated, locally mapped column pairs " - "for primaryjoin condition", sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_no_equated_viewonly(self): - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - primaryjoin=foos.c.id>bars.c.fid, - viewonly=True)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine relation direction for primaryjoin condition", - sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_no_equated_self_ref_viewonly(self): - mapper(Foo, foos, properties={ - 'foos':relation(Foo, - primaryjoin=foos.c.id>foos.c.fid, - viewonly=True)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Specify the 'foreign_keys' argument to indicate which columns " - "on the relation are foreign.", sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_no_equated_self_ref_viewonly_fks(self): - mapper(Foo, foos, properties={ - 'foos':relation(Foo, - primaryjoin=foos.c.id>foos.c.fid, - viewonly=True, - foreign_keys=[foos.c.fid])}) - - sa.orm.compile_mappers() - eq_(Foo.foos.property.local_remote_pairs, [(foos.c.id, foos.c.fid)]) - - @testing.resolve_artifact_names - def test_equated(self): - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - primaryjoin=foos.c.id==bars.c.fid)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine relation direction for primaryjoin condition", - sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_equated_self_ref(self): - mapper(Foo, foos, properties={ - 'foos':relation(Foo, - primaryjoin=foos.c.id==foos.c.fid)}) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine relation direction for primaryjoin condition", - sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_equated_self_ref_wrong_fks(self): - mapper(Foo, foos, properties={ - 'foos':relation(Foo, - primaryjoin=foos.c.id==foos.c.fid, - foreign_keys=[bars.c.id])}) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine relation direction for primaryjoin condition", - sa.orm.compile_mappers) - - -class InvalidRelationEscalationTestM2M(_base.MappedTest): - - def define_tables(self, metadata): - Table('foos', metadata, - Column('id', Integer, primary_key=True)) - Table('foobars', metadata, - Column('fid', Integer), Column('bid', Integer)) - Table('bars', metadata, - Column('id', Integer, primary_key=True)) - - @testing.resolve_artifact_names - def setup_classes(self): - class Foo(_base.Entity): - pass - class Bar(_base.Entity): - pass - - @testing.resolve_artifact_names - def test_no_join(self): - mapper(Foo, foos, properties={ - 'bars': relation(Bar, secondary=foobars)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine join condition between parent/child tables " - "on relation", sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_no_secondaryjoin(self): - mapper(Foo, foos, properties={ - 'bars': relation(Bar, - secondary=foobars, - primaryjoin=foos.c.id > foobars.c.fid)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine join condition between parent/child tables " - "on relation", - sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_bad_primaryjoin(self): - mapper(Foo, foos, properties={ - 'bars': relation(Bar, - secondary=foobars, - primaryjoin=foos.c.id > foobars.c.fid, - secondaryjoin=foobars.c.bid<=bars.c.id)}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine relation direction for primaryjoin condition", - sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_bad_secondaryjoin(self): - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - secondary=foobars, - primaryjoin=foos.c.id == foobars.c.fid, - secondaryjoin=foobars.c.bid <= bars.c.id, - foreign_keys=[foobars.c.fid])}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not determine relation direction for secondaryjoin " - "condition", sa.orm.compile_mappers) - - @testing.resolve_artifact_names - def test_no_equated_secondaryjoin(self): - mapper(Foo, foos, properties={ - 'bars':relation(Bar, - secondary=foobars, - primaryjoin=foos.c.id == foobars.c.fid, - secondaryjoin=foobars.c.bid <= bars.c.id, - foreign_keys=[foobars.c.fid, foobars.c.bid])}) - mapper(Bar, bars) - - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Could not locate any equated, locally mapped column pairs for " - "secondaryjoin condition", sa.orm.compile_mappers) - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/scoping.py b/test/orm/scoping.py deleted file mode 100644 index bdfc5a9d5..000000000 --- a/test/orm/scoping.py +++ /dev/null @@ -1,238 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from sqlalchemy.orm import scoped_session -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, query -from testlib.testing import eq_ -from orm import _base - - -class _ScopedTest(_base.MappedTest): - """Adds another lookup bucket to emulate Session globals.""" - - run_setup_mappers = 'once' - - _artifact_registries = ( - _base.MappedTest._artifact_registries + ('scoping',)) - - def setUpAll(self): - type(self).scoping = _base.adict() - _base.MappedTest.setUpAll(self) - - def tearDownAll(self): - self.scoping.clear() - _base.MappedTest.tearDownAll(self) - - -class ScopedSessionTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('table1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30))) - Table('table2', metadata, - Column('id', Integer, primary_key=True), - Column('someid', None, ForeignKey('table1.id'))) - - @testing.resolve_artifact_names - def test_basic(self): - Session = scoped_session(sa.orm.sessionmaker()) - - class CustomQuery(query.Query): - pass - - class SomeObject(_base.ComparableEntity): - query = Session.query_property() - class SomeOtherObject(_base.ComparableEntity): - query = Session.query_property() - custom_query = Session.query_property(query_cls=CustomQuery) - - mapper(SomeObject, table1, properties={ - 'options':relation(SomeOtherObject)}) - mapper(SomeOtherObject, table2) - - s = SomeObject(id=1, data="hello") - sso = SomeOtherObject() - s.options.append(sso) - Session.add(s) - Session.commit() - Session.refresh(sso) - Session.remove() - - eq_(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), - Session.query(SomeObject).one()) - eq_(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), - SomeObject.query.one()) - eq_(SomeOtherObject(someid=1), - SomeOtherObject.query.filter( - SomeOtherObject.someid == sso.someid).one()) - assert isinstance(SomeOtherObject.query, query.Query) - assert not isinstance(SomeOtherObject.query, CustomQuery) - assert isinstance(SomeOtherObject.custom_query, query.Query) - - -class ScopedMapperTest(_ScopedTest): - - def define_tables(self, metadata): - Table('table1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30))) - Table('table2', metadata, - Column('id', Integer, primary_key=True), - Column('someid', None, ForeignKey('table1.id'))) - - def setup_classes(self): - class SomeObject(_base.ComparableEntity): - pass - class SomeOtherObject(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - Session = scoped_session(sa.orm.create_session) - Session.mapper(SomeObject, table1, properties={ - 'options':relation(SomeOtherObject) - }) - Session.mapper(SomeOtherObject, table2) - - self.scoping['Session'] = Session - - @testing.resolve_artifact_names - def insert_data(self): - s = SomeObject() - s.id = 1 - s.data = 'hello' - sso = SomeOtherObject() - s.options.append(sso) - Session.flush() - Session.expunge_all() - - @testing.resolve_artifact_names - def test_query(self): - sso = SomeOtherObject.query().first() - assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id - - @testing.resolve_artifact_names - def test_query_compiles(self): - class Foo(object): - pass - Session.mapper(Foo, table2) - assert hasattr(Foo, 'query') - - ext = sa.orm.MapperExtension() - - class Bar(object): - pass - Session.mapper(Bar, table2, extension=[ext]) - assert hasattr(Bar, 'query') - - class Baz(object): - pass - Session.mapper(Baz, table2, extension=ext) - assert hasattr(Baz, 'query') - - @testing.resolve_artifact_names - def test_default_constructor_state_not_shared(self): - scope = scoped_session(sa.orm.sessionmaker()) - - class A(object): - pass - class B(object): - def __init__(self): - pass - - scope.mapper(A, table1) - scope.mapper(B, table2) - - A(foo='bar') - self.assertRaises(TypeError, B, foo='bar') - - scope = scoped_session(sa.orm.sessionmaker()) - - class C(object): - def __init__(self): - pass - class D(object): - pass - - scope.mapper(C, table1) - scope.mapper(D, table2) - - self.assertRaises(TypeError, C, foo='bar') - D(foo='bar') - - @testing.resolve_artifact_names - def test_validating_constructor(self): - s2 = SomeObject(someid=12) - s3 = SomeOtherObject(someid=123, bogus=345) - - class ValidatedOtherObject(object): pass - Session.mapper(ValidatedOtherObject, table2, validate=True) - - v1 = ValidatedOtherObject(someid=12) - self.assertRaises(sa.exc.ArgumentError, ValidatedOtherObject, - someid=12, bogus=345) - - @testing.resolve_artifact_names - def test_dont_clobber_methods(self): - class MyClass(object): - def expunge(self): - return "an expunge !" - - Session.mapper(MyClass, table2) - - assert MyClass().expunge() == "an expunge !" - - -class ScopedMapperTest2(_ScopedTest): - - def define_tables(self, metadata): - Table('table1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30)), - Column('type', String(30))) - Table('table2', metadata, - Column('id', Integer, primary_key=True), - Column('someid', None, ForeignKey('table1.id')), - Column('somedata', String(30))) - - def setup_classes(self): - class BaseClass(_base.ComparableEntity): - pass - class SubClass(BaseClass): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - Session = scoped_session(sa.orm.sessionmaker()) - - Session.mapper(BaseClass, table1, - polymorphic_identity='base', - polymorphic_on=table1.c.type) - Session.mapper(SubClass, table2, - polymorphic_identity='sub', - inherits=BaseClass) - - self.scoping['Session'] = Session - - @testing.resolve_artifact_names - def test_inheritance(self): - def expunge_list(l): - for x in l: - Session.expunge(x) - return l - - b = BaseClass(data='b1') - s = SubClass(data='s1', somedata='somedata') - Session.commit() - Session.expunge_all() - - eq_(expunge_list([BaseClass(data='b1'), - SubClass(data='s1', somedata='somedata')]), - BaseClass.query.all()) - eq_(expunge_list([SubClass(data='s1', somedata='somedata')]), - SubClass.query.all()) - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/selectable.py b/test/orm/selectable.py deleted file mode 100644 index 74c41c852..000000000 --- a/test/orm/selectable.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Generic mapping to Select statements""" -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, String, Integer, select -from testlib.sa.orm import mapper, create_session -from testlib.testing import eq_ -from orm import _base - - -# TODO: more tests mapping to selects - -class SelectableNoFromsTest(_base.MappedTest): - def define_tables(self, metadata): - Table('common', metadata, - Column('id', Integer, primary_key=True), - Column('data', Integer), - Column('extra', String(45))) - - def setup_classes(self): - class Subset(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def test_no_tables(self): - - selectable = select(["x", "y", "z"]) - self.assertRaisesMessage(sa.exc.InvalidRequestError, - "Could not find any Table objects", - mapper, Subset, selectable) - - @testing.emits_warning('.*creating an Alias.*') - @testing.resolve_artifact_names - def test_basic(self): - subset_select = select([common.c.id, common.c.data]) - subset_mapper = mapper(Subset, subset_select) - - sess = create_session(bind=testing.db) - sess.add(Subset(data=1)) - sess.flush() - sess.expunge_all() - - eq_(sess.query(Subset).all(), [Subset(data=1)]) - eq_(sess.query(Subset).filter(Subset.data==1).one(), Subset(data=1)) - eq_(sess.query(Subset).filter(Subset.data!=1).first(), None) - - subset_select = sa.orm.class_mapper(Subset).mapped_table - eq_(sess.query(Subset).filter(subset_select.c.data==1).one(), - Subset(data=1)) - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/session.py b/test/orm/session.py deleted file mode 100644 index 6cbd62a50..000000000 --- a/test/orm/session.py +++ /dev/null @@ -1,1436 +0,0 @@ -import testenv; testenv.configure_for_tests() -import gc -import inspect -import pickle -from sqlalchemy.orm import create_session, sessionmaker, attributes -from testlib import engines, sa, testing, config -from testlib.sa import Table, Column, Integer, String, Sequence -from testlib.sa.orm import mapper, relation, backref, eagerload -from testlib.testing import eq_ -from engine import _base as engine_base -from orm import _base, _fixtures - - -class SessionTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def test_no_close_on_flush(self): - """Flush() doesn't close a connection the session didn't open""" - c = testing.db.connect() - c.execute("select * from users") - - mapper(User, users) - s = create_session(bind=c) - s.add(User(name='first')) - s.flush() - c.execute("select * from users") - - @testing.resolve_artifact_names - def test_close(self): - """close() doesn't close a connection the session didn't open""" - c = testing.db.connect() - c.execute("select * from users") - - mapper(User, users) - s = create_session(bind=c) - s.add(User(name='first')) - s.flush() - c.execute("select * from users") - s.close() - c.execute("select * from users") - - @testing.resolve_artifact_names - def test_no_close_transaction_on_flulsh(self): - c = testing.db.connect() - try: - mapper(User, users) - s = create_session(bind=c) - s.begin() - tran = s.transaction - s.add(User(name='first')) - s.flush() - c.execute("select * from users") - u = User(name='two') - s.add(u) - s.flush() - u = User(name='third') - s.add(u) - s.flush() - assert s.transaction is tran - tran.close() - finally: - c.close() - - @testing.requires.sequences - def test_sequence_execute(self): - seq = Sequence("some_sequence") - seq.create(testing.db) - try: - sess = create_session(bind=testing.db) - eq_(sess.execute(seq), 1) - finally: - seq.drop(testing.db) - - - @testing.resolve_artifact_names - def test_expunge_cascade(self): - mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses':relation(Address, - backref=backref("user", cascade="all"), - cascade="all")}) - - _fixtures.run_inserts_for(users) - _fixtures.run_inserts_for(addresses) - - session = create_session() - u = session.query(User).filter_by(id=7).one() - - # get everything to load in both directions - print [a.user for a in u.addresses] - - # then see if expunge fails - session.expunge(u) - - assert sa.orm.object_session(u) is None - assert sa.orm.attributes.instance_state(u).session_id is None - for a in u.addresses: - assert sa.orm.object_session(a) is None - assert sa.orm.attributes.instance_state(a).session_id is None - - @engines.close_open_connections - @testing.resolve_artifact_names - def test_table_binds_from_expression(self): - """Session can extract Table objects from ClauseElements and match them to tables.""" - - mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses':relation(Address, - backref=backref("user", cascade="all"), - cascade="all")}) - - Session = sessionmaker(binds={users: self.metadata.bind, - addresses: self.metadata.bind}) - sess = Session() - - sess.execute(users.insert(), params=dict(id=1, name='ed')) - eq_(sess.execute(users.select(users.c.id == 1)).fetchall(), - [(1, 'ed')]) - - eq_(sess.execute(users.select(User.id == 1)).fetchall(), - [(1, 'ed')]) - - sess.close() - - @engines.close_open_connections - @testing.resolve_artifact_names - def test_mapped_binds_from_expression(self): - """Session can extract Table objects from ClauseElements and match them to tables.""" - - mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses':relation(Address, - backref=backref("user", cascade="all"), - cascade="all")}) - - Session = sessionmaker(binds={User: self.metadata.bind, - Address: self.metadata.bind}) - sess = Session() - - sess.execute(users.insert(), params=dict(id=1, name='ed')) - eq_(sess.execute(users.select(users.c.id == 1)).fetchall(), - [(1, 'ed')]) - - eq_(sess.execute(users.select(User.id == 1)).fetchall(), - [(1, 'ed')]) - - sess.close() - - @engines.close_open_connections - @testing.resolve_artifact_names - def test_bind_from_metadata(self): - mapper(User, users) - - session = create_session() - session.execute(users.insert(), dict(name='Johnny')) - - assert len(session.query(User).filter_by(name='Johnny').all()) == 1 - - session.execute(users.delete()) - - assert len(session.query(User).filter_by(name='Johnny').all()) == 0 - session.close() - - @testing.requires.independent_connections - @engines.close_open_connections - @testing.resolve_artifact_names - def test_transaction(self): - mapper(User, users) - conn1 = testing.db.connect() - conn2 = testing.db.connect() - - sess = create_session(autocommit=False, bind=conn1) - u = User(name='x') - sess.add(u) - sess.flush() - assert conn1.execute("select count(1) from users").scalar() == 1 - assert conn2.execute("select count(1) from users").scalar() == 0 - sess.commit() - assert conn1.execute("select count(1) from users").scalar() == 1 - assert testing.db.connect().execute("select count(1) from users").scalar() == 1 - sess.close() - - @testing.requires.independent_connections - @engines.close_open_connections - @testing.resolve_artifact_names - def test_autoflush(self): - bind = self.metadata.bind - mapper(User, users) - conn1 = bind.connect() - conn2 = bind.connect() - - sess = create_session(bind=conn1, autocommit=False, autoflush=True) - u = User() - u.name='ed' - sess.add(u) - u2 = sess.query(User).filter_by(name='ed').one() - assert u2 is u - eq_(conn1.execute("select count(1) from users").scalar(), 1) - eq_(conn2.execute("select count(1) from users").scalar(), 0) - sess.commit() - eq_(conn1.execute("select count(1) from users").scalar(), 1) - eq_(bind.connect().execute("select count(1) from users").scalar(), 1) - sess.close() - - @testing.resolve_artifact_names - def test_autoflush_expressions(self): - """test that an expression which is dependent on object state is - evaluated after the session autoflushes. This is the lambda - inside of strategies.py lazy_clause. - - """ - mapper(User, users, properties={ - 'addresses':relation(Address, backref="user")}) - mapper(Address, addresses) - - sess = create_session(autoflush=True, autocommit=False) - u = User(name='ed', addresses=[Address(email_address='foo')]) - sess.add(u) - eq_(sess.query(Address).filter(Address.user==u).one(), - Address(email_address='foo')) - - # still works after "u" is garbage collected - sess.commit() - sess.close() - u = sess.query(User).get(u.id) - q = sess.query(Address).filter(Address.user==u) - del u - gc.collect() - eq_(q.one(), Address(email_address='foo')) - - - @testing.requires.independent_connections - @engines.close_open_connections - @testing.resolve_artifact_names - def test_autoflush_unbound(self): - mapper(User, users) - - try: - sess = create_session(autocommit=False, autoflush=True) - u = User() - u.name='ed' - sess.add(u) - u2 = sess.query(User).filter_by(name='ed').one() - assert u2 is u - assert sess.execute("select count(1) from users", mapper=User).scalar() == 1 - assert testing.db.connect().execute("select count(1) from users").scalar() == 0 - sess.commit() - assert sess.execute("select count(1) from users", mapper=User).scalar() == 1 - assert testing.db.connect().execute("select count(1) from users").scalar() == 1 - sess.close() - except: - sess.rollback() - raise - - @engines.close_open_connections - @testing.resolve_artifact_names - def test_autoflush_2(self): - mapper(User, users) - conn1 = testing.db.connect() - conn2 = testing.db.connect() - - sess = create_session(bind=conn1, autocommit=False, autoflush=True) - u = User() - u.name='ed' - sess.add(u) - sess.commit() - assert conn1.execute("select count(1) from users").scalar() == 1 - assert testing.db.connect().execute("select count(1) from users").scalar() == 1 - sess.commit() - - @testing.resolve_artifact_names - def test_autoflush_rollback(self): - mapper(Address, addresses) - mapper(User, users, properties={ - 'addresses':relation(Address)}) - - _fixtures.run_inserts_for(users) - _fixtures.run_inserts_for(addresses) - - sess = create_session(autocommit=False, autoflush=True) - u = sess.query(User).get(8) - newad = Address(email_address='a new address') - u.addresses.append(newad) - u.name = 'some new name' - assert u.name == 'some new name' - assert len(u.addresses) == 4 - assert newad in u.addresses - sess.rollback() - assert u.name == 'ed' - assert len(u.addresses) == 3 - - assert newad not in u.addresses - # pending objects dont get expired - assert newad.email_address == 'a new address' - - @testing.resolve_artifact_names - def test_autocommit_doesnt_raise_on_pending(self): - mapper(User, users) - session = create_session(autocommit=True) - - session.add(User(name='ed')) - - session.begin() - session.flush() - session.commit() - - def test_active_flag(self): - sess = create_session(bind=config.db, autocommit=True) - assert not sess.is_active - sess.begin() - assert sess.is_active - sess.rollback() - assert not sess.is_active - - @testing.resolve_artifact_names - def test_textual_execute(self): - """test that Session.execute() converts to text()""" - - sess = create_session(bind=self.metadata.bind) - users.insert().execute(id=7, name='jack') - - # use :bindparam style - eq_(sess.execute("select * from users where id=:id", - {'id':7}).fetchall(), - [(7, u'jack')]) - - - # use :bindparam style - eq_(sess.scalar("select id from users where id=:id", - {'id':7}), - 7) - - @engines.close_open_connections - @testing.resolve_artifact_names - def test_subtransaction_on_external(self): - mapper(User, users) - conn = testing.db.connect() - trans = conn.begin() - sess = create_session(bind=conn, autocommit=False, autoflush=True) - sess.begin(subtransactions=True) - u = User(name='ed') - sess.add(u) - sess.flush() - sess.commit() # commit does nothing - trans.rollback() # rolls back - assert len(sess.query(User).all()) == 0 - sess.close() - - @testing.requires.savepoints - @engines.close_open_connections - @testing.resolve_artifact_names - def test_external_nested_transaction(self): - mapper(User, users) - try: - conn = testing.db.connect() - trans = conn.begin() - sess = create_session(bind=conn, autocommit=False, autoflush=True) - u1 = User(name='u1') - sess.add(u1) - sess.flush() - - sess.begin_nested() - u2 = User(name='u2') - sess.add(u2) - sess.flush() - sess.rollback() - - trans.commit() - assert len(sess.query(User).all()) == 1 - except: - conn.close() - raise - - @testing.requires.savepoints - @testing.resolve_artifact_names - def test_heavy_nesting(self): - session = create_session(bind=testing.db) - - session.begin() - session.connection().execute("insert into users (name) values ('user1')") - - session.begin(subtransactions=True) - - session.begin_nested() - - session.connection().execute("insert into users (name) values ('user2')") - assert session.connection().execute("select count(1) from users").scalar() == 2 - - session.rollback() - assert session.connection().execute("select count(1) from users").scalar() == 1 - session.connection().execute("insert into users (name) values ('user3')") - - session.commit() - assert session.connection().execute("select count(1) from users").scalar() == 2 - - @testing.fails_on('sqlite', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_transactions_isolated(self): - mapper(User, users) - users.delete().execute() - - s1 = create_session(bind=testing.db, autocommit=False) - s2 = create_session(bind=testing.db, autocommit=False) - u1 = User(name='u1') - s1.add(u1) - s1.flush() - - assert s2.query(User).all() == [] - - @testing.requires.two_phase_transactions - @testing.resolve_artifact_names - def test_twophase(self): - # TODO: mock up a failure condition here - # to ensure a rollback succeeds - mapper(User, users) - mapper(Address, addresses) - - engine2 = engines.testing_engine() - sess = create_session(autocommit=True, autoflush=False, twophase=True) - sess.bind_mapper(User, testing.db) - sess.bind_mapper(Address, engine2) - sess.begin() - u1 = User(name='u1') - a1 = Address(email_address='u1@e') - sess.add_all((u1, a1)) - sess.commit() - sess.close() - engine2.dispose() - assert users.count().scalar() == 1 - assert addresses.count().scalar() == 1 - - @testing.resolve_artifact_names - def test_subtransaction_on_noautocommit(self): - mapper(User, users) - sess = create_session(autocommit=False, autoflush=True) - sess.begin(subtransactions=True) - u = User(name='u1') - sess.add(u) - sess.flush() - sess.commit() # commit does nothing - sess.rollback() # rolls back - assert len(sess.query(User).all()) == 0 - sess.close() - - @testing.requires.savepoints - @testing.resolve_artifact_names - def test_nested_transaction(self): - mapper(User, users) - sess = create_session() - sess.begin() - - u = User(name='u1') - sess.add(u) - sess.flush() - - sess.begin_nested() # nested transaction - - u2 = User(name='u2') - sess.add(u2) - sess.flush() - - sess.rollback() - - sess.commit() - assert len(sess.query(User).all()) == 1 - sess.close() - - @testing.requires.savepoints - @testing.resolve_artifact_names - def test_nested_autotrans(self): - mapper(User, users) - sess = create_session(autocommit=False) - u = User(name='u1') - sess.add(u) - sess.flush() - - sess.begin_nested() # nested transaction - - u2 = User(name='u2') - sess.add(u2) - sess.flush() - - sess.rollback() - - sess.commit() - assert len(sess.query(User).all()) == 1 - sess.close() - - @testing.requires.savepoints - @testing.resolve_artifact_names - def test_nested_transaction_connection_add(self): - mapper(User, users) - - sess = create_session(autocommit=True) - - sess.begin() - sess.begin_nested() - - u1 = User(name='u1') - sess.add(u1) - sess.flush() - - sess.rollback() - - u2 = User(name='u2') - sess.add(u2) - - sess.commit() - - self.assertEquals(set(sess.query(User).all()), set([u2])) - - sess.begin() - sess.begin_nested() - - u3 = User(name='u3') - sess.add(u3) - sess.commit() # commit the nested transaction - sess.rollback() - - self.assertEquals(set(sess.query(User).all()), set([u2])) - - sess.close() - - @testing.requires.savepoints - @testing.resolve_artifact_names - def test_mixed_transaction_control(self): - mapper(User, users) - - sess = create_session(autocommit=True) - - sess.begin() - sess.begin_nested() - transaction = sess.begin(subtransactions=True) - - sess.add(User(name='u1')) - - transaction.commit() - sess.commit() - sess.commit() - - sess.close() - - self.assertEquals(len(sess.query(User).all()), 1) - - t1 = sess.begin() - t2 = sess.begin_nested() - - sess.add(User(name='u2')) - - t2.commit() - assert sess.transaction is t1 - - sess.close() - - @testing.requires.savepoints - @testing.resolve_artifact_names - def test_mixed_transaction_close(self): - mapper(User, users) - - sess = create_session(autocommit=False) - - sess.begin_nested() - - sess.add(User(name='u1')) - sess.flush() - - sess.close() - - sess.add(User(name='u2')) - sess.commit() - - sess.close() - - self.assertEquals(len(sess.query(User).all()), 1) - - @testing.resolve_artifact_names - def test_error_on_using_inactive_session(self): - mapper(User, users) - - sess = create_session(autocommit=True) - - sess.begin() - sess.begin(subtransactions=True) - - sess.add(User(name='u1')) - sess.flush() - - sess.rollback() - self.assertRaisesMessage(sa.exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True) - sess.close() - - @testing.resolve_artifact_names - def test_no_autocommit_with_explicit_commit(self): - mapper(User, users) - session = create_session(autocommit=False) - - session.add(User(name='ed')) - session.transaction.commit() - assert session.transaction is not None, "autocommit=False should start a new transaction" - - @engines.close_open_connections - @testing.resolve_artifact_names - def test_bound_connection(self): - mapper(User, users) - c = testing.db.connect() - sess = create_session(bind=c) - sess.begin() - transaction = sess.transaction - u = User(name='u1') - sess.add(u) - sess.flush() - assert transaction._connection_for_bind(testing.db) is transaction._connection_for_bind(c) is c - - self.assertRaisesMessage(sa.exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect()) - - transaction.rollback() - assert len(sess.query(User).all()) == 0 - sess.close() - - @testing.resolve_artifact_names - def test_bound_connection_transactional(self): - mapper(User, users) - c = testing.db.connect() - - sess = create_session(bind=c, autocommit=False) - u = User(name='u1') - sess.add(u) - sess.flush() - sess.close() - assert not c.in_transaction() - assert c.scalar("select count(1) from users") == 0 - - sess = create_session(bind=c, autocommit=False) - u = User(name='u2') - sess.add(u) - sess.flush() - sess.commit() - assert not c.in_transaction() - assert c.scalar("select count(1) from users") == 1 - c.execute("delete from users") - assert c.scalar("select count(1) from users") == 0 - - c = testing.db.connect() - - trans = c.begin() - sess = create_session(bind=c, autocommit=True) - u = User(name='u3') - sess.add(u) - sess.flush() - assert c.in_transaction() - trans.commit() - assert not c.in_transaction() - assert c.scalar("select count(1) from users") == 1 - - - @testing.uses_deprecated() - @engines.close_open_connections - @testing.resolve_artifact_names - def test_save_update_delete(self): - - s = create_session() - mapper(User, users, properties={ - 'addresses':relation(Address, cascade="all, delete") - }) - mapper(Address, addresses) - - user = User(name='u1') - - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is not persisted", s.update, user) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is not persisted", s.delete, user) - - s.add(user) - s.flush() - user = s.query(User).one() - s.expunge(user) - assert user not in s - - # modify outside of session, assert changes remain/get saved - user.name = "fred" - s.add(user) - assert user in s - assert user in s.dirty - s.flush() - s.expunge_all() - assert s.query(User).count() == 1 - user = s.query(User).one() - assert user.name == 'fred' - - # ensure its not dirty if no changes occur - s.expunge_all() - assert user not in s - s.add(user) - assert user in s - assert user not in s.dirty - - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is already persistent", s.save, user) - - s2 = create_session() - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is already attached to session", s2.delete, user) - - u2 = s2.query(User).get(user.id) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "another instance with key", s.delete, u2) - - s.expire(user) - s.expunge(user) - assert user not in s - s.delete(user) - assert user in s - - s.flush() - assert user not in s - assert s.query(User).count() == 0 - - @testing.resolve_artifact_names - def test_is_modified(self): - s = create_session() - - mapper(User, users, properties={'addresses':relation(Address)}) - mapper(Address, addresses) - - # save user - u = User(name='fred') - s.add(u) - s.flush() - s.expunge_all() - - user = s.query(User).one() - assert user not in s.dirty - assert not s.is_modified(user) - user.name = 'fred' - assert user in s.dirty - assert not s.is_modified(user) - user.name = 'ed' - assert user in s.dirty - assert s.is_modified(user) - s.flush() - assert user not in s.dirty - assert not s.is_modified(user) - - a = Address() - user.addresses.append(a) - assert user in s.dirty - assert s.is_modified(user) - assert not s.is_modified(user, include_collections=False) - - - @testing.resolve_artifact_names - def test_weak_ref(self): - """test the weak-referencing identity map, which strongly-references modified items.""" - - s = create_session() - mapper(User, users) - - s.add(User(name='ed')) - s.flush() - assert not s.dirty - - user = s.query(User).one() - del user - gc.collect() - assert len(s.identity_map) == 0 - - user = s.query(User).one() - user.name = 'fred' - del user - gc.collect() - assert len(s.identity_map) == 1 - assert len(s.dirty) == 1 - assert None not in s.dirty - s.flush() - gc.collect() - assert not s.dirty - assert not s.identity_map - - user = s.query(User).one() - assert user.name == 'fred' - assert s.identity_map - - @testing.resolve_artifact_names - def test_weakref_with_cycles_o2m(self): - s = sessionmaker()() - mapper(User, users, properties={ - "addresses":relation(Address, backref="user") - }) - mapper(Address, addresses) - s.add(User(name="ed", addresses=[Address(email_address="ed1")])) - s.commit() - - user = s.query(User).options(eagerload(User.addresses)).one() - user.addresses[0].user # lazyload - eq_(user, User(name="ed", addresses=[Address(email_address="ed1")])) - - del user - gc.collect() - assert len(s.identity_map) == 0 - - user = s.query(User).options(eagerload(User.addresses)).one() - user.addresses[0].email_address='ed2' - user.addresses[0].user # lazyload - del user - gc.collect() - assert len(s.identity_map) == 2 - - s.commit() - user = s.query(User).options(eagerload(User.addresses)).one() - eq_(user, User(name="ed", addresses=[Address(email_address="ed2")])) - - @testing.resolve_artifact_names - def test_weakref_with_cycles_o2o(self): - s = sessionmaker()() - mapper(User, users, properties={ - "address":relation(Address, backref="user", uselist=False) - }) - mapper(Address, addresses) - s.add(User(name="ed", address=Address(email_address="ed1"))) - s.commit() - - user = s.query(User).options(eagerload(User.address)).one() - user.address.user - eq_(user, User(name="ed", address=Address(email_address="ed1"))) - - del user - gc.collect() - assert len(s.identity_map) == 0 - - user = s.query(User).options(eagerload(User.address)).one() - user.address.email_address='ed2' - user.address.user # lazyload - - del user - gc.collect() - assert len(s.identity_map) == 2 - - s.commit() - user = s.query(User).options(eagerload(User.address)).one() - eq_(user, User(name="ed", address=Address(email_address="ed2"))) - - @testing.resolve_artifact_names - def test_strong_ref(self): - s = create_session(weak_identity_map=False) - mapper(User, users) - - # save user - s.add(User(name='u1')) - s.flush() - user = s.query(User).one() - user = None - print s.identity_map - import gc - gc.collect() - assert len(s.identity_map) == 1 - - user = s.query(User).one() - assert not s.identity_map._modified - user.name = 'u2' - assert s.identity_map._modified - s.flush() - eq_(users.select().execute().fetchall(), [(user.id, 'u2')]) - - - @testing.resolve_artifact_names - def test_prune(self): - s = create_session(weak_identity_map=False) - mapper(User, users) - - for o in [User(name='u%s' % x) for x in xrange(10)]: - s.add(o) - # o is still live after this loop... - - self.assert_(len(s.identity_map) == 0) - self.assert_(s.prune() == 0) - s.flush() - import gc - gc.collect() - self.assert_(s.prune() == 9) - self.assert_(len(s.identity_map) == 1) - - id = o.id - del o - self.assert_(s.prune() == 1) - self.assert_(len(s.identity_map) == 0) - - u = s.query(User).get(id) - self.assert_(s.prune() == 0) - self.assert_(len(s.identity_map) == 1) - u.name = 'squiznart' - del u - self.assert_(s.prune() == 0) - self.assert_(len(s.identity_map) == 1) - s.flush() - self.assert_(s.prune() == 1) - self.assert_(len(s.identity_map) == 0) - - s.add(User(name='x')) - self.assert_(s.prune() == 0) - self.assert_(len(s.identity_map) == 0) - s.flush() - self.assert_(len(s.identity_map) == 1) - self.assert_(s.prune() == 1) - self.assert_(len(s.identity_map) == 0) - - u = s.query(User).get(id) - s.delete(u) - del u - self.assert_(s.prune() == 0) - self.assert_(len(s.identity_map) == 1) - s.flush() - self.assert_(s.prune() == 0) - self.assert_(len(s.identity_map) == 0) - - @testing.resolve_artifact_names - def test_no_save_cascade_1(self): - mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relation(Address, cascade="none", backref="user"))) - s = create_session() - - u = User(name='u1') - s.add(u) - a = Address(email_address='u1@e') - u.addresses.append(a) - assert u in s - assert a not in s - s.flush() - print "\n".join([repr(x.__dict__) for x in s]) - s.expunge_all() - assert s.query(User).one().id == u.id - assert s.query(Address).first() is None - - @testing.resolve_artifact_names - def test_no_save_cascade_2(self): - mapper(Address, addresses) - mapper(User, users, properties=dict( - addresses=relation(Address, - cascade="all", - backref=backref("user", cascade="none")))) - - s = create_session() - u = User(name='u1') - a = Address(email_address='u1@e') - a.user = u - s.add(a) - assert u not in s - assert a in s - s.flush() - s.expunge_all() - assert s.query(Address).one().id == a.id - assert s.query(User).first() is None - - @testing.resolve_artifact_names - def test_extension(self): - mapper(User, users) - log = [] - class MyExt(sa.orm.session.SessionExtension): - def before_commit(self, session): - log.append('before_commit') - def after_commit(self, session): - log.append('after_commit') - def after_rollback(self, session): - log.append('after_rollback') - def before_flush(self, session, flush_context, objects): - log.append('before_flush') - def after_flush(self, session, flush_context): - log.append('after_flush') - def after_flush_postexec(self, session, flush_context): - log.append('after_flush_postexec') - def after_begin(self, session, transaction, connection): - log.append('after_begin') - def after_attach(self, session, instance): - log.append('after_attach') - def after_bulk_update(self, session, query, query_context, result): - log.append('after_bulk_update') - def after_bulk_delete(self, session, query, query_context, result): - log.append('after_bulk_delete') - - sess = create_session(extension = MyExt()) - u = User(name='u1') - sess.add(u) - sess.flush() - assert log == ['after_attach', 'before_flush', 'after_begin', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec'] - - log = [] - sess = create_session(autocommit=False, extension=MyExt()) - u = User(name='u1') - sess.add(u) - sess.flush() - assert log == ['after_attach', 'before_flush', 'after_begin', 'after_flush', 'after_flush_postexec'] - - log = [] - u.name = 'ed' - sess.commit() - assert log == ['before_commit', 'before_flush', 'after_flush', 'after_flush_postexec', 'after_commit'] - - log = [] - sess.commit() - assert log == ['before_commit', 'after_commit'] - - log = [] - sess.query(User).delete() - assert log == ['after_begin', 'after_bulk_delete'] - - log = [] - sess.query(User).update({'name': 'foo'}) - assert log == ['after_bulk_update'] - - log = [] - sess = create_session(autocommit=False, extension=MyExt(), bind=testing.db) - conn = sess.connection() - assert log == ['after_begin'] - - @testing.resolve_artifact_names - def test_before_flush(self): - """test that the flush plan can be affected during before_flush()""" - - mapper(User, users) - - class MyExt(sa.orm.session.SessionExtension): - def before_flush(self, session, flush_context, objects): - for obj in list(session.new) + list(session.dirty): - if isinstance(obj, User): - session.add(User(name='another %s' % obj.name)) - for obj in list(session.deleted): - if isinstance(obj, User): - x = session.query(User).filter(User.name=='another %s' % obj.name).one() - session.delete(x) - - sess = create_session(extension = MyExt(), autoflush=True) - u = User(name='u1') - sess.add(u) - sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), - [ - User(name='another u1'), - User(name='u1') - ] - ) - - sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), - [ - User(name='another u1'), - User(name='u1') - ] - ) - - u.name='u2' - sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), - [ - User(name='another u1'), - User(name='another u2'), - User(name='u2') - ] - ) - - sess.delete(u) - sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), - [ - User(name='another u1'), - ] - ) - - @testing.resolve_artifact_names - def test_before_flush_affects_dirty(self): - mapper(User, users) - - class MyExt(sa.orm.session.SessionExtension): - def before_flush(self, session, flush_context, objects): - for obj in list(session.identity_map.values()): - obj.name += " modified" - - sess = create_session(extension = MyExt(), autoflush=True) - u = User(name='u1') - sess.add(u) - sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), - [ - User(name='u1') - ] - ) - - sess.add(User(name='u2')) - sess.flush() - sess.expunge_all() - self.assertEquals(sess.query(User).order_by(User.name).all(), - [ - User(name='u1 modified'), - User(name='u2') - ] - ) - - @testing.resolve_artifact_names - def test_reentrant_flush(self): - - mapper(User, users) - - class MyExt(sa.orm.session.SessionExtension): - def before_flush(s, session, flush_context, objects): - session.flush() - - sess = create_session(extension=MyExt()) - sess.add(User(name='foo')) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "already flushing", sess.flush) - - @testing.resolve_artifact_names - def test_pickled_update(self): - mapper(User, users) - sess1 = create_session() - sess2 = create_session() - - u1 = User(name='u1') - sess1.add(u1) - - self.assertRaisesMessage(sa.exc.InvalidRequestError, "already attached to session", sess2.add, u1) - - u2 = pickle.loads(pickle.dumps(u1)) - - sess2.add(u2) - - @testing.resolve_artifact_names - def test_duplicate_update(self): - mapper(User, users) - Session = sessionmaker() - sess = Session() - - u1 = User(name='u1') - sess.add(u1) - sess.flush() - assert u1.id is not None - - sess.expunge(u1) - - assert u1 not in sess - assert Session.object_session(u1) is None - - u2 = sess.query(User).get(u1.id) - assert u2 is not None and u2 is not u1 - assert u2 in sess - - self.assertRaises(Exception, lambda: sess.add(u1)) - - sess.expunge(u2) - assert u2 not in sess - assert Session.object_session(u2) is None - - u1.name = "John" - u2.name = "Doe" - - sess.add(u1) - assert u1 in sess - assert Session.object_session(u1) is sess - - sess.flush() - - sess.expunge_all() - - u3 = sess.query(User).get(u1.id) - assert u3 is not u1 and u3 is not u2 and u3.name == u1.name - - @testing.resolve_artifact_names - def test_no_double_save(self): - sess = create_session() - class Foo(object): - def __init__(self): - sess.add(self) - class Bar(Foo): - def __init__(self): - sess.add(self) - Foo.__init__(self) - mapper(Foo, users) - mapper(Bar, users) - - b = Bar() - assert b in sess - assert len(list(sess)) == 1 - -class DisposedStates(_base.MappedTest): - run_setup_mappers = 'once' - run_inserts = 'once' - run_deletes = None - - def define_tables(self, metadata): - global t1 - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) - - def setup_mappers(self): - global T - class T(object): - def __init__(self, data): - self.data = data - mapper(T, t1) - - def tearDown(self): - from sqlalchemy.orm.session import _sessions - _sessions.clear() - super(DisposedStates, self).tearDown() - - def _set_imap_in_disposal(self, sess, *objs): - """remove selected objects from the given session, as though they - were dereferenced and removed from WeakIdentityMap. - - Hardcodes the identity map's "all_states()" method to return the full list - of states. This simulates the all_states() method returning results, afterwhich - some of the states get garbage collected (this normally only happens during - asynchronous gc). The Session now has one or more - InstanceState's which have been removed from the identity map and disposed. - - Will the Session not trip over this ??? Stay tuned. - - """ - all_states = sess.identity_map.all_states() - sess.identity_map.all_states = lambda: all_states - for obj in objs: - state = attributes.instance_state(obj) - sess.identity_map.remove(state) - state.dispose() - - def _test_session(self, **kwargs): - global sess - sess = create_session(**kwargs) - - data = o1, o2, o3, o4, o5 = [T('t1'), T('t2'), T('t3'), T('t4'), T('t5')] - - sess.add_all(data) - - sess.flush() - - o1.data = 't1modified' - o5.data = 't5modified' - - self._set_imap_in_disposal(sess, o2, o4, o5) - return sess - - def test_flush(self): - self._test_session().flush() - - def test_clear(self): - self._test_session().expunge_all() - - def test_close(self): - self._test_session().close() - - def test_expunge_all(self): - self._test_session().expunge_all() - - def test_expire_all(self): - self._test_session().expire_all() - - def test_rollback(self): - sess = self._test_session(autocommit=False, expire_on_commit=True) - sess.commit() - - sess.rollback() - - -class SessionInterface(testing.TestBase): - """Bogus args to Session methods produce actionable exceptions.""" - - # TODO: expand with message body assertions. - - _class_methods = set(( - 'connection', 'execute', 'get_bind', 'scalar')) - - def _public_session_methods(self): - Session = sa.orm.session.Session - - blacklist = set(('begin', 'query')) - - ok = set() - for meth in Session.public_methods: - if meth in blacklist: - continue - spec = inspect.getargspec(getattr(Session, meth)) - if len(spec[0]) > 1 or spec[1]: - ok.add(meth) - return ok - - def _map_it(self, cls): - return mapper(cls, Table('t', sa.MetaData(), - Column('id', Integer, primary_key=True))) - - @testing.uses_deprecated() - def _test_instance_guards(self, user_arg): - watchdog = set() - - def x_raises_(obj, method, *args, **kw): - watchdog.add(method) - callable_ = getattr(obj, method) - self.assertRaises(sa.orm.exc.UnmappedInstanceError, - callable_, *args, **kw) - - def raises_(method, *args, **kw): - x_raises_(create_session(), method, *args, **kw) - - raises_('__contains__', user_arg) - - raises_('add', user_arg) - - raises_('add_all', (user_arg,)) - - raises_('delete', user_arg) - - raises_('expire', user_arg) - - raises_('expunge', user_arg) - - # flush will no-op without something in the unit of work - def _(): - class OK(object): - pass - self._map_it(OK) - - s = create_session() - s.add(OK()) - x_raises_(s, 'flush', (user_arg,)) - _() - - raises_('is_modified', user_arg) - - raises_('merge', user_arg) - - raises_('refresh', user_arg) - - raises_('save', user_arg) - - raises_('save_or_update', user_arg) - - raises_('update', user_arg) - - instance_methods = self._public_session_methods() - self._class_methods - - eq_(watchdog, instance_methods, - watchdog.symmetric_difference(instance_methods)) - - def _test_class_guards(self, user_arg): - watchdog = set() - - def raises_(method, *args, **kw): - watchdog.add(method) - callable_ = getattr(create_session(), method) - self.assertRaises(sa.orm.exc.UnmappedClassError, - callable_, *args, **kw) - - raises_('connection', mapper=user_arg) - - raises_('execute', 'SELECT 1', mapper=user_arg) - - raises_('get_bind', mapper=user_arg) - - raises_('scalar', 'SELECT 1', mapper=user_arg) - - eq_(watchdog, self._class_methods, - watchdog.symmetric_difference(self._class_methods)) - - def test_unmapped_instance(self): - class Unmapped(object): - pass - - self._test_instance_guards(Unmapped()) - self._test_class_guards(Unmapped) - - def test_unmapped_primitives(self): - for prim in ('doh', 123, ('t', 'u', 'p', 'l', 'e')): - self._test_instance_guards(prim) - self._test_class_guards(prim) - - def test_unmapped_class_for_instance(self): - class Unmapped(object): - pass - - self._test_instance_guards(Unmapped) - self._test_class_guards(Unmapped) - - def test_mapped_class_for_instance(self): - class Mapped(object): - pass - self._map_it(Mapped) - - self._test_instance_guards(Mapped) - # no class guards- it would pass. - - def test_missing_state(self): - class Mapped(object): - pass - early = Mapped() - self._map_it(Mapped) - - self._test_instance_guards(early) - self._test_class_guards(early) - - -class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest): - def create_engine(self): - return engines.testing_engine(options=dict(strategy='threadlocal')) - - def define_tables(self, metadata): - Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(20)), - test_needs_acid=True) - - def setup_classes(self): - class User(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(User, users) - - def setUpAll(self): - engine_base.AltEngineTest.setUpAll(self) - _base.MappedTest.setUpAll(self) - - - def tearDownAll(self): - _base.MappedTest.tearDownAll(self) - engine_base.AltEngineTest.tearDownAll(self) - - @testing.exclude('mysql', '<', (5, 0, 3), 'FIXME: unknown') - @testing.resolve_artifact_names - def test_session_nesting(self): - sess = create_session(bind=self.engine) - self.engine.begin() - u = User(name='ed') - sess.add(u) - sess.flush() - self.engine.commit() - - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/sharding/alltests.py b/test/orm/sharding/alltests.py deleted file mode 100644 index 09fa86212..000000000 --- a/test/orm/sharding/alltests.py +++ /dev/null @@ -1,18 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -def suite(): - modules_to_test = ( - 'orm.sharding.shard', - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/shard.py deleted file mode 100644 index 10aaee131..000000000 --- a/test/orm/sharding/shard.py +++ /dev/null @@ -1,163 +0,0 @@ -import testenv; testenv.configure_for_tests() -import datetime, os -from sqlalchemy import * -from sqlalchemy import sql -from sqlalchemy.orm import * -from sqlalchemy.orm.shard import ShardedSession -from sqlalchemy.sql import operators -from testlib import * -from testlib.testing import eq_ - -# TODO: ShardTest can be turned into a base for further subclasses - -class ShardTest(TestBase): - def setUpAll(self): - global db1, db2, db3, db4, weather_locations, weather_reports - - db1 = create_engine('sqlite:///shard1.db') - db2 = create_engine('sqlite:///shard2.db') - db3 = create_engine('sqlite:///shard3.db') - db4 = create_engine('sqlite:///shard4.db') - - meta = MetaData() - ids = Table('ids', meta, - Column('nextid', Integer, nullable=False)) - - def id_generator(ctx): - # in reality, might want to use a separate transaction for this. - c = db1.connect() - nextid = c.execute(ids.select(for_update=True)).scalar() - c.execute(ids.update(values={ids.c.nextid : ids.c.nextid + 1})) - return nextid - - weather_locations = Table("weather_locations", meta, - Column('id', Integer, primary_key=True, default=id_generator), - Column('continent', String(30), nullable=False), - Column('city', String(50), nullable=False) - ) - - weather_reports = Table("weather_reports", meta, - Column('id', Integer, primary_key=True), - Column('location_id', Integer, ForeignKey('weather_locations.id')), - Column('temperature', Float), - Column('report_time', DateTime, default=datetime.datetime.now), - ) - - for db in (db1, db2, db3, db4): - meta.create_all(db) - - db1.execute(ids.insert(), nextid=1) - - self.setup_session() - self.setup_mappers() - - def tearDownAll(self): - for db in (db1, db2, db3, db4): - db.connect().invalidate() - for i in range(1,5): - os.remove("shard%d.db" % i) - - def setup_session(self): - global create_session - - shard_lookup = { - 'North America':'north_america', - 'Asia':'asia', - 'Europe':'europe', - 'South America':'south_america' - } - - def shard_chooser(mapper, instance, clause=None): - if isinstance(instance, WeatherLocation): - return shard_lookup[instance.continent] - else: - return shard_chooser(mapper, instance.location) - - def id_chooser(query, ident): - return ['north_america', 'asia', 'europe', 'south_america'] - - def query_chooser(query): - ids = [] - - class FindContinent(sql.ClauseVisitor): - def visit_binary(self, binary): - if binary.left is weather_locations.c.continent: - if binary.operator == operators.eq: - ids.append(shard_lookup[binary.right.value]) - elif binary.operator == operators.in_op: - for bind in binary.right.clauses: - ids.append(shard_lookup[bind.value]) - - FindContinent().traverse(query._criterion) - if len(ids) == 0: - return ['north_america', 'asia', 'europe', 'south_america'] - else: - return ids - - create_session = sessionmaker(class_=ShardedSession, autoflush=True, autocommit=False) - - create_session.configure(shards={ - 'north_america':db1, - 'asia':db2, - 'europe':db3, - 'south_america':db4 - }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser) - - - def setup_mappers(self): - global WeatherLocation, Report - - class WeatherLocation(object): - def __init__(self, continent, city): - self.continent = continent - self.city = city - - class Report(object): - def __init__(self, temperature): - self.temperature = temperature - - mapper(WeatherLocation, weather_locations, properties={ - 'reports':relation(Report, backref='location'), - 'city': deferred(weather_locations.c.city), - }) - - mapper(Report, weather_reports) - - def test_roundtrip(self): - tokyo = WeatherLocation('Asia', 'Tokyo') - newyork = WeatherLocation('North America', 'New York') - toronto = WeatherLocation('North America', 'Toronto') - london = WeatherLocation('Europe', 'London') - dublin = WeatherLocation('Europe', 'Dublin') - brasilia = WeatherLocation('South America', 'Brasila') - quito = WeatherLocation('South America', 'Quito') - - tokyo.reports.append(Report(80.0)) - newyork.reports.append(Report(75)) - quito.reports.append(Report(85)) - - sess = create_session() - for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: - sess.add(c) - sess.commit() - tokyo.city # reload 'city' attribute on tokyo - sess.expunge_all() - - eq_(db2.execute(weather_locations.select()).fetchall(), [(1, 'Asia', 'Tokyo')]) - eq_(db1.execute(weather_locations.select()).fetchall(), [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')]) - eq_(sess.execute(weather_locations.select(), shard_id='asia').fetchall(), [(1, 'Asia', 'Tokyo')]) - - t = sess.query(WeatherLocation).get(tokyo.id) - eq_(t.city, tokyo.city) - eq_(t.reports[0].temperature, 80.0) - - north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America') - eq_(set([c.city for c in north_american_cities]), set(['New York', 'Toronto'])) - - asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_(['Europe', 'Asia'])) - eq_(set([c.city for c in asia_and_europe]), set(['Tokyo', 'London', 'Dublin'])) - - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/sharding/test_shard.py b/test/orm/sharding/test_shard.py new file mode 100644 index 000000000..89e23fb75 --- /dev/null +++ b/test/orm/sharding/test_shard.py @@ -0,0 +1,164 @@ +import datetime, os +from sqlalchemy import * +from sqlalchemy import sql +from sqlalchemy.orm import * +from sqlalchemy.orm.shard import ShardedSession +from sqlalchemy.sql import operators +from sqlalchemy.test import * +from sqlalchemy.test.testing import eq_ + +# TODO: ShardTest can be turned into a base for further subclasses + +class ShardTest(TestBase): + @classmethod + def setup_class(cls): + global db1, db2, db3, db4, weather_locations, weather_reports + + db1 = create_engine('sqlite:///shard1.db') + db2 = create_engine('sqlite:///shard2.db') + db3 = create_engine('sqlite:///shard3.db') + db4 = create_engine('sqlite:///shard4.db') + + meta = MetaData() + ids = Table('ids', meta, + Column('nextid', Integer, nullable=False)) + + def id_generator(ctx): + # in reality, might want to use a separate transaction for this. + c = db1.connect() + nextid = c.execute(ids.select(for_update=True)).scalar() + c.execute(ids.update(values={ids.c.nextid : ids.c.nextid + 1})) + return nextid + + weather_locations = Table("weather_locations", meta, + Column('id', Integer, primary_key=True, default=id_generator), + Column('continent', String(30), nullable=False), + Column('city', String(50), nullable=False) + ) + + weather_reports = Table("weather_reports", meta, + Column('id', Integer, primary_key=True), + Column('location_id', Integer, ForeignKey('weather_locations.id')), + Column('temperature', Float), + Column('report_time', DateTime, default=datetime.datetime.now), + ) + + for db in (db1, db2, db3, db4): + meta.create_all(db) + + db1.execute(ids.insert(), nextid=1) + + cls.setup_session() + cls.setup_mappers() + + @classmethod + def teardown_class(cls): + for db in (db1, db2, db3, db4): + db.connect().invalidate() + for i in range(1,5): + os.remove("shard%d.db" % i) + + @classmethod + def setup_session(cls): + global create_session + + shard_lookup = { + 'North America':'north_america', + 'Asia':'asia', + 'Europe':'europe', + 'South America':'south_america' + } + + def shard_chooser(mapper, instance, clause=None): + if isinstance(instance, WeatherLocation): + return shard_lookup[instance.continent] + else: + return shard_chooser(mapper, instance.location) + + def id_chooser(query, ident): + return ['north_america', 'asia', 'europe', 'south_america'] + + def query_chooser(query): + ids = [] + + class FindContinent(sql.ClauseVisitor): + def visit_binary(self, binary): + if binary.left is weather_locations.c.continent: + if binary.operator == operators.eq: + ids.append(shard_lookup[binary.right.value]) + elif binary.operator == operators.in_op: + for bind in binary.right.clauses: + ids.append(shard_lookup[bind.value]) + + FindContinent().traverse(query._criterion) + if len(ids) == 0: + return ['north_america', 'asia', 'europe', 'south_america'] + else: + return ids + + create_session = sessionmaker(class_=ShardedSession, autoflush=True, autocommit=False) + + create_session.configure(shards={ + 'north_america':db1, + 'asia':db2, + 'europe':db3, + 'south_america':db4 + }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser) + + + @classmethod + def setup_mappers(cls): + global WeatherLocation, Report + + class WeatherLocation(object): + def __init__(self, continent, city): + self.continent = continent + self.city = city + + class Report(object): + def __init__(self, temperature): + self.temperature = temperature + + mapper(WeatherLocation, weather_locations, properties={ + 'reports':relation(Report, backref='location'), + 'city': deferred(weather_locations.c.city), + }) + + mapper(Report, weather_reports) + + def test_roundtrip(self): + tokyo = WeatherLocation('Asia', 'Tokyo') + newyork = WeatherLocation('North America', 'New York') + toronto = WeatherLocation('North America', 'Toronto') + london = WeatherLocation('Europe', 'London') + dublin = WeatherLocation('Europe', 'Dublin') + brasilia = WeatherLocation('South America', 'Brasila') + quito = WeatherLocation('South America', 'Quito') + + tokyo.reports.append(Report(80.0)) + newyork.reports.append(Report(75)) + quito.reports.append(Report(85)) + + sess = create_session() + for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: + sess.add(c) + sess.commit() + tokyo.city # reload 'city' attribute on tokyo + sess.expunge_all() + + eq_(db2.execute(weather_locations.select()).fetchall(), [(1, 'Asia', 'Tokyo')]) + eq_(db1.execute(weather_locations.select()).fetchall(), [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')]) + eq_(sess.execute(weather_locations.select(), shard_id='asia').fetchall(), [(1, 'Asia', 'Tokyo')]) + + t = sess.query(WeatherLocation).get(tokyo.id) + eq_(t.city, tokyo.city) + eq_(t.reports[0].temperature, 80.0) + + north_american_cities = sess.query(WeatherLocation).filter(WeatherLocation.continent == 'North America') + eq_(set([c.city for c in north_american_cities]), set(['New York', 'Toronto'])) + + asia_and_europe = sess.query(WeatherLocation).filter(WeatherLocation.continent.in_(['Europe', 'Asia'])) + eq_(set([c.city for c in asia_and_europe]), set(['Tokyo', 'London', 'Dublin'])) + + + diff --git a/test/orm/test_association.py b/test/orm/test_association.py new file mode 100644 index 000000000..ee7fb7af9 --- /dev/null +++ b/test/orm/test_association.py @@ -0,0 +1,147 @@ + +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base +from sqlalchemy.test.testing import eq_ + + +class AssociationTest(_base.MappedTest): + run_setup_classes = 'once' + run_setup_mappers = 'once' + + @classmethod + def define_tables(cls, metadata): + Table('items', metadata, + Column('item_id', Integer, primary_key=True), + Column('name', String(40))) + Table('item_keywords', metadata, + Column('item_id', Integer, ForeignKey('items.item_id')), + Column('keyword_id', Integer, ForeignKey('keywords.keyword_id')), + Column('data', String(40))) + Table('keywords', metadata, + Column('keyword_id', Integer, primary_key=True), + Column('name', String(40))) + + @classmethod + def setup_classes(cls): + class Item(_base.BasicEntity): + def __init__(self, name): + self.name = name + def __repr__(self): + return "Item id=%d name=%s keywordassoc=%r" % ( + self.item_id, self.name, self.keywords) + + class Keyword(_base.BasicEntity): + def __init__(self, name): + self.name = name + def __repr__(self): + return "Keyword id=%d name=%s" % (self.keyword_id, self.name) + + class KeywordAssociation(_base.BasicEntity): + def __init__(self, keyword, data): + self.keyword = keyword + self.data = data + def __repr__(self): + return "KeywordAssociation itemid=%d keyword=%r data=%s" % ( + self.item_id, self.keyword, self.data) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + items, item_keywords, keywords = cls.tables.get_all( + 'items', 'item_keywords', 'keywords') + + mapper(Keyword, keywords) + mapper(KeywordAssociation, item_keywords, properties={ + 'keyword':relation(Keyword, lazy=False)}, + primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id], + order_by=[item_keywords.c.data]) + + mapper(Item, items, properties={ + 'keywords' : relation(KeywordAssociation, + cascade="all, delete-orphan") + }) + + @testing.resolve_artifact_names + def test_insert(self): + sess = create_session() + item1 = Item('item1') + item2 = Item('item2') + item1.keywords.append(KeywordAssociation(Keyword('blue'), 'blue_assoc')) + item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) + item2.keywords.append(KeywordAssociation(Keyword('green'), 'green_assoc')) + sess.add_all((item1, item2)) + sess.flush() + saved = repr([item1, item2]) + sess.expunge_all() + l = sess.query(Item).all() + loaded = repr(l) + eq_(saved, loaded) + + @testing.resolve_artifact_names + def test_replace(self): + sess = create_session() + item1 = Item('item1') + item1.keywords.append(KeywordAssociation(Keyword('blue'), 'blue_assoc')) + item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) + sess.add(item1) + sess.flush() + + red_keyword = item1.keywords[1].keyword + del item1.keywords[1] + item1.keywords.append(KeywordAssociation(red_keyword, 'new_red_assoc')) + sess.flush() + saved = repr([item1]) + sess.expunge_all() + l = sess.query(Item).all() + loaded = repr(l) + eq_(saved, loaded) + + @testing.resolve_artifact_names + def test_modify(self): + sess = create_session() + item1 = Item('item1') + item2 = Item('item2') + item1.keywords.append(KeywordAssociation(Keyword('blue'), 'blue_assoc')) + item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) + item2.keywords.append(KeywordAssociation(Keyword('green'), 'green_assoc')) + sess.add_all((item1, item2)) + sess.flush() + + red_keyword = item1.keywords[1].keyword + del item1.keywords[0] + del item1.keywords[0] + purple_keyword = Keyword('purple') + item1.keywords.append(KeywordAssociation(red_keyword, 'new_red_assoc')) + item2.keywords.append(KeywordAssociation(purple_keyword, 'purple_item2_assoc')) + item1.keywords.append(KeywordAssociation(purple_keyword, 'purple_item1_assoc')) + item1.keywords.append(KeywordAssociation(Keyword('yellow'), 'yellow_assoc')) + + sess.flush() + saved = repr([item1, item2]) + sess.expunge_all() + l = sess.query(Item).all() + loaded = repr(l) + eq_(saved, loaded) + + @testing.resolve_artifact_names + def test_delete(self): + sess = create_session() + item1 = Item('item1') + item2 = Item('item2') + item1.keywords.append(KeywordAssociation(Keyword('blue'), 'blue_assoc')) + item1.keywords.append(KeywordAssociation(Keyword('red'), 'red_assoc')) + item2.keywords.append(KeywordAssociation(Keyword('green'), 'green_assoc')) + sess.add_all((item1, item2)) + sess.flush() + eq_(item_keywords.count().scalar(), 3) + + sess.delete(item1) + sess.delete(item2) + sess.flush() + eq_(item_keywords.count().scalar(), 0) + + diff --git a/test/orm/test_assorted_eager.py b/test/orm/test_assorted_eager.py new file mode 100644 index 000000000..09f007547 --- /dev/null +++ b/test/orm/test_assorted_eager.py @@ -0,0 +1,902 @@ +"""Exercises for eager loading. + +Derived from mailing list-reported problems and trac tickets. + +""" +import datetime + +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base + + +class EagerTest(_base.MappedTest): + run_deletes = None + run_inserts = "once" + run_setup_mappers = "once" + + @classmethod + def define_tables(cls, metadata): + # determine a literal value for "false" based on the dialect + # FIXME: this DefaultClause setup is bogus. + + dialect = testing.db.dialect + bp = sa.Boolean().dialect_impl(dialect).bind_processor(dialect) + + if bp: + false = str(bp(False)) + elif testing.against('maxdb'): + false = text('FALSE') + else: + false = str(False) + cls.other_artifacts['false'] = false + + Table('owners', metadata , + Column('id', Integer, primary_key=True, nullable=False), + Column('data', String(30))) + + Table('categories', metadata, + Column('id', Integer, primary_key=True, nullable=False), + Column('name', String(20))) + + Table('tests', metadata , + Column('id', Integer, primary_key=True, nullable=False ), + Column('owner_id', Integer, ForeignKey('owners.id'), + nullable=False), + Column('category_id', Integer, ForeignKey('categories.id'), + nullable=False)) + + Table('options', metadata , + Column('test_id', Integer, ForeignKey('tests.id'), + primary_key=True, nullable=False), + Column('owner_id', Integer, ForeignKey('owners.id'), + primary_key=True, nullable=False), + Column('someoption', sa.Boolean, server_default=false, + nullable=False)) + + @classmethod + def setup_classes(cls): + class Owner(_base.BasicEntity): + pass + + class Category(_base.BasicEntity): + pass + + class Thing(_base.BasicEntity): + pass + + class Option(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Owner, owners) + + mapper(Category, categories) + + mapper(Option, options, properties=dict( + owner=relation(Owner), + test=relation(Thing))) + + mapper(Thing, tests, properties=dict( + owner=relation(Owner, backref='tests'), + category=relation(Category), + owner_option=relation(Option, + primaryjoin=sa.and_(tests.c.id == options.c.test_id, + tests.c.owner_id == options.c.owner_id), + foreign_keys=[options.c.test_id, options.c.owner_id], + uselist=False))) + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + session = create_session() + + o = Owner() + c = Category(name='Some Category') + session.add_all(( + Thing(owner=o, category=c), + Thing(owner=o, category=c, owner_option=Option(someoption=True)), + Thing(owner=o, category=c, owner_option=Option()))) + + session.flush() + + @testing.resolve_artifact_names + def test_noorm(self): + """test the control case""" + # I want to display a list of tests owned by owner 1 + # if someoption is false or he hasn't specified it yet (null) + # but not if he set it to true (example someoption is for hiding) + + # desired output for owner 1 + # test_id, cat_name + # 1 'Some Category' + # 3 " + + # not orm style correct query + print "Obtaining correct results without orm" + result = sa.select( + [tests.c.id,categories.c.name], + sa.and_(tests.c.owner_id == 1, + sa.or_(options.c.someoption==None, + options.c.someoption==False)), + order_by=[tests.c.id], + from_obj=[tests.join(categories).outerjoin(options, sa.and_( + tests.c.id == options.c.test_id, + tests.c.owner_id == options.c.owner_id))] + ).execute().fetchall() + eq_(result, [(1, u'Some Category'), (3, u'Some Category')]) + + @testing.resolve_artifact_names + def test_withouteagerload(self): + s = create_session() + l = (s.query(Thing). + select_from(tests.outerjoin(options, + sa.and_(tests.c.id == options.c.test_id, + tests.c.owner_id == + options.c.owner_id))). + filter(sa.and_(tests.c.owner_id==1, + sa.or_(options.c.someoption==None, + options.c.someoption==False)))) + + result = ["%d %s" % ( t.id,t.category.name ) for t in l] + eq_(result, [u'1 Some Category', u'3 Some Category']) + + @testing.resolve_artifact_names + def test_witheagerload(self): + """ + Test that an eagerload locates the correct "from" clause with which to + attach to, when presented with a query that already has a complicated + from clause. + + """ + s = create_session() + q=s.query(Thing).options(sa.orm.eagerload('category')) + + l=(q.select_from(tests.outerjoin(options, + sa.and_(tests.c.id == + options.c.test_id, + tests.c.owner_id == + options.c.owner_id))). + filter(sa.and_(tests.c.owner_id == 1, + sa.or_(options.c.someoption==None, + options.c.someoption==False)))) + + result = ["%d %s" % ( t.id,t.category.name ) for t in l] + eq_(result, [u'1 Some Category', u'3 Some Category']) + + @testing.resolve_artifact_names + def test_dslish(self): + """test the same as witheagerload except using generative""" + s = create_session() + q = s.query(Thing).options(sa.orm.eagerload('category')) + l = q.filter ( + sa.and_(tests.c.owner_id == 1, + sa.or_(options.c.someoption == None, + options.c.someoption == False)) + ).outerjoin('owner_option') + + result = ["%d %s" % ( t.id,t.category.name ) for t in l] + eq_(result, [u'1 Some Category', u'3 Some Category']) + + @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on') + @testing.resolve_artifact_names + def test_without_outerjoin_literal(self): + s = create_session() + q = s.query(Thing).options(sa.orm.eagerload('category')) + l = (q.filter( + (tests.c.owner_id==1) & + ('options.someoption is null or options.someoption=%s' % false)). + join('owner_option')) + + result = ["%d %s" % ( t.id,t.category.name ) for t in l] + eq_(result, [u'3 Some Category']) + + @testing.resolve_artifact_names + def test_withoutouterjoin(self): + s = create_session() + q = s.query(Thing).options(sa.orm.eagerload('category')) + l = q.filter( + (tests.c.owner_id==1) & + ((options.c.someoption==None) | (options.c.someoption==False)) + ).join('owner_option') + + result = ["%d %s" % ( t.id,t.category.name ) for t in l] + eq_(result, [u'3 Some Category']) + + +class EagerTest2(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('left', metadata, + Column('id', Integer, ForeignKey('middle.id'), primary_key=True), + Column('data', String(50), primary_key=True)) + + Table('middle', metadata, + Column('id', Integer, primary_key = True), + Column('data', String(50))) + + Table('right', metadata, + Column('id', Integer, ForeignKey('middle.id'), primary_key=True), + Column('data', String(50), primary_key=True)) + + @classmethod + def setup_classes(cls): + class Left(_base.BasicEntity): + def __init__(self, data): + self.data = data + + class Middle(_base.BasicEntity): + def __init__(self, data): + self.data = data + + class Right(_base.BasicEntity): + def __init__(self, data): + self.data = data + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + # set up bi-directional eager loads + mapper(Left, left) + mapper(Right, right) + mapper(Middle, middle, properties=dict( + left=relation(Left, + lazy=False, + backref=backref('middle',lazy=False)), + right=relation(Right, + lazy=False, + backref=backref('middle', lazy=False)))), + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_eager_terminate(self): + """Eager query generation does not include the same mapper's table twice. + + Or, that bi-directional eager loads dont include each other in eager + query generation. + + """ + p = Middle('m1') + p.left.append(Left('l1')) + p.right.append(Right('r1')) + + session = create_session() + session.add(p) + session.flush() + session.expunge_all() + obj = session.query(Left).filter_by(data='l1').one() + + +class EagerTest3(_base.MappedTest): + """Eager loading combined with nested SELECT statements, functions, and aggregates.""" + + @classmethod + def define_tables(cls, metadata): + Table('datas', metadata, + Column('id', Integer, primary_key=True, nullable=False), + Column('a', Integer, nullable=False)) + + Table('foo', metadata, + Column('data_id', Integer, + ForeignKey('datas.id'), + nullable=False, primary_key=True), + Column('bar', Integer)) + + Table('stats', metadata, + Column('id', Integer, primary_key=True, nullable=False ), + Column('data_id', Integer, ForeignKey('datas.id')), + Column('somedata', Integer, nullable=False )) + + @classmethod + def setup_classes(cls): + class Data(_base.BasicEntity): + pass + + class Foo(_base.BasicEntity): + pass + + class Stat(_base.BasicEntity): + pass + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_nesting_with_functions(self): + mapper(Data, datas) + mapper(Foo, foo, properties={ + 'data': relation(Data,backref=backref('foo',uselist=False))}) + + mapper(Stat, stats, properties={ + 'data':relation(Data)}) + + session = create_session() + + data = [Data(a=x) for x in range(5)] + session.add_all(data) + + session.add_all(( + Stat(data=data[0], somedata=1), + Stat(data=data[1], somedata=2), + Stat(data=data[2], somedata=3), + Stat(data=data[3], somedata=4), + Stat(data=data[4], somedata=5), + Stat(data=data[0], somedata=6), + Stat(data=data[1], somedata=7), + Stat(data=data[2], somedata=8), + Stat(data=data[3], somedata=9), + Stat(data=data[4], somedata=10))) + session.flush() + + arb_data = sa.select( + [stats.c.data_id, sa.func.max(stats.c.somedata).label('max')], + stats.c.data_id <= 5, + group_by=[stats.c.data_id]).alias('arb') + + arb_result = arb_data.execute().fetchall() + + # order the result list descending based on 'max' + arb_result.sort(key = lambda a: a['max'], reverse=True) + + # extract just the "data_id" from it + arb_result = [row['data_id'] for row in arb_result] + + # now query for Data objects using that above select, adding the + # "order by max desc" separately + q = (session.query(Data). + options(sa.orm.eagerload('foo')). + select_from(datas.join(arb_data, arb_data.c.data_id == datas.c.id)). + order_by(sa.desc(arb_data.c.max)). + limit(10)) + + # extract "data_id" from the list of result objects + verify_result = [d.id for d in q] + + eq_(verify_result, arb_result) + +class EagerTest4(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('departments', metadata, + Column('department_id', Integer, primary_key=True), + Column('name', String(50))) + + Table('employees', metadata, + Column('person_id', Integer, primary_key=True), + Column('name', String(50)), + Column('department_id', Integer, + ForeignKey('departments.department_id'))) + + @classmethod + def setup_classes(cls): + class Department(_base.BasicEntity): + pass + + class Employee(_base.BasicEntity): + pass + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_basic(self): + mapper(Employee, employees) + mapper(Department, departments, properties=dict( + employees=relation(Employee, + lazy=False, + backref='department'))) + + d1 = Department(name='One') + for e in 'Jim', 'Jack', 'John', 'Susan': + d1.employees.append(Employee(name=e)) + + d2 = Department(name='Two') + for e in 'Joe', 'Bob', 'Mary', 'Wally': + d2.employees.append(Employee(name=e)) + + sess = create_session() + sess.add_all((d1, d2)) + sess.flush() + + q = (sess.query(Department). + join('employees'). + filter(Employee.name.startswith('J')). + distinct(). + order_by([sa.desc(Department.name)])) + + eq_(q.count(), 2) + assert q[0] is d2 + + +class EagerTest5(_base.MappedTest): + """Construction of AliasedClauses for the same eager load property but different parent mappers, due to inheritance.""" + + @classmethod + def define_tables(cls, metadata): + Table('base', metadata, + Column('uid', String(30), primary_key=True), + Column('x', String(30))) + + Table('derived', metadata, + Column('uid', String(30), ForeignKey('base.uid'), + primary_key=True), + Column('y', String(30))) + + Table('derivedII', metadata, + Column('uid', String(30), ForeignKey('base.uid'), + primary_key=True), + Column('z', String(30))) + + Table('comments', metadata, + Column('id', Integer, primary_key=True), + Column('uid', String(30), ForeignKey('base.uid')), + Column('comment', String(30))) + + @classmethod + def setup_classes(cls): + class Base(_base.BasicEntity): + def __init__(self, uid, x): + self.uid = uid + self.x = x + + class Derived(Base): + def __init__(self, uid, x, y): + self.uid = uid + self.x = x + self.y = y + + class DerivedII(Base): + def __init__(self, uid, x, z): + self.uid = uid + self.x = x + self.z = z + + class Comment(_base.BasicEntity): + def __init__(self, uid, comment): + self.uid = uid + self.comment = comment + + @testing.resolve_artifact_names + def test_basic(self): + commentMapper = mapper(Comment, comments) + + baseMapper = mapper(Base, base, properties=dict( + comments=relation(Comment, lazy=False, + cascade='all, delete-orphan'))) + + mapper(Derived, derived, inherits=baseMapper) + + mapper(DerivedII, derivedII, inherits=baseMapper) + + sess = create_session() + d = Derived('uid1', 'x', 'y') + d.comments = [Comment('uid1', 'comment')] + d2 = DerivedII('uid2', 'xx', 'z') + d2.comments = [Comment('uid2', 'comment')] + sess.add_all((d, d2)) + sess.flush() + sess.expunge_all() + + # this eager load sets up an AliasedClauses for the "comment" + # relationship, then stores it in clauses_by_lead_mapper[mapper for + # Derived] + d = sess.query(Derived).get('uid1') + sess.expunge_all() + assert len([c for c in d.comments]) == 1 + + # this eager load sets up an AliasedClauses for the "comment" + # relationship, and should store it in clauses_by_lead_mapper[mapper + # for DerivedII]. the bug was that the previous AliasedClause create + # prevented this population from occurring. + d2 = sess.query(DerivedII).get('uid2') + sess.expunge_all() + + # object is not in the session; therefore the lazy load cant trigger + # here, eager load had to succeed + assert len([c for c in d2.comments]) == 1 + + +class EagerTest6(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('design_types', metadata, + Column('design_type_id', Integer, primary_key=True)) + + Table('design', metadata, + Column('design_id', Integer, primary_key=True), + Column('design_type_id', Integer, + ForeignKey('design_types.design_type_id'))) + + Table('parts', metadata, + Column('part_id', Integer, primary_key=True), + Column('design_id', Integer, ForeignKey('design.design_id')), + Column('design_type_id', Integer, + ForeignKey('design_types.design_type_id'))) + + Table('inherited_part', metadata, + Column('ip_id', Integer, primary_key=True), + Column('part_id', Integer, ForeignKey('parts.part_id')), + Column('design_id', Integer, ForeignKey('design.design_id'))) + + @classmethod + def setup_classes(cls): + class Part(_base.BasicEntity): + pass + + class Design(_base.BasicEntity): + pass + + class DesignType(_base.BasicEntity): + pass + + class InheritedPart(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def test_one(self): + p_m = mapper(Part, parts) + + mapper(InheritedPart, inherited_part, properties=dict( + part=relation(Part, lazy=False))) + + d_m = mapper(Design, design, properties=dict( + inheritedParts=relation(InheritedPart, + cascade="all, delete-orphan", + backref="design"))) + + mapper(DesignType, design_types) + + d_m.add_property( + "type", relation(DesignType, lazy=False, backref="designs")) + + p_m.add_property( + "design", relation( + Design, lazy=False, + backref=backref("parts", cascade="all, delete-orphan"))) + + + d = Design() + sess = create_session() + sess.add(d) + sess.flush() + sess.expunge_all() + x = sess.query(Design).get(1) + x.inheritedParts + + +class EagerTest7(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('companies', metadata, + Column('company_id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('company_name', String(40))) + + Table('addresses', metadata, + Column('address_id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('company_id', Integer, ForeignKey("companies.company_id")), + Column('address', String(40))) + + Table('phone_numbers', metadata, + Column('phone_id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('address_id', Integer, ForeignKey('addresses.address_id')), + Column('type', String(20)), + Column('number', String(10))) + + Table('invoices', metadata, + Column('invoice_id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('company_id', Integer, ForeignKey("companies.company_id")), + Column('date', sa.DateTime)) + + Table('items', metadata, + Column('item_id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')), + Column('code', String(20)), + Column('qty', Integer)) + + @classmethod + def setup_classes(cls): + class Company(_base.ComparableEntity): + pass + + class Address(_base.ComparableEntity): + pass + + class Phone(_base.ComparableEntity): + pass + + class Item(_base.ComparableEntity): + pass + + class Invoice(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def testone(self): + """ + Tests eager load of a many-to-one attached to a one-to-many. this + testcase illustrated the bug, which is that when the single Company is + loaded, no further processing of the rows occurred in order to load + the Company's second Address object. + + """ + mapper(Address, addresses) + + mapper(Company, companies, properties={ + 'addresses' : relation(Address, lazy=False)}) + + mapper(Invoice, invoices, properties={ + 'company': relation(Company, lazy=False)}) + + a1 = Address(address='a1 address') + a2 = Address(address='a2 address') + c1 = Company(company_name='company 1', addresses=[a1, a2]) + i1 = Invoice(date=datetime.datetime.now(), company=c1) + + + session = create_session() + session.add(i1) + session.flush() + + company_id = c1.company_id + invoice_id = i1.invoice_id + + session.expunge_all() + c = session.query(Company).get(company_id) + + session.expunge_all() + i = session.query(Invoice).get(invoice_id) + + eq_(c, i.company) + + @testing.resolve_artifact_names + def testtwo(self): + """The original testcase that includes various complicating factors""" + + mapper(Phone, phone_numbers) + + mapper(Address, addresses, properties={ + 'phones': relation(Phone, lazy=False, backref='address', + order_by=phone_numbers.c.phone_id)}) + + mapper(Company, companies, properties={ + 'addresses': relation(Address, lazy=False, backref='company', + order_by=addresses.c.address_id)}) + + mapper(Item, items) + + mapper(Invoice, invoices, properties={ + 'items': relation(Item, lazy=False, backref='invoice', + order_by=items.c.item_id), + 'company': relation(Company, lazy=False, backref='invoices')}) + + c1 = Company(company_name='company 1', addresses=[ + Address(address='a1 address', + phones=[Phone(type='home', number='1111'), + Phone(type='work', number='22222')]), + Address(address='a2 address', + phones=[Phone(type='home', number='3333'), + Phone(type='work', number='44444')]) + ]) + + session = create_session() + session.add(c1) + session.flush() + + company_id = c1.company_id + + session.expunge_all() + + a = session.query(Company).get(company_id) + + # set up an invoice + i1 = Invoice(date=datetime.datetime.now(), company=a) + + item1 = Item(code='aaaa', qty=1, invoice=i1) + item2 = Item(code='bbbb', qty=2, invoice=i1) + item3 = Item(code='cccc', qty=3, invoice=i1) + + session.flush() + invoice_id = i1.invoice_id + + session.expunge_all() + c = session.query(Company).get(company_id) + + session.expunge_all() + i = session.query(Invoice).get(invoice_id) + + eq_(c, i.company) + + +class EagerTest8(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('prj', metadata, + Column('id', Integer, primary_key=True), + Column('created', sa.DateTime ), + Column('title', sa.Unicode(100))) + + Table('task', metadata, + Column('id', Integer, primary_key=True), + Column('status_id', Integer, + ForeignKey('task_status.id'), nullable=False), + Column('title', sa.Unicode(100)), + Column('task_type_id', Integer , + ForeignKey('task_type.id'), nullable=False), + Column('prj_id', Integer , ForeignKey('prj.id'), nullable=False)) + + Table('task_status', metadata, + Column('id', Integer, primary_key=True)) + + Table('task_type', metadata, + Column('id', Integer, primary_key=True)) + + Table('msg', metadata, + Column('id', Integer, primary_key=True), + Column('posted', sa.DateTime, index=True,), + Column('type_id', Integer, ForeignKey('msg_type.id')), + Column('task_id', Integer, ForeignKey('task.id'))) + + Table('msg_type', metadata, + Column('id', Integer, primary_key=True), + Column('name', sa.Unicode(20)), + Column('display_name', sa.Unicode(20))) + + @classmethod + @testing.resolve_artifact_names + def fixtures(cls): + return dict( + prj=(('id',), + (1,)), + + task_status=(('id',), + (1,)), + + task_type=(('id',), + (1,),), + + task=(('title', 'task_type_id', 'status_id', 'prj_id'), + (u'task 1', 1, 1, 1))) + + @classmethod + def setup_classes(cls): + class Task_Type(_base.BasicEntity): + pass + + class Joined(_base.ComparableEntity): + pass + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_nested_joins(self): + # this is testing some subtle column resolution stuff, + # concerning corresponding_column() being extremely accurate + # as well as how mapper sets up its column properties + + mapper(Task_Type, task_type) + + tsk_cnt_join = sa.outerjoin(prj, task, task.c.prj_id==prj.c.id) + + j = sa.outerjoin(task, msg, task.c.id==msg.c.task_id) + jj = sa.select([ task.c.id.label('task_id'), + sa.func.count(msg.c.id).label('props_cnt')], + from_obj=[j], + group_by=[task.c.id]).alias('prop_c_s') + jjj = sa.join(task, jj, task.c.id == jj.c.task_id) + + mapper(Joined, jjj, properties=dict( + type=relation(Task_Type, lazy=False))) + + session = create_session() + + eq_(session.query(Joined).limit(10).offset(0).one(), + Joined(id=1, title=u'task 1', props_cnt=0)) + + +class EagerTest9(_base.MappedTest): + """Test the usage of query options to eagerly load specific paths. + + This relies upon the 'path' construct used by PropertyOption to relate + LoaderStrategies to specific paths, as well as the path state maintained + throughout the query setup/mapper instances process. + + """ + @classmethod + def define_tables(cls, metadata): + Table('accounts', metadata, + Column('account_id', Integer, primary_key=True), + Column('name', String(40))) + + Table('transactions', metadata, + Column('transaction_id', Integer, primary_key=True), + Column('name', String(40))) + + Table('entries', metadata, + Column('entry_id', Integer, primary_key=True), + Column('name', String(40)), + Column('account_id', Integer, + ForeignKey('accounts.account_id')), + Column('transaction_id', Integer, + ForeignKey('transactions.transaction_id'))) + + @classmethod + def setup_classes(cls): + class Account(_base.BasicEntity): + pass + + class Transaction(_base.BasicEntity): + pass + + class Entry(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Account, accounts) + + mapper(Transaction, transactions) + + mapper(Entry, entries, properties=dict( + account=relation(Account, + uselist=False, + backref=backref('entries', lazy=True, + order_by=entries.c.entry_id)), + transaction=relation(Transaction, + uselist=False, + backref=backref('entries', lazy=False, + order_by=entries.c.entry_id)))) + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_eagerload_on_path(self): + session = create_session() + + tx1 = Transaction(name='tx1') + tx2 = Transaction(name='tx2') + + acc1 = Account(name='acc1') + ent11 = Entry(name='ent11', account=acc1, transaction=tx1) + ent12 = Entry(name='ent12', account=acc1, transaction=tx2) + + acc2 = Account(name='acc2') + ent21 = Entry(name='ent21', account=acc2, transaction=tx1) + ent22 = Entry(name='ent22', account=acc2, transaction=tx2) + + session.add(acc1) + session.flush() + session.expunge_all() + + def go(): + # load just the first Account. eager loading will actually load + # all objects saved thus far, but will not eagerly load the + # "accounts" off the immediate "entries"; only the "accounts" off + # the entries->transaction->entries + acc = (session.query(Account). + options(sa.orm.eagerload_all('entries.transaction.entries.account')). + order_by(Account.account_id)).first() + + # no sql occurs + eq_(acc.name, 'acc1') + eq_(acc.entries[0].transaction.entries[0].account.name, 'acc1') + eq_(acc.entries[0].transaction.entries[1].account.name, 'acc2') + + # lazyload triggers but no sql occurs because many-to-one uses + # cached query.get() + for e in acc.entries: + assert e.account is acc + + self.assert_sql_count(testing.db, go, 1) + + + diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py new file mode 100644 index 000000000..3b1b42dad --- /dev/null +++ b/test/orm/test_attributes.py @@ -0,0 +1,1328 @@ +import pickle +import sqlalchemy.orm.attributes as attributes +from sqlalchemy.orm.collections import collection +from sqlalchemy.orm.interfaces import AttributeExtension +from sqlalchemy import exc as sa_exc +from sqlalchemy.test import * +from sqlalchemy.test.testing import eq_ +from test.orm import _base +import gc + +# global for pickling tests +MyTest = None +MyTest2 = None + + +class AttributesTest(_base.ORMTest): + def setup(self): + global MyTest, MyTest2 + class MyTest(object): pass + class MyTest2(object): pass + + def teardown(self): + global MyTest, MyTest2 + MyTest, MyTest2 = None, None + + def test_basic(self): + class User(object):pass + + attributes.register_class(User) + attributes.register_attribute(User, 'user_id', uselist=False, useobject=False) + attributes.register_attribute(User, 'user_name', uselist=False, useobject=False) + attributes.register_attribute(User, 'email_address', uselist=False, useobject=False) + + u = User() + u.user_id = 7 + u.user_name = 'john' + u.email_address = 'lala@123.com' + + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') + attributes.instance_state(u).commit_all(attributes.instance_dict(u)) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') + + u.user_name = 'heythere' + u.email_address = 'foo@bar.com' + self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com') + + def test_pickleness(self): + attributes.register_class(MyTest) + attributes.register_class(MyTest2) + attributes.register_attribute(MyTest, 'user_id', uselist=False, useobject=False) + attributes.register_attribute(MyTest, 'user_name', uselist=False, useobject=False) + attributes.register_attribute(MyTest, 'email_address', uselist=False, useobject=False) + attributes.register_attribute(MyTest, 'some_mutable_data', mutable_scalars=True, copy_function=list, compare_function=cmp, uselist=False, useobject=False) + attributes.register_attribute(MyTest2, 'a', uselist=False, useobject=False) + attributes.register_attribute(MyTest2, 'b', uselist=False, useobject=False) + # shouldnt be pickling callables at the class level + def somecallable(*args): + return None + attributes.register_attribute(MyTest, "mt2", uselist = True, trackparent=True, callable_=somecallable, useobject=True) + + o = MyTest() + o.mt2.append(MyTest2()) + o.user_id=7 + o.some_mutable_data = [1,2,3] + o.mt2[0].a = 'abcde' + pk_o = pickle.dumps(o) + + o2 = pickle.loads(pk_o) + pk_o2 = pickle.dumps(o2) + + # so... pickle is creating a new 'mt2' string after a roundtrip here, + # so we'll brute-force set it to be id-equal to the original string + if False: + o_mt2_str = [ k for k in o.__dict__ if k == 'mt2'][0] + o2_mt2_str = [ k for k in o2.__dict__ if k == 'mt2'][0] + self.assert_(o_mt2_str == o2_mt2_str) + self.assert_(o_mt2_str is not o2_mt2_str) + # change the id of o2.__dict__['mt2'] + former = o2.__dict__['mt2'] + del o2.__dict__['mt2'] + o2.__dict__[o_mt2_str] = former + + self.assert_(pk_o == pk_o2) + + # the above is kind of distrurbing, so let's do it again a little + # differently. the string-id in serialization thing is just an + # artifact of pickling that comes up in the first round-trip. + # a -> b differs in pickle memoization of 'mt2', but b -> c will + # serialize identically. + + o3 = pickle.loads(pk_o2) + pk_o3 = pickle.dumps(o3) + o4 = pickle.loads(pk_o3) + pk_o4 = pickle.dumps(o4) + + self.assert_(pk_o3 == pk_o4) + + # and lastly make sure we still have our data after all that. + # identical serialzation is great, *if* it's complete :) + self.assert_(o4.user_id == 7) + self.assert_(o4.user_name is None) + self.assert_(o4.email_address is None) + self.assert_(o4.some_mutable_data == [1,2,3]) + self.assert_(len(o4.mt2) == 1) + self.assert_(o4.mt2[0].a == 'abcde') + self.assert_(o4.mt2[0].b is None) + + def test_state_gc(self): + """test that InstanceState always has a dict, even after host object gc'ed.""" + + class Foo(object): + pass + + attributes.register_class(Foo) + f = Foo() + state = attributes.instance_state(f) + f.bar = "foo" + assert state.dict == {'bar':'foo', state.manager.STATE_ATTR:state} + del f + gc.collect() + assert state.obj() is None + assert state.dict == {} + + def test_deferred(self): + class Foo(object):pass + + data = {'a':'this is a', 'b':12} + def loader(state, keys): + for k in keys: + state.dict[k] = data[k] + return attributes.ATTR_WAS_SET + + attributes.register_class(Foo) + manager = attributes.manager_of_class(Foo) + manager.deferred_scalar_loader = loader + attributes.register_attribute(Foo, 'a', uselist=False, useobject=False) + attributes.register_attribute(Foo, 'b', uselist=False, useobject=False) + + f = Foo() + attributes.instance_state(f).expire_attributes(None) + eq_(f.a, "this is a") + eq_(f.b, 12) + + f.a = "this is some new a" + attributes.instance_state(f).expire_attributes(None) + eq_(f.a, "this is a") + eq_(f.b, 12) + + attributes.instance_state(f).expire_attributes(None) + f.a = "this is another new a" + eq_(f.a, "this is another new a") + eq_(f.b, 12) + + attributes.instance_state(f).expire_attributes(None) + eq_(f.a, "this is a") + eq_(f.b, 12) + + del f.a + eq_(f.a, None) + eq_(f.b, 12) + + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) + eq_(f.a, None) + eq_(f.b, 12) + + def test_deferred_pickleable(self): + data = {'a':'this is a', 'b':12} + def loader(state, keys): + for k in keys: + state.dict[k] = data[k] + return attributes.ATTR_WAS_SET + + attributes.register_class(MyTest) + manager = attributes.manager_of_class(MyTest) + manager.deferred_scalar_loader=loader + attributes.register_attribute(MyTest, 'a', uselist=False, useobject=False) + attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False) + + m = MyTest() + attributes.instance_state(m).expire_attributes(None) + assert 'a' not in m.__dict__ + m2 = pickle.loads(pickle.dumps(m)) + assert 'a' not in m2.__dict__ + eq_(m2.a, "this is a") + eq_(m2.b, 12) + + def test_list(self): + class User(object):pass + class Address(object):pass + + attributes.register_class(User) + attributes.register_class(Address) + attributes.register_attribute(User, 'user_id', uselist=False, useobject=False) + attributes.register_attribute(User, 'user_name', uselist=False, useobject=False) + attributes.register_attribute(User, 'addresses', uselist = True, useobject=True) + attributes.register_attribute(Address, 'address_id', uselist=False, useobject=False) + attributes.register_attribute(Address, 'email_address', uselist=False, useobject=False) + + u = User() + u.user_id = 7 + u.user_name = 'john' + u.addresses = [] + a = Address() + a.address_id = 10 + a.email_address = 'lala@123.com' + u.addresses.append(a) + + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') + u, attributes.instance_state(a).commit_all(attributes.instance_dict(a)) + self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') + + u.user_name = 'heythere' + a = Address() + a.address_id = 11 + a.email_address = 'foo@bar.com' + u.addresses.append(a) + self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com') + + def test_scalar_listener(self): + # listeners on ScalarAttributeImpl and MutableScalarAttributeImpl aren't used normally. + # test that they work for the benefit of user extensions + class Foo(object): + pass + + results = [] + class ReceiveEvents(AttributeExtension): + def append(self, state, child, initiator): + assert False + + def remove(self, state, child, initiator): + results.append(("remove", state.obj(), child)) + + def set(self, state, child, oldchild, initiator): + results.append(("set", state.obj(), child, oldchild)) + return child + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'x', uselist=False, mutable_scalars=False, useobject=False, extension=ReceiveEvents()) + attributes.register_attribute(Foo, 'y', uselist=False, mutable_scalars=True, useobject=False, copy_function=lambda x:x, extension=ReceiveEvents()) + + f = Foo() + f.x = 5 + f.x = 17 + del f.x + f.y = [1,2,3] + f.y = [4,5,6] + del f.y + + eq_(results, [ + ('set', f, 5, None), + ('set', f, 17, 5), + ('remove', f, 17), + ('set', f, [1,2,3], None), + ('set', f, [4,5,6], [1,2,3]), + ('remove', f, [4,5,6]) + ]) + + + def test_lazytrackparent(self): + """test that the "hasparent" flag works properly when lazy loaders and backrefs are used""" + + class Post(object):pass + class Blog(object):pass + attributes.register_class(Post) + attributes.register_class(Blog) + + # set up instrumented attributes with backrefs + attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) + + # create objects as if they'd been freshly loaded from the database (without history) + b = Blog() + p1 = Post() + attributes.instance_state(b).set_callable('posts', lambda:[p1]) + attributes.instance_state(p1).set_callable('blog', lambda:b) + p1, attributes.instance_state(b).commit_all(attributes.instance_dict(b)) + + # no orphans (called before the lazy loaders fire off) + assert attributes.has_parent(Blog, p1, 'posts', optimistic=True) + assert attributes.has_parent(Post, b, 'blog', optimistic=True) + + # assert connections + assert p1.blog is b + assert p1 in b.posts + + # manual connections + b2 = Blog() + p2 = Post() + b2.posts.append(p2) + assert attributes.has_parent(Blog, p2, 'posts') + assert attributes.has_parent(Post, b2, 'blog') + + def test_inheritance(self): + """tests that attributes are polymorphic""" + class Foo(object):pass + class Bar(Foo):pass + + + attributes.register_class(Foo) + attributes.register_class(Bar) + + def func1(): + print "func1" + return "this is the foo attr" + def func2(): + print "func2" + return "this is the bar attr" + def func3(): + print "func3" + return "this is the shared attr" + attributes.register_attribute(Foo, 'element', uselist=False, callable_=lambda o:func1, useobject=True) + attributes.register_attribute(Foo, 'element2', uselist=False, callable_=lambda o:func3, useobject=True) + attributes.register_attribute(Bar, 'element', uselist=False, callable_=lambda o:func2, useobject=True) + + x = Foo() + y = Bar() + assert x.element == 'this is the foo attr' + assert y.element == 'this is the bar attr' + assert x.element2 == 'this is the shared attr' + assert y.element2 == 'this is the shared attr' + + def test_no_double_state(self): + states = set() + class Foo(object): + def __init__(self): + states.add(attributes.instance_state(self)) + class Bar(Foo): + def __init__(self): + states.add(attributes.instance_state(self)) + Foo.__init__(self) + + + attributes.register_class(Foo) + attributes.register_class(Bar) + + b = Bar() + eq_(len(states), 1) + eq_(list(states)[0].obj(), b) + + + def test_inheritance2(self): + """test that the attribute manager can properly traverse the managed attributes of an object, + if the object is of a descendant class with managed attributes in the parent class""" + class Foo(object):pass + class Bar(Foo):pass + + class Element(object): + _state = True + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'element', uselist=False, useobject=True) + el = Element() + x = Bar() + x.element = el + eq_(attributes.get_history(attributes.instance_state(x), 'element'), ([el], (), ())) + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) + + (added, unchanged, deleted) = attributes.get_history(attributes.instance_state(x), 'element') + assert added == () + assert unchanged == [el] + + def test_lazyhistory(self): + """tests that history functions work with lazy-loading attributes""" + + class Foo(_base.BasicEntity): + pass + class Bar(_base.BasicEntity): + pass + + attributes.register_class(Foo) + attributes.register_class(Bar) + + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] + def func1(): + return "this is func 1" + def func2(): + return [bar1, bar2, bar3] + + attributes.register_attribute(Foo, 'col1', uselist=False, callable_=lambda o:func1, useobject=True) + attributes.register_attribute(Foo, 'col2', uselist=True, callable_=lambda o:func2, useobject=True) + attributes.register_attribute(Bar, 'id', uselist=False, useobject=True) + + x = Foo() + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) + x.col2.append(bar4) + eq_(attributes.get_history(attributes.instance_state(x), 'col2'), ([bar4], [bar1, bar2, bar3], [])) + + def test_parenttrack(self): + class Foo(object):pass + class Bar(object):pass + + attributes.register_class(Foo) + attributes.register_class(Bar) + + attributes.register_attribute(Foo, 'element', uselist=False, trackparent=True, useobject=True) + attributes.register_attribute(Bar, 'element', uselist=False, trackparent=True, useobject=True) + + f1 = Foo() + f2 = Foo() + b1 = Bar() + b2 = Bar() + + f1.element = b1 + b2.element = f2 + + assert attributes.has_parent(Foo, b1, 'element') + assert not attributes.has_parent(Foo, b2, 'element') + assert not attributes.has_parent(Foo, f2, 'element') + assert attributes.has_parent(Bar, f2, 'element') + + b2.element = None + assert not attributes.has_parent(Bar, f2, 'element') + + # test that double assignment doesn't accidentally reset the 'parent' flag. + b3 = Bar() + f4 = Foo() + b3.element = f4 + assert attributes.has_parent(Bar, f4, 'element') + b3.element = f4 + assert attributes.has_parent(Bar, f4, 'element') + + def test_mutablescalars(self): + """test detection of changes on mutable scalar items""" + class Foo(object):pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False) + x = Foo() + x.element = ['one', 'two', 'three'] + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) + x.element[1] = 'five' + assert attributes.instance_state(x).check_modified() + + attributes.unregister_class(Foo) + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'element', uselist=False, useobject=False) + x = Foo() + x.element = ['one', 'two', 'three'] + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) + x.element[1] = 'five' + assert not attributes.instance_state(x).check_modified() + + def test_descriptorattributes(self): + """changeset: 1633 broke ability to use ORM to map classes with unusual + descriptor attributes (for example, classes that inherit from ones + implementing zope.interface.Interface). + This is a simple regression test to prevent that defect. + """ + class des(object): + def __get__(self, instance, owner): + raise AttributeError('fake attribute') + + class Foo(object): + A = des() + + attributes.register_class(Foo) + attributes.unregister_class(Foo) + + def test_collectionclasses(self): + + class Foo(object):pass + attributes.register_class(Foo) + + attributes.register_attribute(Foo, "collection", uselist=True, typecallable=set, useobject=True) + assert attributes.manager_of_class(Foo).is_instrumented("collection") + assert isinstance(Foo().collection, set) + + attributes.unregister_attribute(Foo, "collection") + assert not attributes.manager_of_class(Foo).is_instrumented("collection") + + try: + attributes.register_attribute(Foo, "collection", uselist=True, typecallable=dict, useobject=True) + assert False + except sa_exc.ArgumentError, e: + assert str(e) == "Type InstrumentedDict must elect an appender method to be a collection class" + + class MyDict(dict): + @collection.appender + def append(self, item): + self[item.foo] = item + @collection.remover + def remove(self, item): + del self[item.foo] + attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyDict, useobject=True) + assert isinstance(Foo().collection, MyDict) + + attributes.unregister_attribute(Foo, "collection") + + class MyColl(object):pass + try: + attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl, useobject=True) + assert False + except sa_exc.ArgumentError, e: + assert str(e) == "Type MyColl must elect an appender method to be a collection class" + + class MyColl(object): + @collection.iterator + def __iter__(self): + return iter([]) + @collection.appender + def append(self, item): + pass + @collection.remover + def remove(self, item): + pass + attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl, useobject=True) + try: + Foo().collection + assert True + except sa_exc.ArgumentError, e: + assert False + + +class BackrefTest(_base.ORMTest): + + def test_manytomany(self): + class Student(object):pass + class Course(object):pass + + attributes.register_class(Student) + attributes.register_class(Course) + attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True) + attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True) + + s = Student() + c = Course() + s.courses.append(c) + self.assert_(c.students == [s]) + s.courses.remove(c) + self.assert_(c.students == []) + + (s1, s2, s3) = (Student(), Student(), Student()) + + c.students = [s1, s2, s3] + self.assert_(s2.courses == [c]) + self.assert_(s1.courses == [c]) + s1.courses.remove(c) + self.assert_(c.students == [s2,s3]) + + def test_onetomany(self): + class Post(object):pass + class Blog(object):pass + + attributes.register_class(Post) + attributes.register_class(Blog) + attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) + b = Blog() + (p1, p2, p3) = (Post(), Post(), Post()) + b.posts.append(p1) + b.posts.append(p2) + b.posts.append(p3) + self.assert_(b.posts == [p1, p2, p3]) + self.assert_(p2.blog is b) + + p3.blog = None + self.assert_(b.posts == [p1, p2]) + p4 = Post() + p4.blog = b + self.assert_(b.posts == [p1, p2, p4]) + + p4.blog = b + p4.blog = b + self.assert_(b.posts == [p1, p2, p4]) + + # assert no failure removing None + p5 = Post() + p5.blog = None + del p5.blog + + def test_onetoone(self): + class Port(object):pass + class Jack(object):pass + attributes.register_class(Port) + attributes.register_class(Jack) + attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True) + attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True) + p = Port() + j = Jack() + p.jack = j + self.assert_(j.port is p) + self.assert_(p.jack is not None) + + j.port = None + self.assert_(p.jack is None) + +class PendingBackrefTest(_base.ORMTest): + def setup(self): + global Post, Blog, called, lazy_load + + class Post(object): + def __init__(self, name): + self.name = name + __hash__ = None + def __eq__(self, other): + return other.name == self.name + + class Blog(object): + def __init__(self, name): + self.name = name + __hash__ = None + def __eq__(self, other): + return other.name == self.name + + called = [0] + + lazy_load = [] + def lazy_posts(instance): + def load(): + called[0] += 1 + return lazy_load + return load + + attributes.register_class(Post) + attributes.register_class(Blog) + attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), callable_=lazy_posts, trackparent=True, useobject=True) + + def test_lazy_add(self): + global lazy_load + + p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3") + lazy_load = [p1, p2, p3] + + b = Blog("blog 1") + p = Post("post 4") + + p.blog = b + p = Post("post 5") + p.blog = b + # setting blog doesnt call 'posts' callable + assert called[0] == 0 + + # calling backref calls the callable, populates extra posts + assert b.posts == [p1, p2, p3, Post("post 4"), Post("post 5")] + assert called[0] == 1 + + def test_lazy_history(self): + global lazy_load + + p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3") + lazy_load = [p1, p2, p3] + + b = Blog("blog 1") + p = Post("post 4") + p.blog = b + + p4 = Post("post 5") + p4.blog = b + assert called[0] == 0 + eq_(attributes.instance_state(b).get_history('posts'), ([p, p4], [p1, p2, p3], [])) + assert called[0] == 1 + + def test_lazy_remove(self): + global lazy_load + called[0] = 0 + lazy_load = [] + + b = Blog("blog 1") + p = Post("post 1") + p.blog = b + assert called[0] == 0 + + lazy_load = [p] + + p.blog = None + p2 = Post("post 2") + p2.blog = b + assert called[0] == 0 + assert b.posts == [p2] + assert called[0] == 1 + + def test_normal_load(self): + global lazy_load + lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] + called[0] = 0 + + b = Blog("blog 1") + + # assign without using backref system + p2.__dict__['blog'] = b + + assert b.posts == [Post("post 1"), Post("post 2"), Post("post 3")] + assert called[0] == 1 + p2.blog = None + p4 = Post("post 4") + p4.blog = b + assert b.posts == [Post("post 1"), Post("post 3"), Post("post 4")] + assert called[0] == 1 + + called[0] = 0 + lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] + + def test_commit_removes_pending(self): + global lazy_load + lazy_load = (p1, ) = [Post("post 1"), ] + called[0] = 0 + + b = Blog("blog 1") + p1.blog = b + attributes.instance_state(b).commit_all(attributes.instance_dict(b)) + attributes.instance_state(p1).commit_all(attributes.instance_dict(p1)) + assert b.posts == [Post("post 1")] + +class HistoryTest(_base.ORMTest): + + def test_get_committed_value(self): + class Foo(_base.BasicEntity): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False) + + f = Foo() + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) + + f.someattr = 3 + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) + + f = Foo() + f.someattr = 3 + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), 3) + + def test_scalar(self): + class Foo(_base.BasicEntity): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False) + + # case 1. new object + f = Foo() + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), (), ())) + + f.someattr = "hi" + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['hi'], (), ())) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['hi'], ())) + + f.someattr = 'there' + + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['there'], (), ['hi'])) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['there'], ())) + + del f.someattr + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), (), ['there'])) + + # case 2. object with direct dictionary settings (similar to a load operation) + f = Foo() + f.__dict__['someattr'] = 'new' + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['new'], ())) + + f.someattr = 'old' + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['old'], (), ['new'])) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['old'], ())) + + # setting None on uninitialized is currently a change for a scalar attribute + # no lazyload occurs so this allows overwrite operation to proceed + f = Foo() + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), (), ())) + f.someattr = None + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], (), ())) + + f = Foo() + f.__dict__['someattr'] = 'new' + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['new'], ())) + f.someattr = None + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], (), ['new'])) + + # set same value twice + f = Foo() + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + f.someattr = 'one' + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ())) + f.someattr = 'two' + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['two'], (), ())) + + + def test_mutable_scalar(self): + class Foo(_base.BasicEntity): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False, mutable_scalars=True, copy_function=dict) + + # case 1. new object + f = Foo() + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), (), ())) + + f.someattr = {'foo':'hi'} + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'hi'}], (), ())) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'hi'}], ())) + eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'}) + + f.someattr['foo'] = 'there' + eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'}) + + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'there'}], (), [{'foo':'hi'}])) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'there'}], ())) + + # case 2. object with direct dictionary settings (similar to a load operation) + f = Foo() + f.__dict__['someattr'] = {'foo':'new'} + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'new'}], ())) + + f.someattr = {'foo':'old'} + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'old'}], (), [{'foo':'new'}])) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'old'}], ())) + + + def test_use_object(self): + class Foo(_base.BasicEntity): + pass + + class Bar(_base.BasicEntity): + _state = None + def __nonzero__(self): + assert False + + hi = Bar(name='hi') + there = Bar(name='there') + new = Bar(name='new') + old = Bar(name='old') + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=True) + + # case 1. new object + f = Foo() + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [None], ())) + + f.someattr = hi + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], (), ())) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) + + f.someattr = there + + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], (), [hi])) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ())) + + del f.someattr + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], (), [there])) + + # case 2. object with direct dictionary settings (similar to a load operation) + f = Foo() + f.__dict__['someattr'] = 'new' + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['new'], ())) + + f.someattr = old + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], (), ['new'])) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ())) + + # setting None on uninitialized is currently not a change for an object attribute + # (this is different than scalar attribute). a lazyload has occured so if its + # None, its really None + f = Foo() + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [None], ())) + f.someattr = None + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [None], ())) + + f = Foo() + f.__dict__['someattr'] = 'new' + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['new'], ())) + f.someattr = None + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], (), ['new'])) + + # set same value twice + f = Foo() + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + f.someattr = 'one' + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ())) + f.someattr = 'two' + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['two'], (), ())) + + def test_object_collections_set(self): + class Foo(_base.BasicEntity): + pass + class Bar(_base.BasicEntity): + def __nonzero__(self): + assert False + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True) + + hi = Bar(name='hi') + there = Bar(name='there') + old = Bar(name='old') + new = Bar(name='new') + + # case 1. new object + f = Foo() + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [], ())) + + f.someattr = [hi] + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) + + f.someattr = [there] + + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi])) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ())) + + f.someattr = [hi] + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [there])) + + f.someattr = [old, new] + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [], [there])) + + # case 2. object with direct settings (similar to a load operation) + f = Foo() + collection = attributes.init_collection(attributes.instance_state(f), 'someattr') + collection.append_without_event(new) + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) + + f.someattr = [old] + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], [new])) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ())) + + def test_dict_collections(self): + class Foo(_base.BasicEntity): + pass + class Bar(_base.BasicEntity): + pass + + from sqlalchemy.orm.collections import attribute_mapped_collection + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True, typecallable=attribute_mapped_collection('name')) + + hi = Bar(name='hi') + there = Bar(name='there') + old = Bar(name='old') + new = Bar(name='new') + + f = Foo() + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [], ())) + + f.someattr['hi'] = hi + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) + + f.someattr['there'] = there + eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([hi, there]), set(), set())) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set(), set([hi, there]), set())) + + def test_object_collections_mutate(self): + class Foo(_base.BasicEntity): + pass + class Bar(_base.BasicEntity): + pass + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True) + attributes.register_attribute(Foo, 'id', uselist=False, useobject=False) + + hi = Bar(name='hi') + there = Bar(name='there') + old = Bar(name='old') + new = Bar(name='new') + + # case 1. new object + f = Foo(id=1) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [], ())) + + f.someattr.append(hi) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) + + f.someattr.append(there) + + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [hi], [])) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there], ())) + + f.someattr.remove(there) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], [there])) + + f.someattr.append(old) + f.someattr.append(new) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [hi], [there])) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, old, new], ())) + + f.someattr.pop(0) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old, new], [hi])) + + # case 2. object with direct settings (similar to a load operation) + f = Foo() + f.__dict__['id'] = 1 + collection = attributes.init_collection(attributes.instance_state(f), 'someattr') + collection.append_without_event(new) + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) + + f.someattr.append(old) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [new], [])) + + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new, old], ())) + + f = Foo() + collection = attributes.init_collection(attributes.instance_state(f), 'someattr') + collection.append_without_event(new) + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) + + f.id = 1 + f.someattr.remove(new) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], [new])) + + # case 3. mixing appends with sets + f = Foo() + f.someattr.append(hi) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) + f.someattr.append(there) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there], [], [])) + f.someattr = [there] + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [])) + + # case 4. ensure duplicates show up, order is maintained + f = Foo() + f.someattr.append(hi) + f.someattr.append(there) + f.someattr.append(hi) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there, hi], [], [])) + + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there, hi], ())) + + f.someattr = [] + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], [hi, there, hi])) + + def test_collections_via_backref(self): + class Foo(_base.BasicEntity): + pass + class Bar(_base.BasicEntity): + pass + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, useobject=True) + attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True) + + f1 = Foo() + b1 = Bar() + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [], ())) + eq_(attributes.get_history(attributes.instance_state(b1), 'foo'), ((), [None], ())) + + #b1.foo = f1 + f1.bars.append(b1) + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], [])) + eq_(attributes.get_history(attributes.instance_state(b1), 'foo'), ([f1], (), ())) + + b2 = Bar() + f1.bars.append(b2) + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1, b2], [], [])) + eq_(attributes.get_history(attributes.instance_state(b1), 'foo'), ([f1], (), ())) + eq_(attributes.get_history(attributes.instance_state(b2), 'foo'), ([f1], (), ())) + + def test_lazy_backref_collections(self): + class Foo(_base.BasicEntity): + pass + class Bar(_base.BasicEntity): + pass + + lazy_load = [] + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, callable_=lazyload, useobject=True) + attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True) + + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] + lazy_load = [bar1, bar2, bar3] + + f = Foo() + bar4 = Bar() + bar4.foo = f + eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar2, bar3], [])) + + lazy_load = None + f = Foo() + bar4 = Bar() + bar4.foo = f + eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [], [])) + + lazy_load = [bar1, bar2, bar3] + attributes.instance_state(f).expire_attributes(['bars']) + eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ((), [bar1, bar2, bar3], ())) + + def test_collections_via_lazyload(self): + class Foo(_base.BasicEntity): + pass + class Bar(_base.BasicEntity): + pass + + lazy_load = [] + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bars', uselist=True, callable_=lazyload, trackparent=True, useobject=True) + + bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] + lazy_load = [bar1, bar2, bar3] + + f = Foo() + f.bars = [] + eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [], [bar1, bar2, bar3])) + + f = Foo() + f.bars.append(bar4) + eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar2, bar3], []) ) + + f = Foo() + f.bars.remove(bar2) + eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar3], [bar2])) + f.bars.append(bar4) + eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar3], [bar2])) + + f = Foo() + del f.bars[1] + eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar3], [bar2])) + + lazy_load = None + f = Foo() + f.bars.append(bar2) + eq_(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar2], [], [])) + + def test_scalar_via_lazyload(self): + class Foo(_base.BasicEntity): + pass + + lazy_load = None + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'bar', uselist=False, callable_=lazyload, useobject=False) + lazy_load = "hi" + + # with scalar non-object and active_history=False, the lazy callable is only executed on gets, not history + # operations + + f = Foo() + eq_(f.bar, "hi") + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), ["hi"], ())) + + f = Foo() + f.bar = None + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), ())) + + f = Foo() + f.bar = "there" + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), (["there"], (), ())) + f.bar = "hi" + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), (["hi"], (), ())) + + f = Foo() + eq_(f.bar, "hi") + del f.bar + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), (), ["hi"])) + assert f.bar is None + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), ["hi"])) + + def test_scalar_via_lazyload_with_active(self): + class Foo(_base.BasicEntity): + pass + + lazy_load = None + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'bar', uselist=False, callable_=lazyload, useobject=False, active_history=True) + lazy_load = "hi" + + # active_history=True means the lazy callable is executed on set as well as get, + # causing the old value to appear in the history + + f = Foo() + eq_(f.bar, "hi") + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), ["hi"], ())) + + f = Foo() + f.bar = None + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), ['hi'])) + + f = Foo() + f.bar = "there" + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), (["there"], (), ['hi'])) + f.bar = "hi" + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), ["hi"], ())) + + f = Foo() + eq_(f.bar, "hi") + del f.bar + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), (), ["hi"])) + assert f.bar is None + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), ["hi"])) + + def test_scalar_object_via_lazyload(self): + class Foo(_base.BasicEntity): + pass + class Bar(_base.BasicEntity): + pass + + lazy_load = None + def lazyload(instance): + def load(): + return lazy_load + return load + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'bar', uselist=False, callable_=lazyload, trackparent=True, useobject=True) + bar1, bar2 = [Bar(id=1), Bar(id=2)] + lazy_load = bar1 + + # with scalar object, the lazy callable is only executed on gets and history + # operations + + f = Foo() + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), [bar1], ())) + + f = Foo() + f.bar = None + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), [bar1])) + + f = Foo() + f.bar = bar2 + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([bar2], (), [bar1])) + f.bar = bar1 + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ((), [bar1], ())) + + f = Foo() + eq_(f.bar, bar1) + del f.bar + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), [bar1])) + assert f.bar is None + eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), [bar1])) + +class ListenerTest(_base.ORMTest): + def test_receive_changes(self): + """test that Listeners can mutate the given value. + + This is a rudimentary test which would be better suited by a full-blown inclusion + into collection.py. + + """ + class Foo(object): + pass + class Bar(object): + pass + + class AlteringListener(AttributeExtension): + def append(self, state, child, initiator): + b2 = Bar() + b2.data = b1.data + " appended" + return b2 + + def set(self, state, value, oldvalue, initiator): + return value + " modified" + + attributes.register_class(Foo) + attributes.register_class(Bar) + attributes.register_attribute(Foo, 'data', uselist=False, useobject=False, extension=AlteringListener()) + attributes.register_attribute(Foo, 'barlist', uselist=True, useobject=True, extension=AlteringListener()) + attributes.register_attribute(Foo, 'barset', typecallable=set, uselist=True, useobject=True, extension=AlteringListener()) + attributes.register_attribute(Bar, 'data', uselist=False, useobject=False) + + f1 = Foo() + f1.data = "some data" + eq_(f1.data, "some data modified") + b1 = Bar() + b1.data = "some bar" + f1.barlist.append(b1) + assert b1.data == "some bar" + assert f1.barlist[0].data == "some bar appended" + + f1.barset.add(b1) + assert f1.barset.pop().data == "some bar appended" + + diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py new file mode 100644 index 000000000..9b1c20b60 --- /dev/null +++ b/test/orm/test_bind.py @@ -0,0 +1,59 @@ +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy import MetaData, Integer +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, create_session +import sqlalchemy as sa +from sqlalchemy.test import testing +from test.orm import _base + + +class BindTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('test_table', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', Integer)) + + @classmethod + def setup_classes(cls): + class Foo(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + meta = MetaData() + test_table.tometadata(meta) + + assert meta.tables['test_table'].bind is None + mapper(Foo, meta.tables['test_table']) + + @testing.resolve_artifact_names + def test_session_bind(self): + engine = self.metadata.bind + + for bind in (engine, engine.connect()): + try: + sess = create_session(bind=bind) + assert sess.bind is bind + f = Foo() + sess.add(f) + sess.flush() + assert sess.query(Foo).get(f.id) is f + finally: + if hasattr(bind, 'close'): + bind.close() + + @testing.resolve_artifact_names + def test_session_unbound(self): + sess = create_session() + sess.add(Foo()) + assert_raises_message( + sa.exc.UnboundExecutionError, + ('Could not locate a bind configured on Mapper|Foo|test_table ' + 'or this Session'), + sess.flush) + + diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py new file mode 100644 index 000000000..d0a7b9ded --- /dev/null +++ b/test/orm/test_cascade.py @@ -0,0 +1,1315 @@ + +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref +from sqlalchemy.orm import attributes, exc as orm_exc +from sqlalchemy.test import testing +from sqlalchemy.test.testing import eq_ +from test.orm import _base, _fixtures + + +class O2MCascadeTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Address, addresses) + mapper(User, users, properties = dict( + addresses = relation(Address, cascade="all, delete-orphan", backref="user"), + orders = relation( + mapper(Order, orders), cascade="all, delete-orphan") + )) + mapper(Dingaling,dingalings, properties={ + 'address':relation(Address) + }) + + @testing.resolve_artifact_names + def test_list_assignment(self): + sess = create_session() + u = User(name='jack', orders=[ + Order(description='someorder'), + Order(description='someotherorder')]) + sess.add(u) + sess.flush() + sess.expunge_all() + + u = sess.query(User).get(u.id) + eq_(u, User(name='jack', + orders=[Order(description='someorder'), + Order(description='someotherorder')])) + + u.orders=[Order(description="order 3"), Order(description="order 4")] + sess.flush() + sess.expunge_all() + + u = sess.query(User).get(u.id) + eq_(u, User(name='jack', + orders=[Order(description="order 3"), + Order(description="order 4")])) + + eq_(sess.query(Order).all(), + [Order(description="order 3"), Order(description="order 4")]) + + o5 = Order(description="order 5") + sess.add(o5) + try: + sess.flush() + assert False + except orm_exc.FlushError, e: + assert "is an orphan" in str(e) + + + @testing.resolve_artifact_names + def test_delete(self): + sess = create_session() + u = User(name='jack', + orders=[Order(description='someorder'), + Order(description='someotherorder')]) + sess.add(u) + sess.flush() + + sess.delete(u) + sess.flush() + assert users.count().scalar() == 0 + assert orders.count().scalar() == 0 + + @testing.resolve_artifact_names + def test_delete_unloaded_collections(self): + """Unloaded collections are still included in a delete-cascade by default.""" + sess = create_session() + u = User(name='jack', + addresses=[Address(email_address="address1"), + Address(email_address="address2")]) + sess.add(u) + sess.flush() + sess.expunge_all() + assert addresses.count().scalar() == 2 + assert users.count().scalar() == 1 + + u = sess.query(User).get(u.id) + + assert 'addresses' not in u.__dict__ + sess.delete(u) + sess.flush() + assert addresses.count().scalar() == 0 + assert users.count().scalar() == 0 + + @testing.resolve_artifact_names + def test_cascades_onlycollection(self): + """Cascade only reaches instances that are still part of the collection, + not those that have been removed""" + + sess = create_session() + u = User(name='jack', + orders=[Order(description='someorder'), + Order(description='someotherorder')]) + sess.add(u) + sess.flush() + + o = u.orders[0] + del u.orders[0] + sess.delete(u) + assert u in sess.deleted + assert o not in sess.deleted + assert o in sess + + u2 = User(name='newuser', orders=[o]) + sess.add(u2) + sess.flush() + sess.expunge_all() + assert users.count().scalar() == 1 + assert orders.count().scalar() == 1 + eq_(sess.query(User).all(), + [User(name='newuser', + orders=[Order(description='someorder')])]) + + @testing.resolve_artifact_names + def test_cascade_nosideeffects(self): + """test that cascade leaves the state of unloaded scalars/collections unchanged.""" + + sess = create_session() + u = User(name='jack') + sess.add(u) + assert 'orders' not in u.__dict__ + + sess.flush() + + assert 'orders' not in u.__dict__ + + a = Address(email_address='foo@bar.com') + sess.add(a) + assert 'user' not in a.__dict__ + a.user = u + sess.flush() + + d = Dingaling(data='d1') + d.address_id = a.id + sess.add(d) + assert 'address' not in d.__dict__ + sess.flush() + assert d.address is a + + @testing.resolve_artifact_names + def test_cascade_delete_plusorphans(self): + sess = create_session() + u = User(name='jack', + orders=[Order(description='someorder'), + Order(description='someotherorder')]) + sess.add(u) + sess.flush() + assert users.count().scalar() == 1 + assert orders.count().scalar() == 2 + + del u.orders[0] + sess.delete(u) + sess.flush() + assert users.count().scalar() == 0 + assert orders.count().scalar() == 0 + + @testing.resolve_artifact_names + def test_collection_orphans(self): + sess = create_session() + u = User(name='jack', + orders=[Order(description='someorder'), + Order(description='someotherorder')]) + sess.add(u) + sess.flush() + + assert users.count().scalar() == 1 + assert orders.count().scalar() == 2 + + u.orders[:] = [] + + sess.flush() + + assert users.count().scalar() == 1 + assert orders.count().scalar() == 0 + +class O2OCascadeTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Address, addresses) + mapper(User, users, properties = { + 'address':relation(Address, backref=backref("user", single_parent=True), uselist=False) + }) + + @testing.resolve_artifact_names + def test_single_parent_raise(self): + a1 = Address(email_address='some address') + u1 = User(name='u1', address=a1) + + assert_raises(sa_exc.InvalidRequestError, Address, email_address='asd', user=u1) + + a2 = Address(email_address='asd') + u1.address = a2 + assert u1.address is not a1 + assert a1.user is None + + + +class O2MBackrefTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(User, users, properties = dict( + orders = relation( + mapper(Order, orders), cascade="all, delete-orphan", backref="user") + )) + + @testing.resolve_artifact_names + def test_lazyload_bug(self): + sess = create_session() + + u = User(name="jack") + sess.add(u) + sess.expunge(u) + + o1 = Order(description='someorder') + o1.user = u + sess.add(u) + assert u in sess + assert o1 in sess + + +class NoSaveCascadeTest(_fixtures.FixtureTest): + """test that backrefs don't force save-update cascades to occur + when the cascade initiated from the forwards side.""" + + @testing.resolve_artifact_names + def test_unidirectional_cascade_o2m(self): + mapper(Order, orders) + mapper(User, users, properties = dict( + orders = relation( + Order, backref=backref("user", cascade=None)) + )) + + sess = create_session() + + o1 = Order() + sess.add(o1) + u1 = User(orders=[o1]) + assert u1 not in sess + assert o1 in sess + + sess.expunge_all() + + o1 = Order() + u1 = User(orders=[o1]) + sess.add(o1) + assert u1 not in sess + assert o1 in sess + + @testing.resolve_artifact_names + def test_unidirectional_cascade_m2o(self): + mapper(Order, orders, properties={ + 'user':relation(User, backref=backref("orders", cascade=None)) + }) + mapper(User, users) + + sess = create_session() + + u1 = User() + sess.add(u1) + o1 = Order() + o1.user = u1 + assert o1 not in sess + assert u1 in sess + + sess.expunge_all() + + u1 = User() + o1 = Order() + o1.user = u1 + sess.add(u1) + assert o1 not in sess + assert u1 in sess + + @testing.resolve_artifact_names + def test_unidirectional_cascade_m2m(self): + mapper(Item, items, properties={ + 'keywords':relation(Keyword, secondary=item_keywords, cascade="none", backref="items") + }) + mapper(Keyword, keywords) + + sess = create_session() + + i1 = Item() + k1 = Keyword() + sess.add(i1) + i1.keywords.append(k1) + assert i1 in sess + assert k1 not in sess + + sess.expunge_all() + + i1 = Item() + k1 = Keyword() + sess.add(i1) + k1.items.append(i1) + assert i1 in sess + assert k1 not in sess + + +class O2MCascadeNoOrphanTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(User, users, properties = dict( + orders = relation( + mapper(Order, orders), cascade="all") + )) + + @testing.resolve_artifact_names + def test_cascade_delete_noorphans(self): + sess = create_session() + u = User(name='jack', + orders=[Order(description='someorder'), + Order(description='someotherorder')]) + sess.add(u) + sess.flush() + assert users.count().scalar() == 1 + assert orders.count().scalar() == 2 + + del u.orders[0] + sess.delete(u) + sess.flush() + assert users.count().scalar() == 0 + assert orders.count().scalar() == 1 + + +class M2OCascadeTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table("extra", metadata, + Column("id", Integer, Sequence("extra_id_seq", optional=True), + primary_key=True), + Column("prefs_id", Integer, ForeignKey("prefs.id"))) + + Table('prefs', metadata, + Column('id', Integer, Sequence('prefs_id_seq', optional=True), + primary_key=True), + Column('data', String(40))) + + Table('users', metadata, + Column('id', Integer, Sequence('user_id_seq', optional=True), + primary_key=True), + Column('name', String(40)), + Column('pref_id', Integer, ForeignKey('prefs.id'))) + + @classmethod + def setup_classes(cls): + class User(_fixtures.Base): + pass + class Pref(_fixtures.Base): + pass + class Extra(_fixtures.Base): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Extra, extra) + mapper(Pref, prefs, properties=dict( + extra = relation(Extra, cascade="all, delete") + )) + mapper(User, users, properties = dict( + pref = relation(Pref, lazy=False, cascade="all, delete-orphan", single_parent=True ) + )) + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + u1 = User(name='ed', pref=Pref(data="pref 1", extra=[Extra()])) + u2 = User(name='jack', pref=Pref(data="pref 2", extra=[Extra()])) + u3 = User(name="foo", pref=Pref(data="pref 3", extra=[Extra()])) + sess = create_session() + sess.add_all((u1, u2, u3)) + sess.flush() + sess.close() + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_orphan(self): + sess = create_session() + assert prefs.count().scalar() == 3 + assert extra.count().scalar() == 3 + jack = sess.query(User).filter_by(name="jack").one() + jack.pref = None + sess.flush() + assert prefs.count().scalar() == 2 + assert extra.count().scalar() == 2 + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_orphan_on_update(self): + sess = create_session() + jack = sess.query(User).filter_by(name="jack").one() + p = jack.pref + e = jack.pref.extra[0] + sess.expunge_all() + + jack.pref = None + sess.add(jack) + sess.add(p) + sess.add(e) + assert p in sess + assert e in sess + sess.flush() + assert prefs.count().scalar() == 2 + assert extra.count().scalar() == 2 + + @testing.resolve_artifact_names + def test_pending_expunge(self): + sess = create_session() + someuser = User(name='someuser') + sess.add(someuser) + sess.flush() + someuser.pref = p1 = Pref(data='somepref') + assert p1 in sess + someuser.pref = Pref(data='someotherpref') + assert p1 not in sess + sess.flush() + eq_(sess.query(Pref).with_parent(someuser).all(), + [Pref(data="someotherpref")]) + + @testing.resolve_artifact_names + def test_double_assignment(self): + """Double assignment will not accidentally reset the 'parent' flag.""" + + sess = create_session() + jack = sess.query(User).filter_by(name="jack").one() + + newpref = Pref(data="newpref") + jack.pref = newpref + jack.pref = newpref + sess.flush() + eq_(sess.query(Pref).all(), + [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")]) + +class M2OCascadeDeleteTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('t2id', Integer, ForeignKey('t2.id'))) + Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('t3id', Integer, ForeignKey('t3.id'))) + Table('t3', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50))) + + @classmethod + def setup_classes(cls): + class T1(_fixtures.Base): + pass + class T2(_fixtures.Base): + pass + class T3(_fixtures.Base): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(T1, t1, properties={'t2': relation(T2, cascade="all")}) + mapper(T2, t2, properties={'t3': relation(T3, cascade="all")}) + mapper(T3, t3) + + @testing.resolve_artifact_names + def test_cascade_delete(self): + sess = create_session() + x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + sess.add(x) + sess.flush() + + sess.delete(x) + sess.flush() + eq_(sess.query(T1).all(), []) + eq_(sess.query(T2).all(), []) + eq_(sess.query(T3).all(), []) + + @testing.resolve_artifact_names + def test_cascade_delete_postappend_onelevel(self): + sess = create_session() + x1 = T1(data='t1', ) + x2 = T2(data='t2') + x3 = T3(data='t3') + sess.add_all((x1, x2, x3)) + sess.flush() + + sess.delete(x1) + x1.t2 = x2 + x2.t3 = x3 + sess.flush() + eq_(sess.query(T1).all(), []) + eq_(sess.query(T2).all(), []) + eq_(sess.query(T3).all(), []) + + @testing.resolve_artifact_names + def test_cascade_delete_postappend_twolevel(self): + sess = create_session() + x1 = T1(data='t1', t2=T2(data='t2')) + x3 = T3(data='t3') + sess.add_all((x1, x3)) + sess.flush() + + sess.delete(x1) + x1.t2.t3 = x3 + sess.flush() + eq_(sess.query(T1).all(), []) + eq_(sess.query(T2).all(), []) + eq_(sess.query(T3).all(), []) + + @testing.resolve_artifact_names + def test_preserves_orphans_onelevel(self): + sess = create_session() + x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) + sess.add(x2) + sess.flush() + x2.t2 = None + + sess.delete(x2) + sess.flush() + eq_(sess.query(T1).all(), []) + eq_(sess.query(T2).all(), [T2()]) + eq_(sess.query(T3).all(), [T3()]) + + @testing.future + @testing.resolve_artifact_names + def test_preserves_orphans_onelevel_postremove(self): + sess = create_session() + x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) + sess.add(x2) + sess.flush() + + sess.delete(x2) + x2.t2 = None + sess.flush() + eq_(sess.query(T1).all(), []) + eq_(sess.query(T2).all(), [T2()]) + eq_(sess.query(T3).all(), [T3()]) + + @testing.resolve_artifact_names + def test_preserves_orphans_twolevel(self): + sess = create_session() + x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + sess.add(x) + sess.flush() + + x.t2.t3 = None + sess.delete(x) + sess.flush() + eq_(sess.query(T1).all(), []) + eq_(sess.query(T2).all(), []) + eq_(sess.query(T3).all(), [T3()]) + + +class M2OCascadeDeleteOrphanTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('t2id', Integer, ForeignKey('t2.id'))) + Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('t3id', Integer, ForeignKey('t3.id'))) + Table('t3', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50))) + + @classmethod + def setup_classes(cls): + class T1(_fixtures.Base): + pass + class T2(_fixtures.Base): + pass + class T3(_fixtures.Base): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(T1, t1, properties=dict( + t2=relation(T2, cascade="all, delete-orphan", single_parent=True))) + mapper(T2, t2, properties=dict( + t3=relation(T3, cascade="all, delete-orphan", single_parent=True, backref=backref('t2', uselist=False)))) + mapper(T3, t3) + + @testing.resolve_artifact_names + def test_cascade_delete(self): + sess = create_session() + x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + sess.add(x) + sess.flush() + + sess.delete(x) + sess.flush() + eq_(sess.query(T1).all(), []) + eq_(sess.query(T2).all(), []) + eq_(sess.query(T3).all(), []) + + @testing.resolve_artifact_names + def test_deletes_orphans_onelevel(self): + sess = create_session() + x2 = T1(data='t1b', t2=T2(data='t2b', t3=T3(data='t3b'))) + sess.add(x2) + sess.flush() + x2.t2 = None + + sess.delete(x2) + sess.flush() + eq_(sess.query(T1).all(), []) + eq_(sess.query(T2).all(), []) + eq_(sess.query(T3).all(), []) + + @testing.resolve_artifact_names + def test_deletes_orphans_twolevel(self): + sess = create_session() + x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + sess.add(x) + sess.flush() + + x.t2.t3 = None + sess.delete(x) + sess.flush() + eq_(sess.query(T1).all(), []) + eq_(sess.query(T2).all(), []) + eq_(sess.query(T3).all(), []) + + @testing.resolve_artifact_names + def test_finds_orphans_twolevel(self): + sess = create_session() + x = T1(data='t1a', t2=T2(data='t2a', t3=T3(data='t3a'))) + sess.add(x) + sess.flush() + + x.t2.t3 = None + sess.flush() + eq_(sess.query(T1).all(), [T1()]) + eq_(sess.query(T2).all(), [T2()]) + eq_(sess.query(T3).all(), []) + + @testing.resolve_artifact_names + def test_single_parent_raise(self): + + sess = create_session() + + y = T2(data='T2a') + x = T1(data='T1a', t2=y) + assert_raises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y) + + @testing.resolve_artifact_names + def test_single_parent_backref(self): + + sess = create_session() + + y = T3(data='T3a') + x = T2(data='T2a', t3=y) + + # cant attach the T3 to another T2 + assert_raises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y) + + # set via backref tho is OK, unsets from previous parent + # first + z = T2(data='T2b') + y.t2 = z + + assert z.t3 is y + assert x.t3 is None + +class M2MCascadeTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + test_needs_fk=True + ) + Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + test_needs_fk=True + + ) + Table('atob', metadata, + Column('aid', Integer, ForeignKey('a.id')), + Column('bid', Integer, ForeignKey('b.id')), + test_needs_fk=True + + ) + Table('c', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('bid', Integer, ForeignKey('b.id')), + test_needs_fk=True + + ) + + @classmethod + def setup_classes(cls): + class A(_fixtures.Base): + pass + class B(_fixtures.Base): + pass + class C(_fixtures.Base): + pass + + @testing.resolve_artifact_names + def test_delete_orphan(self): + mapper(A, a, properties={ + # if no backref here, delete-orphan failed until [ticket:427] was + # fixed + 'bs': relation(B, secondary=atob, cascade="all, delete-orphan", single_parent=True) + }) + mapper(B, b) + + sess = create_session() + b1 = B(data='b1') + a1 = A(data='a1', bs=[b1]) + sess.add(a1) + sess.flush() + + a1.bs.remove(b1) + sess.flush() + assert atob.count().scalar() ==0 + assert b.count().scalar() == 0 + assert a.count().scalar() == 1 + + @testing.resolve_artifact_names + def test_delete_orphan_cascades(self): + mapper(A, a, properties={ + # if no backref here, delete-orphan failed until [ticket:427] was + # fixed + 'bs':relation(B, secondary=atob, cascade="all, delete-orphan", single_parent=True) + }) + mapper(B, b, properties={'cs':relation(C, cascade="all, delete-orphan")}) + mapper(C, c) + + sess = create_session() + b1 = B(data='b1', cs=[C(data='c1')]) + a1 = A(data='a1', bs=[b1]) + sess.add(a1) + sess.flush() + + a1.bs.remove(b1) + sess.flush() + assert atob.count().scalar() ==0 + assert b.count().scalar() == 0 + assert a.count().scalar() == 1 + assert c.count().scalar() == 0 + + @testing.resolve_artifact_names + def test_cascade_delete(self): + mapper(A, a, properties={ + 'bs':relation(B, secondary=atob, cascade="all, delete-orphan", single_parent=True) + }) + mapper(B, b) + + sess = create_session() + a1 = A(data='a1', bs=[B(data='b1')]) + sess.add(a1) + sess.flush() + + sess.delete(a1) + sess.flush() + assert atob.count().scalar() ==0 + assert b.count().scalar() == 0 + assert a.count().scalar() == 0 + + @testing.resolve_artifact_names + def test_single_parent_raise(self): + mapper(A, a, properties={ + 'bs':relation(B, secondary=atob, cascade="all, delete-orphan", single_parent=True) + }) + mapper(B, b) + + sess = create_session() + b1 =B(data='b1') + a1 = A(data='a1', bs=[b1]) + + assert_raises(sa_exc.InvalidRequestError, + A, data='a2', bs=[b1] + ) + + @testing.resolve_artifact_names + def test_single_parent_backref(self): + """test that setting m2m via a uselist=False backref bypasses the single_parent raise""" + + mapper(A, a, properties={ + 'bs':relation(B, + secondary=atob, + cascade="all, delete-orphan", single_parent=True, + backref=backref('a', uselist=False)) + }) + mapper(B, b) + + sess = create_session() + b1 =B(data='b1') + a1 = A(data='a1', bs=[b1]) + + assert_raises( + sa_exc.InvalidRequestError, + A, data='a2', bs=[b1] + ) + + a2 = A(data='a2') + b1.a = a2 + assert b1 not in a1.bs + assert b1 in a2.bs + +class UnsavedOrphansTest(_base.MappedTest): + """Pending entities that are orphans""" + + @classmethod + def define_tables(cls, metadata): + Table('users', metadata, + Column('user_id', Integer, + Sequence('user_id_seq', optional=True), + primary_key=True), + Column('name', String(40))) + + Table('addresses', metadata, + Column('address_id', Integer, + Sequence('address_id_seq', optional=True), + primary_key=True), + Column('user_id', Integer, ForeignKey('users.user_id')), + Column('email_address', String(40))) + + @classmethod + def setup_classes(cls): + class User(_fixtures.Base): + pass + class Address(_fixtures.Base): + pass + + @testing.resolve_artifact_names + def test_pending_standalone_orphan(self): + """An entity that never had a parent on a delete-orphan cascade can't be saved.""" + + mapper(Address, addresses) + mapper(User, users, properties=dict( + addresses=relation(Address, cascade="all,delete-orphan", backref="user") + )) + s = create_session() + a = Address() + s.add(a) + try: + s.flush() + except orm_exc.FlushError, e: + pass + assert a.address_id is None, "Error: address should not be persistent" + + @testing.resolve_artifact_names + def test_pending_collection_expunge(self): + """Removing a pending item from a collection expunges it from the session.""" + + mapper(Address, addresses) + mapper(User, users, properties=dict( + addresses=relation(Address, cascade="all,delete-orphan", backref="user") + )) + s = create_session() + + u = User() + s.add(u) + s.flush() + a = Address() + + u.addresses.append(a) + assert a in s + + u.addresses.remove(a) + assert a not in s + + s.delete(u) + s.flush() + + assert a.address_id is None, "Error: address should not be persistent" + + @testing.resolve_artifact_names + def test_nonorphans_ok(self): + mapper(Address, addresses) + mapper(User, users, properties=dict( + addresses=relation(Address, cascade="all,delete", backref="user") + )) + s = create_session() + u = User(name='u1', addresses=[Address(email_address='ad1')]) + s.add(u) + a1 = u.addresses[0] + u.addresses.remove(a1) + assert a1 in s + s.flush() + s.expunge_all() + eq_(s.query(Address).all(), [Address(email_address='ad1')]) + + +class UnsavedOrphansTest2(_base.MappedTest): + """same test as UnsavedOrphans only three levels deep""" + + @classmethod + def define_tables(cls, meta): + Table('orders', meta, + Column('id', Integer, Sequence('order_id_seq'), + primary_key=True), + Column('name', String(50))) + + Table('items', meta, + Column('id', Integer, Sequence('item_id_seq'), + primary_key=True), + Column('order_id', Integer, ForeignKey('orders.id'), + nullable=False), + Column('name', String(50))) + + Table('attributes', meta, + Column('id', Integer, Sequence('attribute_id_seq'), + primary_key=True), + Column('item_id', Integer, ForeignKey('items.id'), + nullable=False), + Column('name', String(50))) + + @testing.resolve_artifact_names + def test_pending_expunge(self): + class Order(_fixtures.Base): + pass + class Item(_fixtures.Base): + pass + class Attribute(_fixtures.Base): + pass + + mapper(Attribute, attributes) + mapper(Item, items, properties=dict( + attributes=relation(Attribute, cascade="all,delete-orphan", backref="item") + )) + mapper(Order, orders, properties=dict( + items=relation(Item, cascade="all,delete-orphan", backref="order") + )) + + s = create_session() + order = Order(name="order1") + s.add(order) + + attr = Attribute(name="attr1") + item = Item(name="item1", attributes=[attr]) + + order.items.append(item) + order.items.remove(item) + + assert item not in s + assert attr not in s + + s.flush() + assert orders.count().scalar() == 1 + assert items.count().scalar() == 0 + assert attributes.count().scalar() == 0 + +class UnsavedOrphansTest3(_base.MappedTest): + """test not expunging double parents""" + + @classmethod + def define_tables(cls, meta): + Table('sales_reps', meta, + Column('sales_rep_id', Integer, + Sequence('sales_rep_id_seq'), + primary_key=True), + Column('name', String(50))) + Table('accounts', meta, + Column('account_id', Integer, + Sequence('account_id_seq'), + primary_key=True), + Column('balance', Integer)) + Table('customers', meta, + Column('customer_id', Integer, + Sequence('customer_id_seq'), + primary_key=True), + Column('name', String(50)), + Column('sales_rep_id', Integer, + ForeignKey('sales_reps.sales_rep_id')), + Column('account_id', Integer, + ForeignKey('accounts.account_id'))) + + @testing.resolve_artifact_names + def test_double_parent_expunge_o2m(self): + """test the delete-orphan uow event for multiple delete-orphan parent relations.""" + + class Customer(_fixtures.Base): + pass + class Account(_fixtures.Base): + pass + class SalesRep(_fixtures.Base): + pass + + mapper(Customer, customers) + mapper(Account, accounts, properties=dict( + customers=relation(Customer, + cascade="all,delete-orphan", + backref="account"))) + mapper(SalesRep, sales_reps, properties=dict( + customers=relation(Customer, + cascade="all,delete-orphan", + backref="sales_rep"))) + s = create_session() + + a = Account(balance=0) + sr = SalesRep(name="John") + s.add_all((a, sr)) + s.flush() + + c = Customer(name="Jane") + + a.customers.append(c) + sr.customers.append(c) + assert c in s + + a.customers.remove(c) + assert c in s, "Should not expunge customer yet, still has one parent" + + sr.customers.remove(c) + assert c not in s, "Should expunge customer when both parents are gone" + + @testing.resolve_artifact_names + def test_double_parent_expunge_o2o(self): + """test the delete-orphan uow event for multiple delete-orphan parent relations.""" + + class Customer(_fixtures.Base): + pass + class Account(_fixtures.Base): + pass + class SalesRep(_fixtures.Base): + pass + + mapper(Customer, customers) + mapper(Account, accounts, properties=dict( + customer=relation(Customer, + cascade="all,delete-orphan", + backref="account", uselist=False))) + mapper(SalesRep, sales_reps, properties=dict( + customer=relation(Customer, + cascade="all,delete-orphan", + backref="sales_rep", uselist=False))) + s = create_session() + + a = Account(balance=0) + sr = SalesRep(name="John") + s.add_all((a, sr)) + s.flush() + + c = Customer(name="Jane") + + a.customer = c + sr.customer = c + assert c in s + + a.customer = None + assert c in s, "Should not expunge customer yet, still has one parent" + + sr.customer = None + assert c not in s, "Should expunge customer when both parents are gone" + + + +class DoubleParentOrphanTest(_base.MappedTest): + """test orphan detection for an entity with two parent relations""" + + @classmethod + def define_tables(cls, metadata): + Table('addresses', metadata, + Column('address_id', Integer, primary_key=True), + Column('street', String(30)), + ) + + Table('homes', metadata, + Column('home_id', Integer, primary_key=True, key="id"), + Column('description', String(30)), + Column('address_id', Integer, ForeignKey('addresses.address_id'), + nullable=False), + ) + + Table('businesses', metadata, + Column('business_id', Integer, primary_key=True, key="id"), + Column('description', String(30), key="description"), + Column('address_id', Integer, ForeignKey('addresses.address_id'), + nullable=False), + ) + + @testing.resolve_artifact_names + def test_non_orphan(self): + """test that an entity can have two parent delete-orphan cascades, and persists normally.""" + + class Address(_fixtures.Base): + pass + class Home(_fixtures.Base): + pass + class Business(_fixtures.Base): + pass + + mapper(Address, addresses) + mapper(Home, homes, properties={'address':relation(Address, cascade="all,delete-orphan", single_parent=True)}) + mapper(Business, businesses, properties={'address':relation(Address, cascade="all,delete-orphan", single_parent=True)}) + + session = create_session() + h1 = Home(description='home1', address=Address(street='address1')) + b1 = Business(description='business1', address=Address(street='address2')) + session.add_all((h1,b1)) + session.flush() + session.expunge_all() + + eq_(session.query(Home).get(h1.id), Home(description='home1', address=Address(street='address1'))) + eq_(session.query(Business).get(b1.id), Business(description='business1', address=Address(street='address2'))) + + @testing.resolve_artifact_names + def test_orphan(self): + """test that an entity can have two parent delete-orphan cascades, and is detected as an orphan + when saved without a parent.""" + + class Address(_fixtures.Base): + pass + class Home(_fixtures.Base): + pass + class Business(_fixtures.Base): + pass + + mapper(Address, addresses) + mapper(Home, homes, properties={'address':relation(Address, cascade="all,delete-orphan", single_parent=True)}) + mapper(Business, businesses, properties={'address':relation(Address, cascade="all,delete-orphan", single_parent=True)}) + + session = create_session() + a1 = Address() + session.add(a1) + try: + session.flush() + assert False + except orm_exc.FlushError, e: + assert True + +class CollectionAssignmentOrphanTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('table_a', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30))) + Table('table_b', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30)), + Column('a_id', Integer, ForeignKey('table_a.id'))) + + @testing.resolve_artifact_names + def test_basic(self): + class A(_fixtures.Base): + pass + class B(_fixtures.Base): + pass + + mapper(A, table_a, properties={ + 'bs':relation(B, cascade="all, delete-orphan") + }) + mapper(B, table_b) + + a1 = A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]) + + sess = create_session() + sess.add(a1) + sess.flush() + + sess.expunge_all() + + eq_(sess.query(A).get(a1.id), + A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) + + a1 = sess.query(A).get(a1.id) + assert not class_mapper(B)._is_orphan( + attributes.instance_state(a1.bs[0])) + a1.bs[0].foo='b2modified' + a1.bs[1].foo='b3modified' + sess.flush() + + sess.expunge_all() + eq_(sess.query(A).get(a1.id), + A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) + + +class PartialFlushTest(_base.MappedTest): + """test cascade behavior as it relates to object lists passed to flush(). + + """ + @classmethod + def define_tables(cls, metadata): + Table("base", metadata, + Column("id", Integer, primary_key=True), + Column("descr", String(50)) + ) + + Table("noninh_child", metadata, + Column('id', Integer, primary_key=True), + Column('base_id', Integer, ForeignKey('base.id')) + ) + + Table("parent", metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True) + ) + Table("inh_child", metadata, + Column("id", Integer, ForeignKey("base.id"), primary_key=True), + Column("parent_id", Integer, ForeignKey("parent.id")) + ) + + @testing.uses_deprecated() + @testing.resolve_artifact_names + def test_o2m_m2o(self): + class Base(_base.ComparableEntity): + pass + class Child(_base.ComparableEntity): + pass + + mapper(Base, base, properties={ + 'children':relation(Child, backref='parent') + }) + mapper(Child, noninh_child) + + sess = create_session() + + c1, c2 = Child(), Child() + b1 = Base(descr='b1', children=[c1, c2]) + sess.add(b1) + + assert c1 in sess.new + assert c2 in sess.new + sess.flush([b1]) + + # c1, c2 get cascaded into the session on o2m. + # not sure if this is how I like this + # to work but that's how it works for now. + assert c1 in sess and c1 not in sess.new + assert c2 in sess and c2 not in sess.new + assert b1 in sess and b1 not in sess.new + + sess = create_session() + c1, c2 = Child(), Child() + b1 = Base(descr='b1', children=[c1, c2]) + sess.add(b1) + sess.flush([c1]) + # m2o, otoh, doesn't cascade up the other way. + assert c1 in sess and c1 not in sess.new + assert c2 in sess and c2 in sess.new + assert b1 in sess and b1 in sess.new + + sess = create_session() + c1, c2 = Child(), Child() + b1 = Base(descr='b1', children=[c1, c2]) + sess.add(b1) + sess.flush([c1, c2]) + # m2o, otoh, doesn't cascade up the other way. + assert c1 in sess and c1 not in sess.new + assert c2 in sess and c2 not in sess.new + assert b1 in sess and b1 in sess.new + + @testing.uses_deprecated() + @testing.resolve_artifact_names + def test_circular_sort(self): + """test ticket 1306""" + + class Base(_base.ComparableEntity): + pass + class Parent(Base): + pass + class Child(Base): + pass + + mapper(Base,base) + + mapper(Child, inh_child, + inherits=Base, + properties={'parent': relation( + Parent, + backref='children', + primaryjoin=inh_child.c.parent_id == parent.c.id + )} + ) + + + mapper(Parent,parent, inherits=Base) + + sess = create_session() + p1 = Parent() + + c1, c2, c3 = Child(), Child(), Child() + p1.children = [c1, c2, c3] + sess.add(p1) + + sess.flush([c1]) + assert p1 in sess.new + assert c1 not in sess.new + assert c2 in sess.new + diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py new file mode 100644 index 000000000..12ff25c46 --- /dev/null +++ b/test/orm/test_collection.py @@ -0,0 +1,1839 @@ +from sqlalchemy.test.testing import eq_ +import sys +from operator import and_ + +import sqlalchemy.orm.collections as collections +from sqlalchemy.orm.collections import collection + +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy import util, exc as sa_exc +from sqlalchemy.orm import create_session, mapper, relation, attributes +from test.orm import _base +from sqlalchemy.test.testing import eq_ + +class Canary(sa.orm.interfaces.AttributeExtension): + def __init__(self): + self.data = set() + self.added = set() + self.removed = set() + def append(self, obj, value, initiator): + assert value not in self.added + self.data.add(value) + self.added.add(value) + return value + def remove(self, obj, value, initiator): + assert value not in self.removed + self.data.remove(value) + self.removed.add(value) + def set(self, obj, value, oldvalue, initiator): + if isinstance(value, str): + value = CollectionsTest.entity_maker() + + if oldvalue is not None: + self.remove(obj, oldvalue, None) + self.append(obj, value, None) + return value + +class CollectionsTest(_base.ORMTest): + class Entity(object): + def __init__(self, a=None, b=None, c=None): + self.a = a + self.b = b + self.c = c + def __repr__(self): + return str((id(self), self.a, self.b, self.c)) + + @classmethod + def setup_class(cls): + attributes.register_class(cls.Entity) + + @classmethod + def teardown_class(cls): + attributes.unregister_class(cls.Entity) + super(CollectionsTest, cls).teardown_class() + + _entity_id = 1 + + @classmethod + def entity_maker(cls): + cls._entity_id += 1 + return cls.Entity(cls._entity_id) + + @classmethod + def dictable_entity(cls, a=None, b=None, c=None): + id = cls._entity_id = (cls._entity_id + 1) + return cls.Entity(a or str(id), b or 'value %s' % id, c) + + def _test_adapter(self, typecallable, creator=None, to_set=None): + if creator is None: + creator = self.entity_maker + + class Foo(object): + pass + + canary = Canary() + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, + typecallable=typecallable, useobject=True) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + if to_set is None: + to_set = lambda col: set(col) + + def assert_eq(): + self.assert_(to_set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) + assert_ne = lambda: self.assert_(to_set(direct) != canary.data) + + e1, e2 = creator(), creator() + + adapter.append_with_event(e1) + assert_eq() + + adapter.append_without_event(e2) + assert_ne() + canary.data.add(e2) + assert_eq() + + adapter.remove_without_event(e2) + assert_ne() + canary.data.remove(e2) + assert_eq() + + adapter.remove_with_event(e1) + assert_eq() + + def _test_list(self, typecallable, creator=None): + if creator is None: + creator = self.entity_maker + + class Foo(object): + pass + + canary = Canary() + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, + typecallable=typecallable, useobject=True) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = list() + + def assert_eq(): + self.assert_(set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) + self.assert_(direct == control) + + # assume append() is available for list tests + e = creator() + direct.append(e) + control.append(e) + assert_eq() + + if hasattr(direct, 'pop'): + direct.pop() + control.pop() + assert_eq() + + if hasattr(direct, '__setitem__'): + e = creator() + direct.append(e) + control.append(e) + + e = creator() + direct[0] = e + control[0] = e + assert_eq() + + if util.reduce(and_, [hasattr(direct, a) for a in + ('__delitem__', 'insert', '__len__')], True): + values = [creator(), creator(), creator(), creator()] + direct[slice(0,1)] = values + control[slice(0,1)] = values + assert_eq() + + values = [creator(), creator()] + direct[slice(0,-1,2)] = values + control[slice(0,-1,2)] = values + assert_eq() + + values = [creator()] + direct[slice(0,-1)] = values + control[slice(0,-1)] = values + assert_eq() + + if hasattr(direct, '__delitem__'): + e = creator() + direct.append(e) + control.append(e) + del direct[-1] + del control[-1] + assert_eq() + + if hasattr(direct, '__getslice__'): + for e in [creator(), creator(), creator(), creator()]: + direct.append(e) + control.append(e) + + del direct[:-3] + del control[:-3] + assert_eq() + + del direct[0:1] + del control[0:1] + assert_eq() + + del direct[::2] + del control[::2] + assert_eq() + + if hasattr(direct, 'remove'): + e = creator() + direct.append(e) + control.append(e) + + direct.remove(e) + control.remove(e) + assert_eq() + + if hasattr(direct, '__setslice__'): + values = [creator(), creator()] + direct[0:1] = values + control[0:1] = values + assert_eq() + + values = [creator()] + direct[0:] = values + control[0:] = values + assert_eq() + + values = [creator()] + direct[:1] = values + control[:1] = values + assert_eq() + + values = [creator()] + direct[-1::2] = values + control[-1::2] = values + assert_eq() + + values = [creator()] * len(direct[1::2]) + direct[1::2] = values + control[1::2] = values + assert_eq() + + if hasattr(direct, '__delslice__'): + for i in range(1, 4): + e = creator() + direct.append(e) + control.append(e) + + del direct[-1:] + del control[-1:] + assert_eq() + + del direct[1:2] + del control[1:2] + assert_eq() + + del direct[:] + del control[:] + assert_eq() + + if hasattr(direct, 'extend'): + values = [creator(), creator(), creator()] + + direct.extend(values) + control.extend(values) + assert_eq() + + if hasattr(direct, '__iadd__'): + values = [creator(), creator(), creator()] + + direct += values + control += values + assert_eq() + + direct += [] + control += [] + assert_eq() + + values = [creator(), creator()] + obj.attr += values + control += values + assert_eq() + + if hasattr(direct, '__imul__'): + direct *= 2 + control *= 2 + assert_eq() + + obj.attr *= 2 + control *= 2 + assert_eq() + + def _test_list_bulk(self, typecallable, creator=None): + if creator is None: + creator = self.entity_maker + + class Foo(object): + pass + + canary = Canary() + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, + typecallable=typecallable, useobject=True) + + obj = Foo() + direct = obj.attr + + e1 = creator() + obj.attr.append(e1) + + like_me = typecallable() + e2 = creator() + like_me.append(e2) + + self.assert_(obj.attr is direct) + obj.attr = like_me + self.assert_(obj.attr is not direct) + self.assert_(obj.attr is not like_me) + self.assert_(set(obj.attr) == set([e2])) + self.assert_(e1 in canary.removed) + self.assert_(e2 in canary.added) + + e3 = creator() + real_list = [e3] + obj.attr = real_list + self.assert_(obj.attr is not real_list) + self.assert_(set(obj.attr) == set([e3])) + self.assert_(e2 in canary.removed) + self.assert_(e3 in canary.added) + + e4 = creator() + try: + obj.attr = set([e4]) + self.assert_(False) + except TypeError: + self.assert_(e4 not in canary.data) + self.assert_(e3 in canary.data) + + e5 = creator() + e6 = creator() + e7 = creator() + obj.attr = [e5, e6, e7] + self.assert_(e5 in canary.added) + self.assert_(e6 in canary.added) + self.assert_(e7 in canary.added) + + obj.attr = [e6, e7] + self.assert_(e5 in canary.removed) + self.assert_(e6 in canary.added) + self.assert_(e7 in canary.added) + self.assert_(e6 not in canary.removed) + self.assert_(e7 not in canary.removed) + + def test_list(self): + self._test_adapter(list) + self._test_list(list) + self._test_list_bulk(list) + + def test_list_subclass(self): + class MyList(list): + pass + self._test_adapter(MyList) + self._test_list(MyList) + self._test_list_bulk(MyList) + self.assert_(getattr(MyList, '_sa_instrumented') == id(MyList)) + + def test_list_duck(self): + class ListLike(object): + def __init__(self): + self.data = list() + def append(self, item): + self.data.append(item) + def remove(self, item): + self.data.remove(item) + def insert(self, index, item): + self.data.insert(index, item) + def pop(self, index=-1): + return self.data.pop(index) + def extend(self): + assert False + def __iter__(self): + return iter(self.data) + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'ListLike(%s)' % repr(self.data) + + self._test_adapter(ListLike) + self._test_list(ListLike) + self._test_list_bulk(ListLike) + self.assert_(getattr(ListLike, '_sa_instrumented') == id(ListLike)) + + def test_list_emulates(self): + class ListIsh(object): + __emulates__ = list + def __init__(self): + self.data = list() + def append(self, item): + self.data.append(item) + def remove(self, item): + self.data.remove(item) + def insert(self, index, item): + self.data.insert(index, item) + def pop(self, index=-1): + return self.data.pop(index) + def extend(self): + assert False + def __iter__(self): + return iter(self.data) + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'ListIsh(%s)' % repr(self.data) + + self._test_adapter(ListIsh) + self._test_list(ListIsh) + self._test_list_bulk(ListIsh) + self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh)) + + def _test_set(self, typecallable, creator=None): + if creator is None: + creator = self.entity_maker + + class Foo(object): + pass + + canary = Canary() + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, + typecallable=typecallable, useobject=True) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = set() + + def assert_eq(): + self.assert_(set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) + self.assert_(direct == control) + + def addall(*values): + for item in values: + direct.add(item) + control.add(item) + assert_eq() + def zap(): + for item in list(direct): + direct.remove(item) + control.clear() + + addall(creator()) + + e = creator() + addall(e) + addall(e) + + if hasattr(direct, 'pop'): + direct.pop() + control.pop() + assert_eq() + + if hasattr(direct, 'remove'): + e = creator() + addall(e) + + direct.remove(e) + control.remove(e) + assert_eq() + + e = creator() + try: + direct.remove(e) + except KeyError: + assert_eq() + self.assert_(e not in canary.removed) + else: + self.assert_(False) + + if hasattr(direct, 'discard'): + e = creator() + addall(e) + + direct.discard(e) + control.discard(e) + assert_eq() + + e = creator() + direct.discard(e) + self.assert_(e not in canary.removed) + assert_eq() + + if hasattr(direct, 'update'): + zap() + e = creator() + addall(e) + + values = set([e, creator(), creator()]) + + direct.update(values) + control.update(values) + assert_eq() + + if hasattr(direct, '__ior__'): + zap() + e = creator() + addall(e) + + values = set([e, creator(), creator()]) + + direct |= values + control |= values + assert_eq() + + # cover self-assignment short-circuit + values = set([e, creator(), creator()]) + obj.attr |= values + control |= values + assert_eq() + + values = frozenset([e, creator()]) + obj.attr |= values + control |= values + assert_eq() + + try: + direct |= [e, creator()] + assert False + except TypeError: + assert True + + if hasattr(direct, 'clear'): + addall(creator(), creator()) + direct.clear() + control.clear() + assert_eq() + + if hasattr(direct, 'difference_update'): + zap() + e = creator() + addall(creator(), creator()) + values = set([creator()]) + + direct.difference_update(values) + control.difference_update(values) + assert_eq() + values.update(set([e, creator()])) + direct.difference_update(values) + control.difference_update(values) + assert_eq() + + if hasattr(direct, '__isub__'): + zap() + e = creator() + addall(creator(), creator()) + values = set([creator()]) + + direct -= values + control -= values + assert_eq() + values.update(set([e, creator()])) + direct -= values + control -= values + assert_eq() + + values = set([creator()]) + obj.attr -= values + control -= values + assert_eq() + + values = frozenset([creator()]) + obj.attr -= values + control -= values + assert_eq() + + try: + direct -= [e, creator()] + assert False + except TypeError: + assert True + + if hasattr(direct, 'intersection_update'): + zap() + e = creator() + addall(e, creator(), creator()) + values = set(control) + + direct.intersection_update(values) + control.intersection_update(values) + assert_eq() + + values.update(set([e, creator()])) + direct.intersection_update(values) + control.intersection_update(values) + assert_eq() + + if hasattr(direct, '__iand__'): + zap() + e = creator() + addall(e, creator(), creator()) + values = set(control) + + direct &= values + control &= values + assert_eq() + + values.update(set([e, creator()])) + direct &= values + control &= values + assert_eq() + + values.update(set([creator()])) + obj.attr &= values + control &= values + assert_eq() + + try: + direct &= [e, creator()] + assert False + except TypeError: + assert True + + if hasattr(direct, 'symmetric_difference_update'): + zap() + e = creator() + addall(e, creator(), creator()) + + values = set([e, creator()]) + direct.symmetric_difference_update(values) + control.symmetric_difference_update(values) + assert_eq() + + e = creator() + addall(e) + values = set([e]) + direct.symmetric_difference_update(values) + control.symmetric_difference_update(values) + assert_eq() + + values = set() + direct.symmetric_difference_update(values) + control.symmetric_difference_update(values) + assert_eq() + + if hasattr(direct, '__ixor__'): + zap() + e = creator() + addall(e, creator(), creator()) + + values = set([e, creator()]) + direct ^= values + control ^= values + assert_eq() + + e = creator() + addall(e) + values = set([e]) + direct ^= values + control ^= values + assert_eq() + + values = set() + direct ^= values + control ^= values + assert_eq() + + values = set([creator()]) + obj.attr ^= values + control ^= values + assert_eq() + + try: + direct ^= [e, creator()] + assert False + except TypeError: + assert True + + def _test_set_bulk(self, typecallable, creator=None): + if creator is None: + creator = self.entity_maker + + class Foo(object): + pass + + canary = Canary() + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, + typecallable=typecallable, useobject=True) + + obj = Foo() + direct = obj.attr + + e1 = creator() + obj.attr.add(e1) + + like_me = typecallable() + e2 = creator() + like_me.add(e2) + + self.assert_(obj.attr is direct) + obj.attr = like_me + self.assert_(obj.attr is not direct) + self.assert_(obj.attr is not like_me) + self.assert_(obj.attr == set([e2])) + self.assert_(e1 in canary.removed) + self.assert_(e2 in canary.added) + + e3 = creator() + real_set = set([e3]) + obj.attr = real_set + self.assert_(obj.attr is not real_set) + self.assert_(obj.attr == set([e3])) + self.assert_(e2 in canary.removed) + self.assert_(e3 in canary.added) + + e4 = creator() + try: + obj.attr = [e4] + self.assert_(False) + except TypeError: + self.assert_(e4 not in canary.data) + self.assert_(e3 in canary.data) + + def test_set(self): + self._test_adapter(set) + self._test_set(set) + self._test_set_bulk(set) + + def test_set_subclass(self): + class MySet(set): + pass + self._test_adapter(MySet) + self._test_set(MySet) + self._test_set_bulk(MySet) + self.assert_(getattr(MySet, '_sa_instrumented') == id(MySet)) + + def test_set_duck(self): + class SetLike(object): + def __init__(self): + self.data = set() + def add(self, item): + self.data.add(item) + def remove(self, item): + self.data.remove(item) + def discard(self, item): + self.data.discard(item) + def pop(self): + return self.data.pop() + def update(self, other): + self.data.update(other) + def __iter__(self): + return iter(self.data) + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + + self._test_adapter(SetLike) + self._test_set(SetLike) + self._test_set_bulk(SetLike) + self.assert_(getattr(SetLike, '_sa_instrumented') == id(SetLike)) + + def test_set_emulates(self): + class SetIsh(object): + __emulates__ = set + def __init__(self): + self.data = set() + def add(self, item): + self.data.add(item) + def remove(self, item): + self.data.remove(item) + def discard(self, item): + self.data.discard(item) + def pop(self): + return self.data.pop() + def update(self, other): + self.data.update(other) + def __iter__(self): + return iter(self.data) + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + + self._test_adapter(SetIsh) + self._test_set(SetIsh) + self._test_set_bulk(SetIsh) + self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh)) + + def _test_dict(self, typecallable, creator=None): + if creator is None: + creator = self.dictable_entity + + class Foo(object): + pass + + canary = Canary() + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, + typecallable=typecallable, useobject=True) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = dict() + + def assert_eq(): + self.assert_(set(direct.values()) == canary.data) + self.assert_(set(adapter) == canary.data) + self.assert_(direct == control) + + def addall(*values): + for item in values: + direct.set(item) + control[item.a] = item + assert_eq() + def zap(): + for item in list(adapter): + direct.remove(item) + control.clear() + + # assume an 'set' method is available for tests + addall(creator()) + + if hasattr(direct, '__setitem__'): + e = creator() + direct[e.a] = e + control[e.a] = e + assert_eq() + + e = creator(e.a, e.b) + direct[e.a] = e + control[e.a] = e + assert_eq() + + if hasattr(direct, '__delitem__'): + e = creator() + addall(e) + + del direct[e.a] + del control[e.a] + assert_eq() + + e = creator() + try: + del direct[e.a] + except KeyError: + self.assert_(e not in canary.removed) + + if hasattr(direct, 'clear'): + addall(creator(), creator(), creator()) + + direct.clear() + control.clear() + assert_eq() + + direct.clear() + control.clear() + assert_eq() + + if hasattr(direct, 'pop'): + e = creator() + addall(e) + + direct.pop(e.a) + control.pop(e.a) + assert_eq() + + e = creator() + try: + direct.pop(e.a) + except KeyError: + self.assert_(e not in canary.removed) + + if hasattr(direct, 'popitem'): + zap() + e = creator() + addall(e) + + direct.popitem() + control.popitem() + assert_eq() + + if hasattr(direct, 'setdefault'): + e = creator() + + val_a = direct.setdefault(e.a, e) + val_b = control.setdefault(e.a, e) + assert_eq() + self.assert_(val_a is val_b) + + val_a = direct.setdefault(e.a, e) + val_b = control.setdefault(e.a, e) + assert_eq() + self.assert_(val_a is val_b) + + if hasattr(direct, 'update'): + e = creator() + d = dict([(ee.a, ee) for ee in [e, creator(), creator()]]) + addall(e, creator()) + + direct.update(d) + control.update(d) + assert_eq() + + if sys.version_info >= (2, 4): + kw = dict([(ee.a, ee) for ee in [e, creator()]]) + direct.update(**kw) + control.update(**kw) + assert_eq() + + def _test_dict_bulk(self, typecallable, creator=None): + if creator is None: + creator = self.dictable_entity + + class Foo(object): + pass + + canary = Canary() + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, + typecallable=typecallable, useobject=True) + + obj = Foo() + direct = obj.attr + + e1 = creator() + collections.collection_adapter(direct).append_with_event(e1) + + like_me = typecallable() + e2 = creator() + like_me.set(e2) + + self.assert_(obj.attr is direct) + obj.attr = like_me + self.assert_(obj.attr is not direct) + self.assert_(obj.attr is not like_me) + self.assert_(set(collections.collection_adapter(obj.attr)) == set([e2])) + self.assert_(e1 in canary.removed) + self.assert_(e2 in canary.added) + + + # key validity on bulk assignment is a basic feature of MappedCollection + # but is not present in basic, @converter-less dict collections. + e3 = creator() + if isinstance(obj.attr, collections.MappedCollection): + real_dict = dict(badkey=e3) + try: + obj.attr = real_dict + self.assert_(False) + except TypeError: + pass + self.assert_(obj.attr is not real_dict) + self.assert_('badkey' not in obj.attr) + eq_(set(collections.collection_adapter(obj.attr)), + set([e2])) + self.assert_(e3 not in canary.added) + else: + real_dict = dict(keyignored1=e3) + obj.attr = real_dict + self.assert_(obj.attr is not real_dict) + self.assert_('keyignored1' not in obj.attr) + eq_(set(collections.collection_adapter(obj.attr)), + set([e3])) + self.assert_(e2 in canary.removed) + self.assert_(e3 in canary.added) + + obj.attr = typecallable() + eq_(list(collections.collection_adapter(obj.attr)), []) + + e4 = creator() + try: + obj.attr = [e4] + self.assert_(False) + except TypeError: + self.assert_(e4 not in canary.data) + + def test_dict(self): + try: + self._test_adapter(dict, self.dictable_entity, + to_set=lambda c: set(c.values())) + self.assert_(False) + except sa_exc.ArgumentError, e: + self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class') + + try: + self._test_dict(dict) + self.assert_(False) + except sa_exc.ArgumentError, e: + self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class') + + def test_dict_subclass(self): + class MyDict(dict): + @collection.appender + @collection.internally_instrumented + def set(self, item, _sa_initiator=None): + self.__setitem__(item.a, item, _sa_initiator=_sa_initiator) + @collection.remover + @collection.internally_instrumented + def _remove(self, item, _sa_initiator=None): + self.__delitem__(item.a, _sa_initiator=_sa_initiator) + + self._test_adapter(MyDict, self.dictable_entity, + to_set=lambda c: set(c.values())) + self._test_dict(MyDict) + self._test_dict_bulk(MyDict) + self.assert_(getattr(MyDict, '_sa_instrumented') == id(MyDict)) + + def test_dict_subclass2(self): + class MyEasyDict(collections.MappedCollection): + def __init__(self): + super(MyEasyDict, self).__init__(lambda e: e.a) + + self._test_adapter(MyEasyDict, self.dictable_entity, + to_set=lambda c: set(c.values())) + self._test_dict(MyEasyDict) + self._test_dict_bulk(MyEasyDict) + self.assert_(getattr(MyEasyDict, '_sa_instrumented') == id(MyEasyDict)) + + def test_dict_subclass3(self): + class MyOrdered(util.OrderedDict, collections.MappedCollection): + def __init__(self): + collections.MappedCollection.__init__(self, lambda e: e.a) + util.OrderedDict.__init__(self) + + self._test_adapter(MyOrdered, self.dictable_entity, + to_set=lambda c: set(c.values())) + self._test_dict(MyOrdered) + self._test_dict_bulk(MyOrdered) + self.assert_(getattr(MyOrdered, '_sa_instrumented') == id(MyOrdered)) + + def test_dict_duck(self): + class DictLike(object): + def __init__(self): + self.data = dict() + + @collection.appender + @collection.replaces(1) + def set(self, item): + current = self.data.get(item.a, None) + self.data[item.a] = item + return current + @collection.remover + def _remove(self, item): + del self.data[item.a] + def __setitem__(self, key, value): + self.data[key] = value + def __getitem__(self, key): + return self.data[key] + def __delitem__(self, key): + del self.data[key] + def values(self): + return self.data.values() + def __contains__(self, key): + return key in self.data + @collection.iterator + def itervalues(self): + return self.data.itervalues() + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'DictLike(%s)' % repr(self.data) + + self._test_adapter(DictLike, self.dictable_entity, + to_set=lambda c: set(c.itervalues())) + self._test_dict(DictLike) + self._test_dict_bulk(DictLike) + self.assert_(getattr(DictLike, '_sa_instrumented') == id(DictLike)) + + def test_dict_emulates(self): + class DictIsh(object): + __emulates__ = dict + def __init__(self): + self.data = dict() + + @collection.appender + @collection.replaces(1) + def set(self, item): + current = self.data.get(item.a, None) + self.data[item.a] = item + return current + @collection.remover + def _remove(self, item): + del self.data[item.a] + def __setitem__(self, key, value): + self.data[key] = value + def __getitem__(self, key): + return self.data[key] + def __delitem__(self, key): + del self.data[key] + def values(self): + return self.data.values() + def __contains__(self, key): + return key in self.data + @collection.iterator + def itervalues(self): + return self.data.itervalues() + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'DictIsh(%s)' % repr(self.data) + + self._test_adapter(DictIsh, self.dictable_entity, + to_set=lambda c: set(c.itervalues())) + self._test_dict(DictIsh) + self._test_dict_bulk(DictIsh) + self.assert_(getattr(DictIsh, '_sa_instrumented') == id(DictIsh)) + + def _test_object(self, typecallable, creator=None): + if creator is None: + creator = self.entity_maker + + class Foo(object): + pass + + canary = Canary() + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, + typecallable=typecallable, useobject=True) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = set() + + def assert_eq(): + self.assert_(set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) + self.assert_(direct == control) + + # There is no API for object collections. We'll make one up + # for the purposes of the test. + e = creator() + direct.push(e) + control.add(e) + assert_eq() + + direct.zark(e) + control.remove(e) + assert_eq() + + e = creator() + direct.maybe_zark(e) + control.discard(e) + assert_eq() + + e = creator() + direct.push(e) + control.add(e) + assert_eq() + + e = creator() + direct.maybe_zark(e) + control.discard(e) + assert_eq() + + def test_object_duck(self): + class MyCollection(object): + def __init__(self): + self.data = set() + @collection.appender + def push(self, item): + self.data.add(item) + @collection.remover + def zark(self, item): + self.data.remove(item) + @collection.removes_return() + def maybe_zark(self, item): + if item in self.data: + self.data.remove(item) + return item + @collection.iterator + def __iter__(self): + return iter(self.data) + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + + self._test_adapter(MyCollection) + self._test_object(MyCollection) + self.assert_(getattr(MyCollection, '_sa_instrumented') == + id(MyCollection)) + + def test_object_emulates(self): + class MyCollection2(object): + __emulates__ = None + def __init__(self): + self.data = set() + # looks like a list + def append(self, item): + assert False + @collection.appender + def push(self, item): + self.data.add(item) + @collection.remover + def zark(self, item): + self.data.remove(item) + @collection.removes_return() + def maybe_zark(self, item): + if item in self.data: + self.data.remove(item) + return item + @collection.iterator + def __iter__(self): + return iter(self.data) + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + + self._test_adapter(MyCollection2) + self._test_object(MyCollection2) + self.assert_(getattr(MyCollection2, '_sa_instrumented') == + id(MyCollection2)) + + def test_recipes(self): + class Custom(object): + def __init__(self): + self.data = [] + @collection.appender + @collection.adds('entity') + def put(self, entity): + self.data.append(entity) + + @collection.remover + @collection.removes(1) + def remove(self, entity): + self.data.remove(entity) + + @collection.adds(1) + def push(self, *args): + self.data.append(args[0]) + + @collection.removes('entity') + def yank(self, entity, arg): + self.data.remove(entity) + + @collection.replaces(2) + def replace(self, arg, entity, **kw): + self.data.insert(0, entity) + return self.data.pop() + + @collection.removes_return() + def pop(self, key): + return self.data.pop() + + @collection.iterator + def __iter__(self): + return iter(self.data) + + class Foo(object): + pass + canary = Canary() + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, + typecallable=Custom, useobject=True) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = list() + def assert_eq(): + self.assert_(set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) + self.assert_(list(direct) == control) + creator = self.entity_maker + + e1 = creator() + direct.put(e1) + control.append(e1) + assert_eq() + + e2 = creator() + direct.put(entity=e2) + control.append(e2) + assert_eq() + + direct.remove(e2) + control.remove(e2) + assert_eq() + + direct.remove(entity=e1) + control.remove(e1) + assert_eq() + + e3 = creator() + direct.push(e3) + control.append(e3) + assert_eq() + + direct.yank(e3, 'blah') + control.remove(e3) + assert_eq() + + e4, e5, e6, e7 = creator(), creator(), creator(), creator() + direct.put(e4) + direct.put(e5) + control.append(e4) + control.append(e5) + + dr1 = direct.replace('foo', e6, bar='baz') + control.insert(0, e6) + cr1 = control.pop() + assert_eq() + self.assert_(dr1 is cr1) + + dr2 = direct.replace(arg=1, entity=e7) + control.insert(0, e7) + cr2 = control.pop() + assert_eq() + self.assert_(dr2 is cr2) + + dr3 = direct.pop('blah') + cr3 = control.pop() + assert_eq() + self.assert_(dr3 is cr3) + + def test_lifecycle(self): + class Foo(object): + pass + + canary = Canary() + creator = self.entity_maker + attributes.register_class(Foo) + attributes.register_attribute(Foo, 'attr', uselist=True, extension=canary, useobject=True) + + obj = Foo() + col1 = obj.attr + + e1 = creator() + obj.attr.append(e1) + + e2 = creator() + bulk1 = [e2] + # empty & sever col1 from obj + obj.attr = bulk1 + self.assert_(len(col1) == 0) + self.assert_(len(canary.data) == 1) + self.assert_(obj.attr is not col1) + self.assert_(obj.attr is not bulk1) + self.assert_(obj.attr == bulk1) + + e3 = creator() + col1.append(e3) + self.assert_(e3 not in canary.data) + self.assert_(collections.collection_adapter(col1) is None) + + obj.attr[0] = e3 + self.assert_(e3 in canary.data) + +class DictHelpersTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('parents', metadata, + Column('id', Integer, primary_key=True), + Column('label', String(128))) + Table('children', metadata, + Column('id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('parents.id'), + nullable=False), + Column('a', String(128)), + Column('b', String(128)), + Column('c', String(128))) + + @classmethod + def setup_classes(cls): + class Parent(_base.BasicEntity): + def __init__(self, label=None): + self.label = label + + class Child(_base.BasicEntity): + def __init__(self, a=None, b=None, c=None): + self.a = a + self.b = b + self.c = c + + @testing.resolve_artifact_names + def _test_scalar_mapped(self, collection_class): + mapper(Child, children) + mapper(Parent, parents, properties={ + 'children': relation(Child, collection_class=collection_class, + cascade="all, delete-orphan")}) + + p = Parent() + p.children['foo'] = Child('foo', 'value') + p.children['bar'] = Child('bar', 'value') + session = create_session() + session.add(p) + session.flush() + pid = p.id + session.expunge_all() + + p = session.query(Parent).get(pid) + + + eq_(set(p.children.keys()), set(['foo', 'bar'])) + cid = p.children['foo'].id + + collections.collection_adapter(p.children).append_with_event( + Child('foo', 'newvalue')) + + session.flush() + session.expunge_all() + + p = session.query(Parent).get(pid) + + self.assert_(set(p.children.keys()) == set(['foo', 'bar'])) + self.assert_(p.children['foo'].id != cid) + + self.assert_(len(list(collections.collection_adapter(p.children))) == 2) + session.flush() + session.expunge_all() + + p = session.query(Parent).get(pid) + self.assert_(len(list(collections.collection_adapter(p.children))) == 2) + + collections.collection_adapter(p.children).remove_with_event( + p.children['foo']) + + self.assert_(len(list(collections.collection_adapter(p.children))) == 1) + session.flush() + session.expunge_all() + + p = session.query(Parent).get(pid) + self.assert_(len(list(collections.collection_adapter(p.children))) == 1) + + del p.children['bar'] + self.assert_(len(list(collections.collection_adapter(p.children))) == 0) + session.flush() + session.expunge_all() + + p = session.query(Parent).get(pid) + self.assert_(len(list(collections.collection_adapter(p.children))) == 0) + + + @testing.resolve_artifact_names + def _test_composite_mapped(self, collection_class): + mapper(Child, children) + mapper(Parent, parents, properties={ + 'children': relation(Child, collection_class=collection_class, + cascade="all, delete-orphan") + }) + + p = Parent() + p.children[('foo', '1')] = Child('foo', '1', 'value 1') + p.children[('foo', '2')] = Child('foo', '2', 'value 2') + + session = create_session() + session.add(p) + session.flush() + pid = p.id + session.expunge_all() + + p = session.query(Parent).get(pid) + + self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])) + cid = p.children[('foo', '1')].id + + collections.collection_adapter(p.children).append_with_event( + Child('foo', '1', 'newvalue')) + + session.flush() + session.expunge_all() + + p = session.query(Parent).get(pid) + + self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])) + self.assert_(p.children[('foo', '1')].id != cid) + + self.assert_(len(list(collections.collection_adapter(p.children))) == 2) + + def test_mapped_collection(self): + collection_class = collections.mapped_collection(lambda c: c.a) + self._test_scalar_mapped(collection_class) + + def test_mapped_collection2(self): + collection_class = collections.mapped_collection(lambda c: (c.a, c.b)) + self._test_composite_mapped(collection_class) + + def test_attr_mapped_collection(self): + collection_class = collections.attribute_mapped_collection('a') + self._test_scalar_mapped(collection_class) + + def test_declarative_column_mapped(self): + """test that uncompiled attribute usage works with column_mapped_collection""" + + from sqlalchemy.ext.declarative import declarative_base + + BaseObject = declarative_base() + + class Foo(BaseObject): + __tablename__ = "foo" + id = Column(Integer(), primary_key=True) + bar_id = Column(Integer, ForeignKey('bar.id')) + + class Bar(BaseObject): + __tablename__ = "bar" + id = Column(Integer(), primary_key=True) + foos = relation(Foo, collection_class=collections.column_mapped_collection(Foo.id)) + foos2 = relation(Foo, collection_class=collections.column_mapped_collection((Foo.id, Foo.bar_id))) + + eq_(Bar.foos.property.collection_class().keyfunc(Foo(id=3)), 3) + eq_(Bar.foos2.property.collection_class().keyfunc(Foo(id=3, bar_id=12)), (3, 12)) + + @testing.resolve_artifact_names + def test_column_mapped_collection(self): + collection_class = collections.column_mapped_collection( + children.c.a) + self._test_scalar_mapped(collection_class) + + @testing.resolve_artifact_names + def test_column_mapped_collection2(self): + collection_class = collections.column_mapped_collection( + (children.c.a, children.c.b)) + self._test_composite_mapped(collection_class) + + def test_mixin(self): + class Ordered(util.OrderedDict, collections.MappedCollection): + def __init__(self): + collections.MappedCollection.__init__(self, lambda v: v.a) + util.OrderedDict.__init__(self) + collection_class = Ordered + self._test_scalar_mapped(collection_class) + + def test_mixin2(self): + class Ordered2(util.OrderedDict, collections.MappedCollection): + def __init__(self, keyfunc): + collections.MappedCollection.__init__(self, keyfunc) + util.OrderedDict.__init__(self) + collection_class = lambda: Ordered2(lambda v: (v.a, v.b)) + self._test_composite_mapped(collection_class) + +# TODO: are these tests redundant vs. the above tests ? +# remove if so +class CustomCollectionsTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('sometable', metadata, + Column('col1',Integer, primary_key=True), + Column('data', String(30))) + Table('someothertable', metadata, + Column('col1', Integer, primary_key=True), + Column('scol1', Integer, + ForeignKey('sometable.col1')), + Column('data', String(20))) + + @testing.resolve_artifact_names + def test_basic(self): + class MyList(list): + pass + class Foo(object): + pass + class Bar(object): + pass + + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, collection_class=MyList) + }) + mapper(Bar, someothertable) + f = Foo() + assert isinstance(f.bars, MyList) + + @testing.resolve_artifact_names + def test_lazyload(self): + """test that a 'set' can be used as a collection and can lazyload.""" + class Foo(object): + pass + class Bar(object): + pass + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, collection_class=set) + }) + mapper(Bar, someothertable) + f = Foo() + f.bars.add(Bar()) + f.bars.add(Bar()) + sess = create_session() + sess.add(f) + sess.flush() + sess.expunge_all() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + f.bars.clear() + + @testing.resolve_artifact_names + def test_dict(self): + """test that a 'dict' can be used as a collection and can lazyload.""" + + class Foo(object): + pass + class Bar(object): + pass + class AppenderDict(dict): + @collection.appender + def set(self, item): + self[id(item)] = item + @collection.remover + def remove(self, item): + if id(item) in self: + del self[id(item)] + + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, collection_class=AppenderDict) + }) + mapper(Bar, someothertable) + f = Foo() + f.bars.set(Bar()) + f.bars.set(Bar()) + sess = create_session() + sess.add(f) + sess.flush() + sess.expunge_all() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + f.bars.clear() + + @testing.resolve_artifact_names + def test_dict_wrapper(self): + """test that the supplied 'dict' wrapper can be used as a collection and can lazyload.""" + + class Foo(object): + pass + class Bar(object): + def __init__(self, data): self.data = data + + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, + collection_class=collections.column_mapped_collection( + someothertable.c.data)) + }) + mapper(Bar, someothertable) + + f = Foo() + col = collections.collection_adapter(f.bars) + col.append_with_event(Bar('a')) + col.append_with_event(Bar('b')) + sess = create_session() + sess.add(f) + sess.flush() + sess.expunge_all() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + + existing = set([id(b) for b in f.bars.values()]) + + col = collections.collection_adapter(f.bars) + col.append_with_event(Bar('b')) + f.bars['a'] = Bar('a') + sess.flush() + sess.expunge_all() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + + replaced = set([id(b) for b in f.bars.values()]) + self.assert_(existing != replaced) + + @testing.resolve_artifact_names + def test_list(self): + class Parent(object): + pass + class Child(object): + pass + + mapper(Parent, sometable, properties={ + 'children':relation(Child, collection_class=list) + }) + mapper(Child, someothertable) + + control = list() + p = Parent() + + o = Child() + control.append(o) + p.children.append(o) + assert control == p.children + assert control == list(p.children) + + o = [Child(), Child(), Child(), Child()] + control.extend(o) + p.children.extend(o) + assert control == p.children + assert control == list(p.children) + + assert control[0] == p.children[0] + assert control[-1] == p.children[-1] + assert control[1:3] == p.children[1:3] + + del control[1] + del p.children[1] + assert control == p.children + assert control == list(p.children) + + o = [Child()] + control[1:3] = o + + p.children[1:3] = o + assert control == p.children + assert control == list(p.children) + + o = [Child(), Child(), Child(), Child()] + control[1:3] = o + p.children[1:3] = o + assert control == p.children + assert control == list(p.children) + + o = [Child(), Child(), Child(), Child()] + control[-1:-2] = o + p.children[-1:-2] = o + assert control == p.children + assert control == list(p.children) + + o = [Child(), Child(), Child(), Child()] + control[4:] = o + p.children[4:] = o + assert control == p.children + assert control == list(p.children) + + o = Child() + control.insert(0, o) + p.children.insert(0, o) + assert control == p.children + assert control == list(p.children) + + o = Child() + control.insert(3, o) + p.children.insert(3, o) + assert control == p.children + assert control == list(p.children) + + o = Child() + control.insert(999, o) + p.children.insert(999, o) + assert control == p.children + assert control == list(p.children) + + del control[0:1] + del p.children[0:1] + assert control == p.children + assert control == list(p.children) + + del control[1:1] + del p.children[1:1] + assert control == p.children + assert control == list(p.children) + + del control[1:3] + del p.children[1:3] + assert control == p.children + assert control == list(p.children) + + del control[7:] + del p.children[7:] + assert control == p.children + assert control == list(p.children) + + assert control.pop() == p.children.pop() + assert control == p.children + assert control == list(p.children) + + assert control.pop(0) == p.children.pop(0) + assert control == p.children + assert control == list(p.children) + + assert control.pop(2) == p.children.pop(2) + assert control == p.children + assert control == list(p.children) + + o = Child() + control.insert(2, o) + p.children.insert(2, o) + assert control == p.children + assert control == list(p.children) + + control.remove(o) + p.children.remove(o) + assert control == p.children + assert control == list(p.children) + + @testing.resolve_artifact_names + def test_custom(self): + class Parent(object): + pass + class Child(object): + pass + + class MyCollection(object): + def __init__(self): + self.data = [] + @collection.appender + def append(self, value): + self.data.append(value) + @collection.remover + def remove(self, value): + self.data.remove(value) + @collection.iterator + def __iter__(self): + return iter(self.data) + + mapper(Parent, sometable, properties={ + 'children':relation(Child, collection_class=MyCollection) + }) + mapper(Child, someothertable) + + control = list() + p1 = Parent() + + o = Child() + control.append(o) + p1.children.append(o) + assert control == list(p1.children) + + o = Child() + control.append(o) + p1.children.append(o) + assert control == list(p1.children) + + o = Child() + control.append(o) + p1.children.append(o) + assert control == list(p1.children) + + sess = create_session() + sess.add(p1) + sess.flush() + sess.expunge_all() + + p2 = sess.query(Parent).get(p1.col1) + o = list(p2.children) + assert len(o) == 3 + + +class InstrumentationTest(_base.ORMTest): + def test_uncooperative_descriptor_in_sweep(self): + class DoNotTouch(object): + def __get__(self, obj, owner): + raise AttributeError + + class Touchy(list): + no_touch = DoNotTouch() + + assert 'no_touch' in Touchy.__dict__ + assert not hasattr(Touchy, 'no_touch') + assert 'no_touch' in dir(Touchy) + + instrumented = collections._instrument_class(Touchy) + assert True + diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py new file mode 100644 index 000000000..7a5b63615 --- /dev/null +++ b/test/orm/test_compile.py @@ -0,0 +1,183 @@ +from sqlalchemy import * +from sqlalchemy import exc as sa_exc +from sqlalchemy.orm import * +from sqlalchemy.test import * +from test.orm import _base + + +class CompileTest(_base.ORMTest): + """test various mapper compilation scenarios""" + + def teardown(self): + clear_mappers() + + def testone(self): + metadata = MetaData(testing.db) + + order = Table('orders', metadata, + Column('id', Integer, primary_key=True), + Column('employee_id', Integer, ForeignKey('employees.id'), nullable=False), + Column('type', Unicode(16))) + + employee = Table('employees', metadata, + Column('id', Integer, primary_key=True), + Column('name', Unicode(16), unique=True, nullable=False)) + + product = Table('products', metadata, + Column('id', Integer, primary_key=True), + ) + + orderproduct = Table('orderproducts', metadata, + Column('id', Integer, primary_key=True), + Column('order_id', Integer, ForeignKey("orders.id"), nullable=False), + Column('product_id', Integer, ForeignKey("products.id"), nullable=False), + ) + + class Order(object): + pass + + class Employee(object): + pass + + class Product(object): + pass + + class OrderProduct(object): + pass + + order_join = order.select().alias('pjoin') + + order_mapper = mapper(Order, order, + with_polymorphic=('*', order_join), + polymorphic_on=order_join.c.type, + polymorphic_identity='order', + properties={ + 'orderproducts': relation(OrderProduct, lazy=True, backref='order')} + ) + + mapper(Product, product, + properties={ + 'orderproducts': relation(OrderProduct, lazy=True, backref='product')} + ) + + mapper(Employee, employee, + properties={ + 'orders': relation(Order, lazy=True, backref='employee')}) + + mapper(OrderProduct, orderproduct) + + # this requires that the compilation of order_mapper's "surrogate + # mapper" occur after the initial setup of MapperProperty objects on + # the mapper. + class_mapper(Product).compile() + + def testtwo(self): + """test that conflicting backrefs raises an exception""" + metadata = MetaData(testing.db) + + order = Table('orders', metadata, + Column('id', Integer, primary_key=True), + Column('type', Unicode(16))) + + product = Table('products', metadata, + Column('id', Integer, primary_key=True), + ) + + orderproduct = Table('orderproducts', metadata, + Column('id', Integer, primary_key=True), + Column('order_id', Integer, ForeignKey("orders.id"), nullable=False), + Column('product_id', Integer, ForeignKey("products.id"), nullable=False), + ) + + class Order(object): + pass + + class Product(object): + pass + + class OrderProduct(object): + pass + + order_join = order.select().alias('pjoin') + + order_mapper = mapper(Order, order, + with_polymorphic=('*', order_join), + polymorphic_on=order_join.c.type, + polymorphic_identity='order', + properties={ + 'orderproducts': relation(OrderProduct, lazy=True, backref='product')} + ) + + mapper(Product, product, + properties={ + 'orderproducts': relation(OrderProduct, lazy=True, backref='product')} + ) + + mapper(OrderProduct, orderproduct) + + try: + class_mapper(Product).compile() + assert False + except sa_exc.ArgumentError, e: + assert str(e).index("Error creating backref ") > -1 + + def testthree(self): + metadata = MetaData(testing.db) + node_table = Table("node", metadata, + Column('node_id', Integer, primary_key=True), + Column('name_index', Integer, nullable=True), + ) + node_name_table = Table("node_name", metadata, + Column('node_name_id', Integer, primary_key=True), + Column('node_id', Integer, ForeignKey('node.node_id')), + Column('host_id', Integer, ForeignKey('host.host_id')), + Column('name', String(64), nullable=False), + ) + host_table = Table("host", metadata, + Column('host_id', Integer, primary_key=True), + Column('hostname', String(64), nullable=False, + unique=True), + ) + metadata.create_all() + try: + node_table.insert().execute(node_id=1, node_index=5) + class Node(object):pass + class NodeName(object):pass + class Host(object):pass + + node_mapper = mapper(Node, node_table) + host_mapper = mapper(Host, host_table) + node_name_mapper = mapper(NodeName, node_name_table, + properties = { + 'node' : relation(Node, backref=backref('names')), + 'host' : relation(Host), + } + ) + sess = create_session() + assert sess.query(Node).get(1).names == [] + finally: + metadata.drop_all() + + def testfour(self): + meta = MetaData() + + a = Table('a', meta, Column('id', Integer, primary_key=True)) + b = Table('b', meta, Column('id', Integer, primary_key=True), Column('a_id', Integer, ForeignKey('a.id'))) + + class A(object):pass + class B(object):pass + + mapper(A, a, properties={ + 'b':relation(B, backref='a') + }) + mapper(B, b, properties={ + 'a':relation(A, backref='b') + }) + + try: + compile_mappers() + assert False + except sa_exc.ArgumentError, e: + assert str(e).index("Error creating backref") > -1 + + diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py new file mode 100644 index 000000000..fe77b3601 --- /dev/null +++ b/test/orm/test_cycles.py @@ -0,0 +1,885 @@ +"""Tests cyclical mapper relationships. + +We might want to try an automated generate of much of this, all combos of +T1<->T2, with o2m or m2o between them, and a third T3 with o2m/m2o to one/both +T1/T2. + +""" +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, create_session +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf +from test.orm import _base + + +class SelfReferentialTest(_base.MappedTest): + """A self-referential mapper with an additional list of child objects.""" + + @classmethod + def define_tables(cls, metadata): + Table('t1', metadata, + Column('c1', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('parent_c1', Integer, ForeignKey('t1.c1')), + Column('data', String(20))) + Table('t2', metadata, + Column('c1', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('c1id', Integer, ForeignKey('t1.c1')), + Column('data', String(20))) + + @classmethod + def setup_classes(cls): + class C1(_base.BasicEntity): + def __init__(self, data=None): + self.data = data + + class C2(_base.BasicEntity): + def __init__(self, data=None): + self.data = data + + @testing.resolve_artifact_names + def testsingle(self): + mapper(C1, t1, properties = { + 'c1s':relation(C1, cascade="all"), + 'parent':relation(C1, + primaryjoin=t1.c.parent_c1 == t1.c.c1, + remote_side=t1.c.c1, + lazy=True, + uselist=False)}) + a = C1('head c1') + a.c1s.append(C1('another c1')) + + sess = create_session( ) + sess.add(a) + sess.flush() + sess.delete(a) + sess.flush() + + @testing.resolve_artifact_names + def testmanytooneonly(self): + """ + + test that the circular dependency sort can assemble a many-to-one + dependency processor when only the object on the "many" side is + actually in the list of modified objects. this requires that the + circular sort add the other side of the relation into the + UOWTransaction so that the dependency operation can be tacked onto it. + + This also affects inheritance relationships since they rely upon + circular sort as well. + + """ + mapper(C1, t1, properties={ + 'parent':relation(C1, + primaryjoin=t1.c.parent_c1 == t1.c.c1, + remote_side=t1.c.c1)}) + + c1 = C1() + + sess = create_session() + sess.add(c1) + sess.flush() + sess.expunge_all() + c1 = sess.query(C1).get(c1.c1) + c2 = C1() + c2.parent = c1 + sess.add(c2) + sess.flush() + assert c2.parent_c1==c1.c1 + + @testing.resolve_artifact_names + def testcycle(self): + mapper(C1, t1, properties = { + 'c1s' : relation(C1, cascade="all"), + 'c2s' : relation(mapper(C2, t2), cascade="all, delete-orphan")}) + + a = C1('head c1') + a.c1s.append(C1('child1')) + a.c1s.append(C1('child2')) + a.c1s[0].c1s.append(C1('subchild1')) + a.c1s[0].c1s.append(C1('subchild2')) + a.c1s[1].c2s.append(C2('child2 data1')) + a.c1s[1].c2s.append(C2('child2 data2')) + sess = create_session( ) + sess.add(a) + sess.flush() + + sess.delete(a) + sess.flush() + + @testing.resolve_artifact_names + def test_setnull_ondelete(self): + mapper(C1, t1, properties={ + 'children':relation(C1) + }) + + sess = create_session() + c1 = C1() + c2 = C1() + c1.children.append(c2) + sess.add(c1) + sess.flush() + assert c2.parent_c1 == c1.c1 + + sess.delete(c1) + sess.flush() + assert c2.parent_c1 is None + + sess.expire_all() + assert c2.parent_c1 is None + +class SelfReferentialNoPKTest(_base.MappedTest): + """A self-referential relationship that joins on a column other than the primary key column""" + + @classmethod + def define_tables(cls, metadata): + Table('item', metadata, + Column('id', Integer, primary_key=True), + Column('uuid', String(32), unique=True, nullable=False), + Column('parent_uuid', String(32), ForeignKey('item.uuid'), + nullable=True)) + + @classmethod + def setup_classes(cls): + class TT(_base.BasicEntity): + def __init__(self): + self.uuid = hex(id(self)) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(TT, item, properties={ + 'children': relation( + TT, + remote_side=[item.c.parent_uuid], + backref=backref('parent', remote_side=[item.c.uuid]))}) + + @testing.resolve_artifact_names + def testbasic(self): + t1 = TT() + t1.children.append(TT()) + t1.children.append(TT()) + + s = create_session() + s.add(t1) + s.flush() + s.expunge_all() + t = s.query(TT).filter_by(id=t1.id).one() + eq_(t.children[0].parent_uuid, t1.uuid) + + @testing.resolve_artifact_names + def testlazyclause(self): + s = create_session() + t1 = TT() + t2 = TT() + t1.children.append(t2) + s.add(t1) + s.flush() + s.expunge_all() + + t = s.query(TT).filter_by(id=t2.id).one() + eq_(t.uuid, t2.uuid) + eq_(t.parent.uuid, t1.uuid) + + +class InheritTestOne(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table("parent", metadata, + Column("id", Integer, primary_key=True), + Column("parent_data", String(50)), + Column("type", String(10))) + + Table("child1", metadata, + Column("id", Integer, ForeignKey("parent.id"), + primary_key=True), + Column("child1_data", String(50))) + + Table("child2", metadata, + Column("id", Integer, ForeignKey("parent.id"), + primary_key=True), + Column("child1_id", Integer, ForeignKey("child1.id"), + nullable=False), + Column("child2_data", String(50))) + + @classmethod + def setup_classes(cls): + class Parent(_base.BasicEntity): + pass + + class Child1(Parent): + pass + + class Child2(Parent): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Parent, parent) + mapper(Child1, child1, inherits=Parent) + mapper(Child2, child2, inherits=Parent, properties=dict( + child1=relation(Child1, + primaryjoin=child2.c.child1_id == child1.c.id))) + + @testing.resolve_artifact_names + def testmanytooneonly(self): + """test similar to SelfReferentialTest.testmanytooneonly""" + + session = create_session() + + c1 = Child1() + c1.child1_data = "qwerty" + session.add(c1) + session.flush() + session.expunge_all() + + c1 = session.query(Child1).filter_by(child1_data="qwerty").one() + c2 = Child2() + c2.child1 = c1 + c2.child2_data = "asdfgh" + session.add(c2) + + # the flush will fail if the UOW does not set up a many-to-one DP + # attached to a task corresponding to c1, since "child1_id" is not + # nullable + session.flush() + + +class InheritTestTwo(_base.MappedTest): + """ + + The fix in BiDirectionalManyToOneTest raised this issue, regarding the + 'circular sort' containing UOWTasks that were still polymorphic, which + could create duplicate entries in the final sort + + """ + + @classmethod + def define_tables(cls, metadata): + Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('cid', Integer, ForeignKey('c.id'))) + + Table('b', metadata, + Column('id', Integer, ForeignKey("a.id"), primary_key=True), + Column('data', String(30))) + + Table('c', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('aid', Integer, + ForeignKey('a.id', use_alter=True, name="foo"))) + + @classmethod + def setup_classes(cls): + class A(_base.BasicEntity): + pass + + class B(A): + pass + + class C(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def test_flush(self): + mapper(A, a, properties={ + 'cs':relation(C, primaryjoin=a.c.cid==c.c.id)}) + + mapper(B, b, inherits=A, inherit_condition=b.c.id == a.c.id) + + mapper(C, c, properties={ + 'arel':relation(A, primaryjoin=a.c.id == c.c.aid)}) + + sess = create_session() + bobj = B() + sess.add(bobj) + cobj = C() + sess.add(cobj) + sess.flush() + + +class BiDirectionalManyToOneTest(_base.MappedTest): + run_define_tables = 'each' + + @classmethod + def define_tables(cls, metadata): + Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('t2id', Integer, ForeignKey('t2.id'))) + Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('t1id', Integer, + ForeignKey('t1.id', use_alter=True, name="foo_fk"))) + Table('t3', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('t1id', Integer, ForeignKey('t1.id'), nullable=False), + Column('t2id', Integer, ForeignKey('t2.id'), nullable=False)) + + @classmethod + def setup_classes(cls): + class T1(_base.BasicEntity): + pass + class T2(_base.BasicEntity): + pass + class T3(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(T1, t1, properties={ + 't2':relation(T2, primaryjoin=t1.c.t2id == t2.c.id)}) + mapper(T2, t2, properties={ + 't1':relation(T1, primaryjoin=t2.c.t1id == t1.c.id)}) + mapper(T3, t3, properties={ + 't1':relation(T1), + 't2':relation(T2)}) + + @testing.resolve_artifact_names + def test_reflush(self): + o1 = T1() + o1.t2 = T2() + sess = create_session() + sess.add(o1) + sess.flush() + + # the bug here is that the dependency sort comes up with T1/T2 in a + # cycle, but there are no T1/T2 objects to be saved. therefore no + # "cyclical subtree" gets generated, and one or the other of T1/T2 + # gets lost, and processors on T3 dont fire off. the test will then + # fail because the FK's on T3 are not nullable. + o3 = T3() + o3.t1 = o1 + o3.t2 = o1.t2 + sess.add(o3) + sess.flush() + + + @testing.resolve_artifact_names + def test_reflush_2(self): + """A variant on test_reflush()""" + o1 = T1() + o1.t2 = T2() + sess = create_session() + sess.add(o1) + sess.flush() + + # in this case, T1, T2, and T3 tasks will all be in the cyclical + # tree normally. the dependency processors for T3 are part of the + # 'extradeps' collection so they all get assembled into the tree + # as well. + o1a = T1() + o2a = T2() + sess.add(o1a) + sess.add(o2a) + o3b = T3() + o3b.t1 = o1a + o3b.t2 = o2a + sess.add(o3b) + + o3 = T3() + o3.t1 = o1 + o3.t2 = o1.t2 + sess.add(o3) + sess.flush() + + +class BiDirectionalOneToManyTest(_base.MappedTest): + """tests two mappers with a one-to-many relation to each other.""" + + run_define_tables = 'each' + + @classmethod + def define_tables(cls, metadata): + Table('t1', metadata, + Column('c1', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('c2', Integer, ForeignKey('t2.c1'))) + + Table('t2', metadata, + Column('c1', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('c2', Integer, + ForeignKey('t1.c1', use_alter=True, name='t1c1_fk'))) + + @classmethod + def setup_classes(cls): + class C1(_base.BasicEntity): + pass + + class C2(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def testcycle(self): + mapper(C2, t2, properties={ + 'c1s': relation(C1, + primaryjoin=t2.c.c1 == t1.c.c2, + uselist=True)}) + mapper(C1, t1, properties={ + 'c2s': relation(C2, + primaryjoin=t1.c.c1 == t2.c.c2, + uselist=True)}) + + a = C1() + b = C2() + c = C1() + d = C2() + e = C2() + f = C2() + a.c2s.append(b) + d.c1s.append(c) + b.c1s.append(c) + sess = create_session() + sess.add_all((a, b, c, d, e, f)) + sess.flush() + + +class BiDirectionalOneToManyTest2(_base.MappedTest): + """Two mappers with a one-to-many relation to each other, with a second one-to-many on one of the mappers""" + + run_define_tables = 'each' + + @classmethod + def define_tables(cls, metadata): + Table('t1', metadata, + Column('c1', Integer, primary_key=True), + Column('c2', Integer, ForeignKey('t2.c1')), + test_needs_autoincrement=True) + + Table('t2', metadata, + Column('c1', Integer, primary_key=True), + Column('c2', Integer, + ForeignKey('t1.c1', use_alter=True, name='t1c1_fq')), + test_needs_autoincrement=True) + + Table('t1_data', metadata, + Column('c1', Integer, primary_key=True), + Column('t1id', Integer, ForeignKey('t1.c1')), + Column('data', String(20)), + test_needs_autoincrement=True) + + @classmethod + def setup_classes(cls): + class C1(_base.BasicEntity): + pass + + class C2(_base.BasicEntity): + pass + + class C1Data(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(C2, t2, properties={ + 'c1s': relation(C1, + primaryjoin=t2.c.c1 == t1.c.c2, + uselist=True)}) + mapper(C1, t1, properties={ + 'c2s': relation(C2, + primaryjoin=t1.c.c1 == t2.c.c2, + uselist=True), + 'data': relation(mapper(C1Data, t1_data))}) + + @testing.resolve_artifact_names + def testcycle(self): + a = C1() + b = C2() + c = C1() + d = C2() + e = C2() + f = C2() + a.c2s.append(b) + d.c1s.append(c) + b.c1s.append(c) + a.data.append(C1Data(data='c1data1')) + a.data.append(C1Data(data='c1data2')) + c.data.append(C1Data(data='c1data3')) + sess = create_session() + sess.add_all((a, b, c, d, e, f)) + sess.flush() + + sess.delete(d) + sess.delete(c) + sess.flush() + +class OneToManyManyToOneTest(_base.MappedTest): + """ + + Tests two mappers, one has a one-to-many on the other mapper, the other + has a separate many-to-one relationship to the first. two tests will have + a row for each item that is dependent on the other. without the + "post_update" flag, such relationships raise an exception when + dependencies are sorted. + + """ + run_define_tables = 'each' + + @classmethod + def define_tables(cls, metadata): + Table('ball', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('person_id', Integer, + ForeignKey('person.id', use_alter=True, name='fk_person_id')), + Column('data', String(30))) + + Table('person', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('favorite_ball_id', Integer, ForeignKey('ball.id')), + Column('data', String(30))) + + @classmethod + def setup_classes(cls): + class Person(_base.BasicEntity): + pass + + class Ball(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def testcycle(self): + """ + This test has a peculiar aspect in that it doesnt create as many + dependent relationships as the other tests, and revealed a small + glitch in the circular dependency sorting. + + """ + mapper(Ball, ball) + mapper(Person, person, properties=dict( + balls=relation(Ball, + primaryjoin=ball.c.person_id == person.c.id, + remote_side=ball.c.person_id), + favorite=relation(Ball, + primaryjoin=person.c.favorite_ball_id == ball.c.id, + remote_side=ball.c.id))) + + b = Ball() + p = Person() + p.balls.append(b) + sess = create_session() + sess.add(p) + sess.flush() + + @testing.resolve_artifact_names + def testpostupdate_m2o(self): + """A cycle between two rows, with a post_update on the many-to-one""" + mapper(Ball, ball) + mapper(Person, person, properties=dict( + balls=relation(Ball, + primaryjoin=ball.c.person_id == person.c.id, + remote_side=ball.c.person_id, + post_update=False, + cascade="all, delete-orphan"), + favorite=relation(Ball, + primaryjoin=person.c.favorite_ball_id == ball.c.id, + remote_side=person.c.favorite_ball_id, + post_update=True))) + + b = Ball(data='some data') + p = Person(data='some data') + p.balls.append(b) + p.balls.append(Ball(data='some data')) + p.balls.append(Ball(data='some data')) + p.balls.append(Ball(data='some data')) + p.favorite = b + sess = create_session() + sess.add(b) + sess.add(p) + + self.assert_sql_execution( + testing.db, + sess.flush, + RegexSQL("^INSERT INTO person", {'data':'some data'}), + RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), + RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), + RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), + RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), + ExactSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id " + "WHERE person.id = :person_id", + lambda ctx:{'favorite_ball_id':p.favorite.id, 'person_id':p.id} + ), + ) + + sess.delete(p) + + self.assert_sql_execution( + testing.db, + sess.flush, + ExactSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id " + "WHERE person.id = :person_id", + lambda ctx: {'person_id': p.id, 'favorite_ball_id': None}), + ExactSQL("DELETE FROM ball WHERE ball.id = :id", None), # lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}]) + ExactSQL("DELETE FROM person WHERE person.id = :id", lambda ctx:[{'id': p.id}]) + ) + + @testing.resolve_artifact_names + def testpostupdate_o2m(self): + """A cycle between two rows, with a post_update on the one-to-many""" + + mapper(Ball, ball) + mapper(Person, person, properties=dict( + balls=relation(Ball, + primaryjoin=ball.c.person_id == person.c.id, + remote_side=ball.c.person_id, + cascade="all, delete-orphan", + post_update=True, + backref='person'), + favorite=relation(Ball, + primaryjoin=person.c.favorite_ball_id == ball.c.id, + remote_side=person.c.favorite_ball_id))) + + b = Ball(data='some data') + p = Person(data='some data') + p.balls.append(b) + b2 = Ball(data='some data') + p.balls.append(b2) + b3 = Ball(data='some data') + p.balls.append(b3) + b4 = Ball(data='some data') + p.balls.append(b4) + p.favorite = b + sess = create_session() + sess.add_all((b,p,b2,b3,b4)) + + self.assert_sql_execution( + testing.db, + sess.flush, + CompiledSQL("INSERT INTO ball (person_id, data) " + "VALUES (:person_id, :data)", + {'person_id':None, 'data':'some data'}), + + CompiledSQL("INSERT INTO ball (person_id, data) " + "VALUES (:person_id, :data)", + {'person_id':None, 'data':'some data'}), + + CompiledSQL("INSERT INTO ball (person_id, data) " + "VALUES (:person_id, :data)", + {'person_id':None, 'data':'some data'}), + + CompiledSQL("INSERT INTO ball (person_id, data) " + "VALUES (:person_id, :data)", + {'person_id':None, 'data':'some data'}), + + CompiledSQL("INSERT INTO person (favorite_ball_id, data) " + "VALUES (:favorite_ball_id, :data)", + lambda ctx:{'favorite_ball_id':b.id, 'data':'some data'}), + + AllOf( + CompiledSQL("UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx:{'person_id':p.id,'ball_id':b.id}), + + CompiledSQL("UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx:{'person_id':p.id,'ball_id':b2.id}), + + CompiledSQL("UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx:{'person_id':p.id,'ball_id':b3.id}), + + CompiledSQL("UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx:{'person_id':p.id,'ball_id':b4.id}) + ), + ) + + sess.delete(p) + + self.assert_sql_execution(testing.db, sess.flush, + AllOf(CompiledSQL("UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx:{'person_id': None, 'ball_id': b.id}), + + CompiledSQL("UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx:{'person_id': None, 'ball_id': b2.id}), + + CompiledSQL("UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx:{'person_id': None, 'ball_id': b3.id}), + + CompiledSQL("UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx:{'person_id': None, 'ball_id': b4.id})), + + CompiledSQL("DELETE FROM person WHERE person.id = :id", + lambda ctx:[{'id':p.id}]), + + CompiledSQL("DELETE FROM ball WHERE ball.id = :id", + lambda ctx:[{'id': b.id}, + {'id': b2.id}, + {'id': b3.id}, + {'id': b4.id}]) + ) + + +class SelfReferentialPostUpdateTest(_base.MappedTest): + """Post_update on a single self-referential mapper""" + + @classmethod + def define_tables(cls, metadata): + Table('node', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('path', String(50), nullable=False), + Column('parent_id', Integer, + ForeignKey('node.id'), nullable=True), + Column('prev_sibling_id', Integer, + ForeignKey('node.id'), nullable=True), + Column('next_sibling_id', Integer, + ForeignKey('node.id'), nullable=True)) + + @classmethod + def setup_classes(cls): + class Node(_base.BasicEntity): + def __init__(self, path=''): + self.path = path + + @testing.resolve_artifact_names + def testbasic(self): + """Post_update only fires off when needed. + + This test case used to produce many superfluous update statements, + particularly upon delete + + """ + + mapper(Node, node, properties={ + 'children': relation( + Node, + primaryjoin=node.c.id==node.c.parent_id, + lazy=True, + cascade="all", + backref=backref("parent", remote_side=node.c.id) + ), + 'prev_sibling': relation( + Node, + primaryjoin=node.c.prev_sibling_id==node.c.id, + remote_side=node.c.id, + lazy=True, + uselist=False), + 'next_sibling': relation( + Node, + primaryjoin=node.c.next_sibling_id==node.c.id, + remote_side=node.c.id, + lazy=True, + uselist=False, + post_update=True)}) + + session = create_session() + + def append_child(parent, child): + if parent.children: + parent.children[-1].next_sibling = child + child.prev_sibling = parent.children[-1] + parent.children.append(child) + + def remove_child(parent, child): + child.parent = None + node = child.next_sibling + node.prev_sibling = child.prev_sibling + child.prev_sibling.next_sibling = node + session.delete(child) + root = Node('root') + + about = Node('about') + cats = Node('cats') + stories = Node('stories') + bruce = Node('bruce') + + append_child(root, about) + assert(about.prev_sibling is None) + append_child(root, cats) + assert(cats.prev_sibling is about) + assert(cats.next_sibling is None) + assert(about.next_sibling is cats) + assert(about.prev_sibling is None) + append_child(root, stories) + append_child(root, bruce) + session.add(root) + session.flush() + + remove_child(root, cats) + # pre-trigger lazy loader on 'cats' to make the test easier + cats.children + self.assert_sql_execution( + testing.db, + session.flush, + CompiledSQL("UPDATE node SET prev_sibling_id=:prev_sibling_id " + "WHERE node.id = :node_id", + lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id}), + + CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " + "WHERE node.id = :node_id", + lambda ctx:{'next_sibling_id':stories.id, 'node_id':about.id}), + + CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " + "WHERE node.id = :node_id", + lambda ctx:{'next_sibling_id':None, 'node_id':cats.id}), + + CompiledSQL("DELETE FROM node WHERE node.id = :id", + lambda ctx:[{'id':cats.id}]) + ) + + +class SelfReferentialPostUpdateTest2(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table("a_table", metadata, + Column("id", Integer(), primary_key=True), + Column("fui", String(128)), + Column("b", Integer(), ForeignKey("a_table.id"))) + + @classmethod + def setup_classes(cls): + class A(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def testbasic(self): + """ + Test that post_update remembers to be involved in update operations as + well, since it replaces the normal dependency processing completely + [ticket:413] + + """ + + mapper(A, a_table, properties={ + 'foo': relation(A, + remote_side=[a_table.c.id], + post_update=True)}) + + session = create_session() + + f1 = A(fui="f1") + session.add(f1) + session.flush() + + f2 = A(fui="f2", foo=f1) + + # at this point f1 is already inserted. but we need post_update + # to fire off anyway + session.add(f2) + session.flush() + session.expunge_all() + + f1 = session.query(A).get(f1.id) + f2 = session.query(A).get(f2.id) + assert f2.foo is f1 + + diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py new file mode 100644 index 000000000..b063780ac --- /dev/null +++ b/test/orm/test_defaults.py @@ -0,0 +1,133 @@ + +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base +from sqlalchemy.test.testing import eq_ + + +class TriggerDefaultsTest(_base.MappedTest): + __requires__ = ('row_triggers',) + + @classmethod + def define_tables(cls, metadata): + dt = Table('dt', metadata, + Column('id', Integer, primary_key=True), + Column('col1', String(20)), + Column('col2', String(20), + server_default=sa.schema.FetchedValue()), + Column('col3', String(20), + sa.schema.FetchedValue(for_update=True)), + Column('col4', String(20), + sa.schema.FetchedValue(), + sa.schema.FetchedValue(for_update=True))) + for ins in ( + sa.DDL("CREATE TRIGGER dt_ins AFTER INSERT ON dt " + "FOR EACH ROW BEGIN " + "UPDATE dt SET col2='ins', col4='ins' " + "WHERE dt.id = NEW.id; END", + on='sqlite'), + sa.DDL("CREATE TRIGGER dt_ins ON dt AFTER INSERT AS " + "UPDATE dt SET col2='ins', col4='ins' " + "WHERE dt.id IN (SELECT id FROM inserted);", + on='mssql'), + ): + if testing.against(ins.on): + break + else: + ins = sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt " + "FOR EACH ROW BEGIN " + "SET NEW.col2='ins'; SET NEW.col4='ins'; END") + ins.execute_at('after-create', dt) + sa.DDL("DROP TRIGGER dt_ins").execute_at('before-drop', dt) + + + for up in ( + sa.DDL("CREATE TRIGGER dt_up AFTER UPDATE ON dt " + "FOR EACH ROW BEGIN " + "UPDATE dt SET col3='up', col4='up' " + "WHERE dt.id = OLD.id; END", + on='sqlite'), + sa.DDL("CREATE TRIGGER dt_up ON dt AFTER UPDATE AS " + "UPDATE dt SET col3='up', col4='up' " + "WHERE dt.id IN (SELECT id FROM deleted);", + on='mssql'), + ): + if testing.against(up.on): + break + else: + up = sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt " + "FOR EACH ROW BEGIN " + "SET NEW.col3='up'; SET NEW.col4='up'; END") + up.execute_at('after-create', dt) + sa.DDL("DROP TRIGGER dt_up").execute_at('before-drop', dt) + + + @classmethod + def setup_classes(cls): + class Default(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Default, dt) + + @testing.resolve_artifact_names + def test_insert(self): + + d1 = Default(id=1) + + eq_(d1.col1, None) + eq_(d1.col2, None) + eq_(d1.col3, None) + eq_(d1.col4, None) + + session = create_session() + session.add(d1) + session.flush() + + eq_(d1.col1, None) + eq_(d1.col2, 'ins') + eq_(d1.col3, None) + # don't care which trigger fired + assert d1.col4 in ('ins', 'up') + + @testing.resolve_artifact_names + def test_update(self): + d1 = Default(id=1) + + session = create_session() + session.add(d1) + session.flush() + d1.col1 = 'set' + session.flush() + + eq_(d1.col1, 'set') + eq_(d1.col2, 'ins') + eq_(d1.col3, 'up') + eq_(d1.col4, 'up') + +class ExcludedDefaultsTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + dt = Table('dt', metadata, + Column('id', Integer, primary_key=True), + Column('col1', String(20), default="hello"), + ) + + @testing.resolve_artifact_names + def test_exclude(self): + class Foo(_base.ComparableEntity): + pass + mapper(Foo, dt, exclude_properties=('col1',)) + + f1 = Foo() + sess = create_session() + sess.add(f1) + sess.flush() + eq_(dt.select().execute().fetchall(), [(1, "hello")]) + diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py new file mode 100644 index 000000000..00d64119e --- /dev/null +++ b/test/orm/test_deprecations.py @@ -0,0 +1,486 @@ +"""The collection of modern alternatives to deprecated & removed functionality. + +Collects specimens of old ORM code and explicitly covers the recommended +modern (i.e. not deprecated) alternative to them. The tests snippets here can +be migrated directly to the wiki, docs, etc. + +""" +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, sessionmaker +from test.orm import _base + + +class QueryAlternativesTest(_base.MappedTest): + '''Collects modern idioms for Queries + + The docstring for each test case serves as miniature documentation about + the deprecated use case, and the test body illustrates (and covers) the + intended replacement code to accomplish the same task. + + Documenting the "old way" including the argument signature helps these + cases remain useful to readers even after the deprecated method has been + removed from the modern codebase. + + Format: + + def test_deprecated_thing(self): + """Query.methodname(old, arg, **signature) + + output = session.query(User).deprecatedmethod(inputs) + + """ + # 0.4+ + output = session.query(User).newway(inputs) + assert output is correct + + # 0.5+ + output = session.query(User).evennewerway(inputs) + assert output is correct + + ''' + + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + Table('users_table', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(64))) + + Table('addresses_table', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', Integer, ForeignKey('users_table.id')), + Column('email_address', String(128)), + Column('purpose', String(16)), + Column('bounces', Integer, default=0)) + + @classmethod + def setup_classes(cls): + class User(_base.BasicEntity): + pass + + class Address(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(User, users_table, properties=dict( + addresses=relation(Address, backref='user'), + )) + mapper(Address, addresses_table) + + @classmethod + def fixtures(cls): + return dict( + users_table=( + ('id', 'name'), + (1, 'jack'), + (2, 'ed'), + (3, 'fred'), + (4, 'chuck')), + + addresses_table=( + ('id', 'user_id', 'email_address', 'purpose', 'bounces'), + (1, 1, 'jack@jack.home', 'Personal', 0), + (2, 1, 'jack@jack.bizz', 'Work', 1), + (3, 2, 'ed@foo.bar', 'Personal', 0), + (4, 3, 'fred@the.fred', 'Personal', 10))) + + + ###################################################################### + + @testing.resolve_artifact_names + def test_override_get(self): + """MapperExtension.get() + + x = session.query.get(5) + + """ + from sqlalchemy.orm.query import Query + cache = {} + class MyQuery(Query): + def get(self, ident, **kwargs): + if ident in cache: + return cache[ident] + else: + x = super(MyQuery, self).get(ident) + cache[ident] = x + return x + + session = sessionmaker(query_cls=MyQuery)() + + ad1 = session.query(Address).get(1) + assert ad1 in cache.values() + + @testing.resolve_artifact_names + def test_load(self): + """x = session.query(Address).load(1) + + x = session.load(Address, 1) + + """ + + session = create_session() + ad1 = session.query(Address).populate_existing().get(1) + assert bool(ad1) + + + @testing.resolve_artifact_names + def test_apply_max(self): + """Query.apply_max(col) + + max = session.query(Address).apply_max(Address.bounces) + + """ + session = create_session() + + # 0.5.0 + maxes = list(session.query(Address).values(func.max(Address.bounces))) + max = maxes[0][0] + assert max == 10 + + max = session.query(func.max(Address.bounces)).one()[0] + assert max == 10 + + @testing.resolve_artifact_names + def test_apply_min(self): + """Query.apply_min(col) + + min = session.query(Address).apply_min(Address.bounces) + + """ + session = create_session() + + # 0.5.0 + mins = list(session.query(Address).values(func.min(Address.bounces))) + min = mins[0][0] + assert min == 0 + + min = session.query(func.min(Address.bounces)).one()[0] + assert min == 0 + + @testing.resolve_artifact_names + def test_apply_avg(self): + """Query.apply_avg(col) + + avg = session.query(Address).apply_avg(Address.bounces) + + """ + session = create_session() + + avgs = list(session.query(Address).values(func.avg(Address.bounces))) + avg = avgs[0][0] + assert avg > 0 and avg < 10 + + avg = session.query(func.avg(Address.bounces)).one()[0] + assert avg > 0 and avg < 10 + + @testing.resolve_artifact_names + def test_apply_sum(self): + """Query.apply_sum(col) + + avg = session.query(Address).apply_avg(Address.bounces) + + """ + session = create_session() + + avgs = list(session.query(Address).values(func.sum(Address.bounces))) + avg = avgs[0][0] + assert avg == 11 + + avg = session.query(func.sum(Address.bounces)).one()[0] + assert avg == 11 + + @testing.resolve_artifact_names + def test_count_by(self): + """Query.count_by(*args, **params) + + num = session.query(Address).count_by(purpose='Personal') + + # old-style implicit *_by join + num = session.query(User).count_by(purpose='Personal') + + """ + session = create_session() + + num = session.query(Address).filter_by(purpose='Personal').count() + assert num == 3, num + + num = (session.query(User).join('addresses'). + filter(Address.purpose=='Personal')).count() + assert num == 3, num + + @testing.resolve_artifact_names + def test_count_whereclause(self): + """Query.count(whereclause=None, params=None, **kwargs) + + num = session.query(Address).count(address_table.c.bounces > 1) + + """ + session = create_session() + + num = session.query(Address).filter(Address.bounces > 1).count() + assert num == 1, num + + @testing.resolve_artifact_names + def test_execute(self): + """Query.execute(clauseelement, params=None, *args, **kwargs) + + users = session.query(User).execute(users_table.select()) + + """ + session = create_session() + + users = session.query(User).from_statement(users_table.select()).all() + assert len(users) == 4 + + @testing.resolve_artifact_names + def test_get_by(self): + """Query.get_by(*args, **params) + + user = session.query(User).get_by(name='ed') + + # 0.3-style implicit *_by join + user = session.query(User).get_by(email_addresss='fred@the.fred') + + """ + session = create_session() + + user = session.query(User).filter_by(name='ed').first() + assert user.name == 'ed' + + user = (session.query(User).join('addresses'). + filter(Address.email_address=='fred@the.fred')).first() + assert user.name == 'fred' + + user = session.query(User).filter( + User.addresses.any(Address.email_address=='fred@the.fred')).first() + assert user.name == 'fred' + + @testing.resolve_artifact_names + def test_instances_entities(self): + """Query.instances(cursor, *mappers_or_columns, **kwargs) + + sel = users_table.join(addresses_table).select(use_labels=True) + res = session.query(User).instances(sel.execute(), Address) + + """ + session = create_session() + + sel = users_table.join(addresses_table).select(use_labels=True) + res = list(session.query(User, Address).instances(sel.execute())) + + assert len(res) == 4 + cola, colb = res[0] + assert isinstance(cola, User) and isinstance(colb, Address) + + @testing.resolve_artifact_names + def test_join_by(self): + """Query.join_by(*args, **params) + + TODO + """ + session = create_session() + + + @testing.resolve_artifact_names + def test_join_to(self): + """Query.join_to(key) + + TODO + """ + session = create_session() + + + @testing.resolve_artifact_names + def test_join_via(self): + """Query.join_via(keys) + + TODO + """ + session = create_session() + + + @testing.resolve_artifact_names + def test_list(self): + """Query.list() + + users = session.query(User).list() + + """ + session = create_session() + + users = session.query(User).all() + assert len(users) == 4 + + @testing.resolve_artifact_names + def test_scalar(self): + """Query.scalar() + + user = session.query(User).filter(User.id==1).scalar() + + """ + session = create_session() + + user = session.query(User).filter(User.id==1).first() + assert user.id==1 + + @testing.resolve_artifact_names + def test_select(self): + """Query.select(arg=None, **kwargs) + + users = session.query(User).select(users_table.c.name != None) + + """ + session = create_session() + + users = session.query(User).filter(User.name != None).all() + assert len(users) == 4 + + @testing.resolve_artifact_names + def test_select_by(self): + """Query.select_by(*args, **params) + + users = session.query(User).select_by(name='fred') + + # 0.3 magic join on *_by methods + users = session.query(User).select_by(email_address='fred@the.fred') + + """ + session = create_session() + + users = session.query(User).filter_by(name='fred').all() + assert len(users) == 1 + + users = session.query(User).filter(User.name=='fred').all() + assert len(users) == 1 + + users = (session.query(User).join('addresses'). + filter_by(email_address='fred@the.fred')).all() + assert len(users) == 1 + + users = session.query(User).filter(User.addresses.any( + Address.email_address == 'fred@the.fred')).all() + assert len(users) == 1 + + @testing.resolve_artifact_names + def test_selectfirst(self): + """Query.selectfirst(arg=None, **kwargs) + + bounced = session.query(Address).selectfirst( + addresses_table.c.bounces > 0) + + """ + session = create_session() + + bounced = session.query(Address).filter(Address.bounces > 0).first() + assert bounced.bounces > 0 + + @testing.resolve_artifact_names + def test_selectfirst_by(self): + """Query.selectfirst_by(*args, **params) + + onebounce = session.query(Address).selectfirst_by(bounces=1) + + # 0.3 magic join on *_by methods + onebounce_user = session.query(User).selectfirst_by(bounces=1) + + """ + session = create_session() + + onebounce = session.query(Address).filter_by(bounces=1).first() + assert onebounce.bounces == 1 + + onebounce_user = (session.query(User).join('addresses'). + filter_by(bounces=1)).first() + assert onebounce_user.name == 'jack' + + onebounce_user = (session.query(User).join('addresses'). + filter(Address.bounces == 1)).first() + assert onebounce_user.name == 'jack' + + onebounce_user = session.query(User).filter(User.addresses.any( + Address.bounces == 1)).first() + assert onebounce_user.name == 'jack' + + @testing.resolve_artifact_names + def test_selectone(self): + """Query.selectone(arg=None, **kwargs) + + ed = session.query(User).selectone(users_table.c.name == 'ed') + + """ + session = create_session() + + ed = session.query(User).filter(User.name == 'jack').one() + + @testing.resolve_artifact_names + def test_selectone_by(self): + """Query.selectone_by + + ed = session.query(User).selectone_by(name='ed') + + # 0.3 magic join on *_by methods + ed = session.query(User).selectone_by(email_address='ed@foo.bar') + + """ + session = create_session() + + ed = session.query(User).filter_by(name='jack').one() + + ed = session.query(User).filter(User.name == 'jack').one() + + ed = session.query(User).join('addresses').filter( + Address.email_address == 'ed@foo.bar').one() + + ed = session.query(User).filter(User.addresses.any( + Address.email_address == 'ed@foo.bar')).one() + + @testing.resolve_artifact_names + def test_select_statement(self): + """Query.select_statement(statement, **params) + + users = session.query(User).select_statement(users_table.select()) + + """ + session = create_session() + + users = session.query(User).from_statement(users_table.select()).all() + assert len(users) == 4 + + @testing.resolve_artifact_names + def test_select_text(self): + """Query.select_text(text, **params) + + users = session.query(User).select_text('SELECT * FROM users_table') + + """ + session = create_session() + + users = (session.query(User). + from_statement('SELECT * FROM users_table')).all() + assert len(users) == 4 + + @testing.resolve_artifact_names + def test_select_whereclause(self): + """Query.select_whereclause(whereclause=None, params=None, **kwargs) + + + users = session,query(User).select_whereclause(users.c.name=='ed') + users = session.query(User).select_whereclause("name='ed'") + + """ + session = create_session() + + users = session.query(User).filter(User.name=='ed').all() + assert len(users) == 1 and users[0].name == 'ed' + + users = session.query(User).filter("name='ed'").all() + assert len(users) == 1 and users[0].name == 'ed' + + diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py new file mode 100644 index 000000000..f2089a435 --- /dev/null +++ b/test/orm/test_dynamic.py @@ -0,0 +1,561 @@ +from sqlalchemy.test.testing import eq_ +import operator +from sqlalchemy.orm import dynamic_loader, backref +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, desc, select, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, Query, attributes +from sqlalchemy.orm.dynamic import AppenderMixin +from sqlalchemy.test.testing import eq_ +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures + + +class DynamicTest(_fixtures.FixtureTest): + @testing.resolve_artifact_names + def test_basic(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + q = sess.query(User) + + u = q.filter(User.id==7).first() + eq_([User(id=7, + addresses=[Address(id=1, email_address='jack@bean.com')])], + q.filter(User.id==7).all()) + eq_(self.static.user_address_result, q.all()) + + @testing.resolve_artifact_names + def test_order_by(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + u = sess.query(User).get(8) + eq_(list(u.addresses.order_by(desc(Address.email_address))), [Address(email_address=u'ed@wood.com'), Address(email_address=u'ed@lala.com'), Address(email_address=u'ed@bettyboop.com')]) + + @testing.resolve_artifact_names + def test_configured_order_by(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), order_by=desc(Address.email_address)) + }) + sess = create_session() + u = sess.query(User).get(8) + eq_(list(u.addresses), [Address(email_address=u'ed@wood.com'), Address(email_address=u'ed@lala.com'), Address(email_address=u'ed@bettyboop.com')]) + + # test cancellation of None, replacement with something else + eq_( + list(u.addresses.order_by(None).order_by(Address.email_address)), + [Address(email_address=u'ed@bettyboop.com'), Address(email_address=u'ed@lala.com'), Address(email_address=u'ed@wood.com')] + ) + + # test cancellation of None, replacement with nothing + eq_( + set(u.addresses.order_by(None)), + set([Address(email_address=u'ed@bettyboop.com'), Address(email_address=u'ed@lala.com'), Address(email_address=u'ed@wood.com')]) + ) + + @testing.resolve_artifact_names + def test_count(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + u = sess.query(User).first() + eq_(u.addresses.count(), 1) + + @testing.resolve_artifact_names + def test_backref(self): + mapper(Address, addresses, properties={ + 'user':relation(User, backref=backref('addresses', lazy='dynamic')) + }) + mapper(User, users) + + sess = create_session() + ad = sess.query(Address).get(1) + def go(): + ad.user = None + self.assert_sql_count(testing.db, go, 1) + sess.flush() + u = sess.query(User).get(7) + assert ad not in u.addresses + + @testing.resolve_artifact_names + def test_no_count(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + q = sess.query(User) + + # dynamic collection cannot implement __len__() (at least one that + # returns a live database result), else additional count() queries are + # issued when evaluating in a list context + def go(): + eq_([User(id=7, + addresses=[Address(id=1, + email_address='jack@bean.com')])], + q.filter(User.id==7).all()) + self.assert_sql_count(testing.db, go, 2) + + @testing.resolve_artifact_names + def test_m2m(self): + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy="dynamic", + backref=backref('orders', lazy="dynamic")) + }) + mapper(Item, items) + + sess = create_session() + o1 = Order(id=15, description="order 10") + i1 = Item(id=10, description="item 8") + o1.items.append(i1) + sess.add(o1) + sess.flush() + + assert o1 in i1.orders.all() + assert i1 in o1.items.all() + + @testing.resolve_artifact_names + def test_transient_detached(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + u1 = User() + u1.addresses.append(Address()) + assert u1.addresses.count() == 1 + assert u1.addresses[0] == Address() + + @testing.resolve_artifact_names + def test_custom_query(self): + class MyQuery(Query): + pass + + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), + query_class=MyQuery) + }) + sess = create_session() + u = User() + sess.add(u) + + col = u.addresses + assert isinstance(col, Query) + assert isinstance(col, MyQuery) + assert hasattr(col, 'append') + assert type(col).__name__ == 'AppenderMyQuery' + + q = col.limit(1) + assert isinstance(q, Query) + assert isinstance(q, MyQuery) + assert not hasattr(q, 'append') + assert type(q).__name__ == 'MyQuery' + + @testing.resolve_artifact_names + def test_custom_query_with_custom_mixin(self): + class MyAppenderMixin(AppenderMixin): + def add(self, items): + if isinstance(items, list): + for item in items: + self.append(item) + else: + self.append(items) + + class MyQuery(Query): + pass + + class MyAppenderQuery(MyAppenderMixin, MyQuery): + query_class = MyQuery + + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), + query_class=MyAppenderQuery) + }) + sess = create_session() + u = User() + sess.add(u) + + col = u.addresses + assert isinstance(col, Query) + assert isinstance(col, MyQuery) + assert hasattr(col, 'append') + assert hasattr(col, 'add') + assert type(col).__name__ == 'MyAppenderQuery' + + q = col.limit(1) + assert isinstance(q, Query) + assert isinstance(q, MyQuery) + assert not hasattr(q, 'append') + assert not hasattr(q, 'add') + assert type(q).__name__ == 'MyQuery' + + +class SessionTest(_fixtures.FixtureTest): + run_inserts = None + + @testing.resolve_artifact_names + def test_events(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + u1 = User(name='jack') + a1 = Address(email_address='foo') + sess.add_all([u1, a1]) + sess.flush() + + assert testing.db.scalar(select([func.count(1)]).where(addresses.c.user_id!=None)) == 0 + u1 = sess.query(User).get(u1.id) + u1.addresses.append(a1) + sess.flush() + + assert testing.db.execute(select([addresses]).where(addresses.c.user_id!=None)).fetchall() == [ + (a1.id, u1.id, 'foo') + ] + + u1.addresses.remove(a1) + sess.flush() + assert testing.db.scalar(select([func.count(1)]).where(addresses.c.user_id!=None)) == 0 + + u1.addresses.append(a1) + sess.flush() + assert testing.db.execute(select([addresses]).where(addresses.c.user_id!=None)).fetchall() == [ + (a1.id, u1.id, 'foo') + ] + + a2= Address(email_address='bar') + u1.addresses.remove(a1) + u1.addresses.append(a2) + sess.flush() + assert testing.db.execute(select([addresses]).where(addresses.c.user_id!=None)).fetchall() == [ + (a2.id, u1.id, 'bar') + ] + + + @testing.resolve_artifact_names + def test_merge(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), order_by=addresses.c.email_address) + }) + sess = create_session() + u1 = User(name='jack') + a1 = Address(email_address='a1') + a2 = Address(email_address='a2') + a3 = Address(email_address='a3') + + u1.addresses.append(a2) + u1.addresses.append(a3) + + sess.add_all([u1, a1]) + sess.flush() + + u1 = User(id=u1.id, name='jack') + u1.addresses.append(a1) + u1.addresses.append(a3) + u1 = sess.merge(u1) + assert attributes.get_history(u1, 'addresses') == ( + [a1], + [a3], + [a2] + ) + + sess.flush() + + eq_( + list(u1.addresses), + [a1, a3] + ) + + @testing.resolve_artifact_names + def test_flush(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session() + u1 = User(name='jack') + u2 = User(name='ed') + u2.addresses.append(Address(email_address='foo@bar.com')) + u1.addresses.append(Address(email_address='lala@hoho.com')) + sess.add_all((u1, u2)) + sess.flush() + + from sqlalchemy.orm import attributes + eq_(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], [])) + + sess.expunge_all() + + # test the test fixture a little bit + assert User(name='jack', addresses=[Address(email_address='wrong')]) != sess.query(User).first() + assert User(name='jack', addresses=[Address(email_address='lala@hoho.com')]) == sess.query(User).first() + + assert [ + User(name='jack', addresses=[Address(email_address='lala@hoho.com')]), + User(name='ed', addresses=[Address(email_address='foo@bar.com')]) + ] == sess.query(User).all() + + @testing.resolve_artifact_names + def test_hasattr(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + u1 = User(name='jack') + + assert 'addresses' not in u1.__dict__.keys() + u1.addresses = [Address(email_address='test')] + assert 'addresses' in dir(u1) + + @testing.resolve_artifact_names + def test_collection_set(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), order_by=addresses.c.email_address) + }) + sess = create_session(autoflush=True, autocommit=False) + u1 = User(name='jack') + a1 = Address(email_address='a1') + a2 = Address(email_address='a2') + a3 = Address(email_address='a3') + a4 = Address(email_address='a4') + + sess.add(u1) + u1.addresses = [a1, a3] + assert list(u1.addresses) == [a1, a3] + u1.addresses = [a1, a2, a4] + assert list(u1.addresses) == [a1, a2, a4] + u1.addresses = [a2, a3] + assert list(u1.addresses) == [a2, a3] + u1.addresses = [] + assert list(u1.addresses) == [] + + + + + @testing.resolve_artifact_names + def test_rollback(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses)) + }) + sess = create_session(expire_on_commit=False, autocommit=False, autoflush=True) + u1 = User(name='jack') + u1.addresses.append(Address(email_address='lala@hoho.com')) + sess.add(u1) + sess.flush() + sess.commit() + u1.addresses.append(Address(email_address='foo@bar.com')) + eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')]) + sess.rollback() + eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com')]) + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_delete_nocascade(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), order_by=Address.id, backref='user') + }) + sess = create_session(autoflush=True) + u = User(name='ed') + u.addresses.append(Address(email_address='a')) + u.addresses.append(Address(email_address='b')) + u.addresses.append(Address(email_address='c')) + u.addresses.append(Address(email_address='d')) + u.addresses.append(Address(email_address='e')) + u.addresses.append(Address(email_address='f')) + sess.add(u) + + assert Address(email_address='c') == u.addresses[2] + sess.delete(u.addresses[2]) + sess.delete(u.addresses[4]) + sess.delete(u.addresses[3]) + assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses) + + sess.expunge_all() + u = sess.query(User).get(u.id) + + sess.delete(u) + + # u.addresses relation will have to force the load + # of all addresses so that they can be updated + sess.flush() + sess.close() + + assert testing.db.scalar(addresses.count(addresses.c.user_id != None)) ==0 + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_delete_cascade(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), order_by=Address.id, backref='user', cascade="all, delete-orphan") + }) + sess = create_session(autoflush=True) + u = User(name='ed') + u.addresses.append(Address(email_address='a')) + u.addresses.append(Address(email_address='b')) + u.addresses.append(Address(email_address='c')) + u.addresses.append(Address(email_address='d')) + u.addresses.append(Address(email_address='e')) + u.addresses.append(Address(email_address='f')) + sess.add(u) + + assert Address(email_address='c') == u.addresses[2] + sess.delete(u.addresses[2]) + sess.delete(u.addresses[4]) + sess.delete(u.addresses[3]) + assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses) + + sess.expunge_all() + u = sess.query(User).get(u.id) + + sess.delete(u) + + # u.addresses relation will have to force the load + # of all addresses so that they can be updated + sess.flush() + sess.close() + + assert testing.db.scalar(addresses.count()) ==0 + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_remove_orphans(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), order_by=Address.id, cascade="all, delete-orphan", backref='user') + }) + sess = create_session(autoflush=True) + u = User(name='ed') + u.addresses.append(Address(email_address='a')) + u.addresses.append(Address(email_address='b')) + u.addresses.append(Address(email_address='c')) + u.addresses.append(Address(email_address='d')) + u.addresses.append(Address(email_address='e')) + u.addresses.append(Address(email_address='f')) + sess.add(u) + + assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='c'), + Address(email_address='d'), Address(email_address='e'), Address(email_address='f')] == sess.query(Address).all() + + assert Address(email_address='c') == u.addresses[2] + + try: + del u.addresses[3] + assert False + except TypeError, e: + assert "doesn't support item deletion" in str(e), str(e) + + for a in u.addresses.filter(Address.email_address.in_(['c', 'e', 'f'])): + u.addresses.remove(a) + + assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == list(u.addresses) + + assert [Address(email_address='a'), Address(email_address='b'), Address(email_address='d')] == sess.query(Address).all() + + sess.delete(u) + sess.close() + + +def _create_backref_test(autoflush, saveuser): + + @testing.resolve_artifact_names + def test_backref(self): + mapper(User, users, properties={ + 'addresses':dynamic_loader(mapper(Address, addresses), backref='user') + }) + sess = create_session(autoflush=autoflush) + + u = User(name='buffy') + + a = Address(email_address='foo@bar.com') + a.user = u + + if saveuser: + sess.add(u) + else: + sess.add(a) + + if not autoflush: + sess.flush() + + assert u in sess + assert a in sess + + self.assert_(list(u.addresses) == [a]) + + a.user = None + if not autoflush: + self.assert_(list(u.addresses) == [a]) + + if not autoflush: + sess.flush() + self.assert_(list(u.addresses) == []) + + test_backref = function_named( + test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""), + (saveuser and "_saveuser" or "_savead"))) + setattr(SessionTest, test_backref.__name__, test_backref) + +for autoflush in (False, True): + for saveuser in (False, True): + _create_backref_test(autoflush, saveuser) + +class DontDereferenceTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(40)), + Column('fullname', String(100)), + Column('password', String(15))) + + Table('addresses', metadata, + Column('id', Integer, primary_key=True), + Column('email_address', String(100), nullable=False), + Column('user_id', Integer, ForeignKey('users.id'))) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + class User(_base.ComparableEntity): + pass + + class Address(_base.ComparableEntity): + pass + + mapper(User, users, properties={ + 'addresses': relation(Address, backref='user', lazy='dynamic') + }) + mapper(Address, addresses) + + @testing.resolve_artifact_names + def test_no_deref(self): + session = create_session() + user = User() + user.name = 'joe' + user.fullname = 'Joe User' + user.password = 'Joe\'s secret' + address = Address() + address.email_address = 'joe@joesdomain.example' + address.user = user + session.add(user) + session.flush() + session.expunge_all() + + def query1(): + session = create_session(testing.db) + user = session.query(User).first() + return user.addresses.all() + + def query2(): + session = create_session(testing.db) + return session.query(User).first().addresses.all() + + def query3(): + session = create_session(testing.db) + user = session.query(User).first() + return session.query(User).first().addresses.all() + + eq_(query1(), [Address(email_address='joe@joesdomain.example')]) + eq_(query2(), [Address(email_address='joe@joesdomain.example')]) + eq_(query3(), [Address(email_address='joe@joesdomain.example')]) + + diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py new file mode 100644 index 000000000..384e0472f --- /dev/null +++ b/test/orm/test_eager_relations.py @@ -0,0 +1,1608 @@ +"""basic tests of eager loaded attributes""" + +from sqlalchemy.test.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy.orm import eagerload, deferred, undefer +from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, lazyload, aliased +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.assertsql import CompiledSQL +from test.orm import _base, _fixtures +import datetime + +class EagerTest(_fixtures.FixtureTest): + run_inserts = 'once' + run_deletes = None + + @testing.resolve_artifact_names + def test_basic(self): + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=Address.id) + }) + sess = create_session() + q = sess.query(User) + + assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all() + eq_(self.static.user_address_result, q.order_by(User.id).all()) + + @testing.resolve_artifact_names + def test_late_compile(self): + m = mapper(User, users) + sess = create_session() + sess.query(User).all() + m.add_property("addresses", relation(mapper(Address, addresses))) + + sess.expunge_all() + def go(): + eq_( + [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])], + sess.query(User).options(eagerload('addresses')).filter(User.id==7).all() + ) + self.assert_sql_count(testing.db, go, 1) + + + @testing.resolve_artifact_names + def test_no_orphan(self): + """An eagerly loaded child object is not marked as an orphan""" + mapper(User, users, properties={ + 'addresses':relation(Address, cascade="all,delete-orphan", lazy=False) + }) + mapper(Address, addresses) + + sess = create_session() + user = sess.query(User).get(7) + assert getattr(User, 'addresses').hasparent(sa.orm.attributes.instance_state(user.addresses[0]), optimistic=True) + assert not sa.orm.class_mapper(Address)._is_orphan(sa.orm.attributes.instance_state(user.addresses[0])) + + @testing.resolve_artifact_names + def test_orderby(self): + mapper(User, users, properties = { + 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.email_address), + }) + q = create_session().query(User) + assert [ + User(id=7, addresses=[ + Address(id=1) + ]), + User(id=8, addresses=[ + Address(id=3, email_address='ed@bettyboop.com'), + Address(id=4, email_address='ed@lala.com'), + Address(id=2, email_address='ed@wood.com') + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=10, addresses=[]) + ] == q.order_by(User.id).all() + + @testing.resolve_artifact_names + def test_orderby_multi(self): + mapper(User, users, properties = { + 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=[addresses.c.email_address, addresses.c.id]), + }) + q = create_session().query(User) + assert [ + User(id=7, addresses=[ + Address(id=1) + ]), + User(id=8, addresses=[ + Address(id=3, email_address='ed@bettyboop.com'), + Address(id=4, email_address='ed@lala.com'), + Address(id=2, email_address='ed@wood.com') + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=10, addresses=[]) + ] == q.order_by(User.id).all() + + @testing.resolve_artifact_names + def test_orderby_related(self): + """A regular mapper select on a single table can order by a relation to a second table""" + mapper(Address, addresses) + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=False, order_by=addresses.c.id), + )) + + q = create_session().query(User) + l = q.filter(User.id==Address.user_id).order_by(Address.email_address).all() + + assert [ + User(id=8, addresses=[ + Address(id=2, email_address='ed@wood.com'), + Address(id=3, email_address='ed@bettyboop.com'), + Address(id=4, email_address='ed@lala.com'), + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=7, addresses=[ + Address(id=1) + ]), + ] == l + + @testing.resolve_artifact_names + def test_orderby_desc(self): + mapper(Address, addresses) + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=False, + order_by=[sa.desc(addresses.c.email_address)]), + )) + sess = create_session() + assert [ + User(id=7, addresses=[ + Address(id=1) + ]), + User(id=8, addresses=[ + Address(id=2, email_address='ed@wood.com'), + Address(id=4, email_address='ed@lala.com'), + Address(id=3, email_address='ed@bettyboop.com'), + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=10, addresses=[]) + ] == sess.query(User).order_by(User.id).all() + + @testing.resolve_artifact_names + def test_deferred_fk_col(self): + User, Address, Dingaling = self.classes.get_all( + 'User', 'Address', 'Dingaling') + users, addresses, dingalings = self.tables.get_all( + 'users', 'addresses', 'dingalings') + + mapper(Address, addresses, properties={ + 'user_id':deferred(addresses.c.user_id), + 'user':relation(User, lazy=False) + }) + mapper(User, users) + + sess = create_session() + + for q in [ + sess.query(Address).filter(Address.id.in_([1, 4, 5])), + sess.query(Address).filter(Address.id.in_([1, 4, 5])).limit(3) + ]: + sess.expunge_all() + eq_(q.all(), + [Address(id=1, user=User(id=7)), + Address(id=4, user=User(id=8)), + Address(id=5, user=User(id=9))] + ) + + a = sess.query(Address).filter(Address.id==1).first() + def go(): + eq_(a.user_id, 7) + # assert that the eager loader added 'user_id' to the row and deferred + # loading of that col was disabled + self.assert_sql_count(testing.db, go, 0) + + # do the mapping in reverse + # (we would have just used an "addresses" backref but the test + # fixtures then require the whole backref to be set up, lazy loaders + # trigger, etc.) + sa.orm.clear_mappers() + + mapper(Address, addresses, properties={ + 'user_id':deferred(addresses.c.user_id), + }) + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False)}) + + for q in [ + sess.query(User).filter(User.id==7), + sess.query(User).filter(User.id==7).limit(1) + ]: + sess.expunge_all() + eq_(q.all(), + [User(id=7, addresses=[Address(id=1)])] + ) + + sess.expunge_all() + u = sess.query(User).get(7) + def go(): + assert u.addresses[0].user_id==7 + # assert that the eager loader didn't have to affect 'user_id' here + # and that its still deferred + self.assert_sql_count(testing.db, go, 1) + + sa.orm.clear_mappers() + + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False)}) + mapper(Address, addresses, properties={ + 'user_id':deferred(addresses.c.user_id), + 'dingalings':relation(Dingaling, lazy=False)}) + mapper(Dingaling, dingalings, properties={ + 'address_id':deferred(dingalings.c.address_id)}) + sess.expunge_all() + def go(): + u = sess.query(User).get(8) + eq_(User(id=8, + addresses=[Address(id=2, dingalings=[Dingaling(id=1)]), + Address(id=3), + Address(id=4)]), + u) + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_many_to_many(self): + Keyword, Item = self.Keyword, self.Item + keywords, item_keywords, items = self.tables.get_all( + 'keywords', 'item_keywords', 'items') + + mapper(Keyword, keywords) + mapper(Item, items, properties = dict( + keywords = relation(Keyword, secondary=item_keywords, + lazy=False, order_by=keywords.c.id))) + + q = create_session().query(Item).order_by(Item.id) + def go(): + assert self.static.item_keyword_result == q.all() + self.assert_sql_count(testing.db, go, 1) + + def go(): + eq_(self.static.item_keyword_result[0:2], + q.join('keywords').filter(Keyword.name == 'red').all()) + self.assert_sql_count(testing.db, go, 1) + + def go(): + eq_(self.static.item_keyword_result[0:2], + (q.join('keywords', aliased=True). + filter(Keyword.name == 'red')).all()) + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_eager_option(self): + Keyword, Item = self.Keyword, self.Item + keywords, item_keywords, items = self.tables.get_all( + 'keywords', 'item_keywords', 'items') + + mapper(Keyword, keywords) + mapper(Item, items, properties = dict( + keywords = relation(Keyword, secondary=item_keywords, lazy=True, + order_by=keywords.c.id))) + + q = create_session().query(Item) + + def go(): + eq_(self.static.item_keyword_result[0:2], + (q.options(eagerload('keywords')). + join('keywords').filter(keywords.c.name == 'red')).order_by(Item.id).all()) + + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_cyclical(self): + """A circular eager relationship breaks the cycle with a lazy loader""" + User, Address = self.User, self.Address + users, addresses = self.tables.get_all('users', 'addresses') + + mapper(Address, addresses) + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=False, + backref=sa.orm.backref('user', lazy=False), order_by=Address.id) + )) + assert sa.orm.class_mapper(User).get_property('addresses').lazy is False + assert sa.orm.class_mapper(Address).get_property('user').lazy is False + + sess = create_session() + eq_(self.static.user_address_result, sess.query(User).order_by(User.id).all()) + + @testing.resolve_artifact_names + def test_double(self): + """Eager loading with two relations simultaneously, from the same table, using aliases.""" + User, Address, Order = self.classes.get_all( + 'User', 'Address', 'Order') + users, addresses, orders = self.tables.get_all( + 'users', 'addresses', 'orders') + + openorders = sa.alias(orders, 'openorders') + closedorders = sa.alias(orders, 'closedorders') + + mapper(Address, addresses) + mapper(Order, orders) + + open_mapper = mapper(Order, openorders, non_primary=True) + closed_mapper = mapper(Order, closedorders, non_primary=True) + + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=False, order_by=addresses.c.id), + open_orders = relation( + open_mapper, + primaryjoin=sa.and_(openorders.c.isopen == 1, + users.c.id==openorders.c.user_id), + lazy=False, order_by=openorders.c.id), + closed_orders = relation( + closed_mapper, + primaryjoin=sa.and_(closedorders.c.isopen == 0, + users.c.id==closedorders.c.user_id), + lazy=False, order_by=closedorders.c.id))) + + q = create_session().query(User).order_by(User.id) + + def go(): + assert [ + User( + id=7, + addresses=[Address(id=1)], + open_orders = [Order(id=3)], + closed_orders = [Order(id=1), Order(id=5)] + ), + User( + id=8, + addresses=[Address(id=2), Address(id=3), Address(id=4)], + open_orders = [], + closed_orders = [] + ), + User( + id=9, + addresses=[Address(id=5)], + open_orders = [Order(id=4)], + closed_orders = [Order(id=2)] + ), + User(id=10) + + ] == q.all() + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_double_same_mappers(self): + """Eager loading with two relations simulatneously, from the same table, using aliases.""" + User, Address, Order = self.classes.get_all( + 'User', 'Address', 'Order') + users, addresses, orders = self.tables.get_all( + 'users', 'addresses', 'orders') + + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items': relation(Item, secondary=order_items, lazy=False, + order_by=items.c.id)}) + mapper(Item, items) + mapper(User, users, properties=dict( + addresses=relation(Address, lazy=False, order_by=addresses.c.id), + open_orders=relation( + Order, + primaryjoin=sa.and_(orders.c.isopen == 1, + users.c.id==orders.c.user_id), + lazy=False, order_by=orders.c.id), + closed_orders=relation( + Order, + primaryjoin=sa.and_(orders.c.isopen == 0, + users.c.id==orders.c.user_id), + lazy=False, order_by=orders.c.id))) + q = create_session().query(User).order_by(User.id) + + def go(): + assert [ + User(id=7, + addresses=[ + Address(id=1)], + open_orders=[Order(id=3, + items=[ + Item(id=3), + Item(id=4), + Item(id=5)])], + closed_orders=[Order(id=1, + items=[ + Item(id=1), + Item(id=2), + Item(id=3)]), + Order(id=5, + items=[ + Item(id=5)])]), + User(id=8, + addresses=[ + Address(id=2), + Address(id=3), + Address(id=4)], + open_orders = [], + closed_orders = []), + User(id=9, + addresses=[ + Address(id=5)], + open_orders=[ + Order(id=4, + items=[ + Item(id=1), + Item(id=5)])], + closed_orders=[ + Order(id=2, + items=[ + Item(id=1), + Item(id=2), + Item(id=3)])]), + User(id=10) + ] == q.all() + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_no_false_hits(self): + """Eager loaders don't interpret main table columns as part of their eager load.""" + User, Address, Order = self.classes.get_all( + 'User', 'Address', 'Order') + users, addresses, orders = self.tables.get_all( + 'users', 'addresses', 'orders') + + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False), + 'orders':relation(Order, lazy=False) + }) + mapper(Address, addresses) + mapper(Order, orders) + + allusers = create_session().query(User).all() + + # using a textual select, the columns will be 'id' and 'name'. the + # eager loaders have aliases which should not hit on those columns, + # they should be required to locate only their aliased/fully table + # qualified column name. + noeagers = create_session().query(User).from_statement("select * from users").all() + assert 'orders' not in noeagers[0].__dict__ + assert 'addresses' not in noeagers[0].__dict__ + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_limit(self): + """Limit operations combined with lazy-load relationships.""" + User, Item, Address, Order = self.classes.get_all( + 'User', 'Item', 'Address', 'Order') + users, items, order_items, orders, addresses = self.tables.get_all( + 'users', 'items', 'order_items', 'orders', 'addresses') + + mapper(Item, items) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id) + }) + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id), + 'orders':relation(Order, lazy=True) + }) + + sess = create_session() + q = sess.query(User) + + if testing.against('mysql'): + l = q.limit(2).all() + assert self.static.user_all_result[:2] == l + else: + l = q.order_by(User.id).limit(2).offset(1).all() + print self.static.user_all_result[1:3] + print l + assert self.static.user_all_result[1:3] == l + + @testing.resolve_artifact_names + def test_distinct(self): + # this is an involved 3x union of the users table to get a lot of rows. + # then see if the "distinct" works its way out. you actually get the same + # result with or without the distinct, just via less or more rows. + u2 = users.alias('u2') + s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') + + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=False), + }) + + sess = create_session() + q = sess.query(User) + + def go(): + l = q.filter(s.c.u2_id==User.id).distinct().all() + assert self.static.user_address_result == l + self.assert_sql_count(testing.db, go, 1) + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_limit_2(self): + mapper(Keyword, keywords) + mapper(Item, items, properties = dict( + keywords = relation(Keyword, secondary=item_keywords, lazy=False, order_by=[keywords.c.id]), + )) + + sess = create_session() + q = sess.query(Item) + l = q.filter((Item.description=='item 2') | (Item.description=='item 5') | (Item.description=='item 3')).\ + order_by(Item.id).limit(2).all() + + assert self.static.item_keyword_result[1:3] == l + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_limit_3(self): + """test that the ORDER BY is propagated from the inner select to the outer select, when using the + 'wrapped' select statement resulting from the combination of eager loading and limit/offset clauses.""" + + mapper(Item, items) + mapper(Order, orders, properties = dict( + items = relation(Item, secondary=order_items, lazy=False) + )) + + mapper(Address, addresses) + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=False, order_by=addresses.c.id), + orders = relation(Order, lazy=False, order_by=orders.c.id), + )) + sess = create_session() + + q = sess.query(User) + + if not testing.against('maxdb', 'mssql'): + l = q.join('orders').order_by(Order.user_id.desc()).limit(2).offset(1) + assert [ + User(id=9, + orders=[Order(id=2), Order(id=4)], + addresses=[Address(id=5)] + ), + User(id=7, + orders=[Order(id=1), Order(id=3), Order(id=5)], + addresses=[Address(id=1)] + ) + ] == l.all() + + l = q.join('addresses').order_by(Address.email_address.desc()).limit(1).offset(0) + assert [ + User(id=7, + orders=[Order(id=1), Order(id=3), Order(id=5)], + addresses=[Address(id=1)] + ) + ] == l.all() + + @testing.resolve_artifact_names + def test_limit_4(self): + # tests the LIMIT/OFFSET aliasing on a mapper against a select. original issue from ticket #904 + sel = sa.select([users, addresses.c.email_address], users.c.id==addresses.c.user_id).alias('useralias') + mapper(User, sel, properties={ + 'orders':relation(Order, primaryjoin=sel.c.id==orders.c.user_id, lazy=False) + }) + mapper(Order, orders) + + sess = create_session() + eq_(sess.query(User).first(), + User(name=u'jack',orders=[ + Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), + Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), + Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5)], + email_address=u'jack@bean.com',id=7) + ) + + @testing.resolve_artifact_names + def test_one_to_many_scalar(self): + mapper(User, users, properties = dict( + address = relation(mapper(Address, addresses), lazy=False, uselist=False) + )) + q = create_session().query(User) + + def go(): + l = q.filter(users.c.id == 7).all() + assert [User(id=7, address=Address(id=1))] == l + self.assert_sql_count(testing.db, go, 1) + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_many_to_one(self): + mapper(Address, addresses, properties = dict( + user = relation(mapper(User, users), lazy=False) + )) + sess = create_session() + q = sess.query(Address) + + def go(): + a = q.filter(addresses.c.id==1).one() + assert a.user is not None + u1 = sess.query(User).get(7) + assert a.user is u1 + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_many_to_one_null(self): + """test that a many-to-one eager load which loads None does + not later trigger a lazy load. + + """ + + # use a primaryjoin intended to defeat SA's usage of + # query.get() for a many-to-one lazyload + mapper(Order, orders, properties = dict( + address = relation(mapper(Address, addresses), + primaryjoin=and_( + addresses.c.id==orders.c.address_id, + addresses.c.email_address != None + ), + + lazy=False) + )) + sess = create_session() + + def go(): + o1 = sess.query(Order).options(lazyload('address')).filter(Order.id==5).one() + eq_(o1.address, None) + self.assert_sql_count(testing.db, go, 2) + + sess.expunge_all() + def go(): + o1 = sess.query(Order).filter(Order.id==5).one() + eq_(o1.address, None) + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_one_and_many(self): + """tests eager load for a parent object with a child object that + contains a many-to-many relationship to a third object.""" + + mapper(User, users, properties={ + 'orders':relation(Order, lazy=False, order_by=orders.c.id) + }) + mapper(Item, items) + mapper(Order, orders, properties = dict( + items = relation(Item, secondary=order_items, lazy=False, order_by=items.c.id) + )) + + q = create_session().query(User) + + l = q.filter("users.id in (7, 8, 9)").order_by("users.id") + + def go(): + assert self.static.user_order_result[0:3] == l.all() + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_double_with_aggregate(self): + max_orders_by_user = sa.select([sa.func.max(orders.c.id).label('order_id')], group_by=[orders.c.user_id]).alias('max_orders_by_user') + + max_orders = orders.select(orders.c.id==max_orders_by_user.c.order_id).alias('max_orders') + + mapper(Order, orders) + mapper(User, users, properties={ + 'orders':relation(Order, backref='user', lazy=False), + 'max_order':relation(mapper(Order, max_orders, non_primary=True), lazy=False, uselist=False) + }) + q = create_session().query(User) + + def go(): + assert [ + User(id=7, orders=[ + Order(id=1), + Order(id=3), + Order(id=5), + ], + max_order=Order(id=5) + ), + User(id=8, orders=[]), + User(id=9, orders=[Order(id=2),Order(id=4)], + max_order=Order(id=4) + ), + User(id=10), + ] == q.all() + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_wide(self): + mapper(Order, orders, properties={'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id)}) + mapper(Item, items) + mapper(User, users, properties = dict( + addresses = relation(mapper(Address, addresses), lazy = False, order_by=addresses.c.id), + orders = relation(Order, lazy = False, order_by=orders.c.id), + )) + q = create_session().query(User) + l = q.all() + assert self.static.user_all_result == q.order_by(User.id).all() + + @testing.resolve_artifact_names + def test_against_select(self): + """test eager loading of a mapper which is against a select""" + + s = sa.select([orders], orders.c.isopen==1).alias('openorders') + + mapper(Order, s, properties={ + 'user':relation(User, lazy=False) + }) + mapper(User, users) + mapper(Item, items) + + q = create_session().query(Order) + assert [ + Order(id=3, user=User(id=7)), + Order(id=4, user=User(id=9)) + ] == q.all() + + q = q.select_from(s.join(order_items).join(items)).filter(~Item.id.in_([1, 2, 5])) + assert [ + Order(id=3, user=User(id=7)), + ] == q.all() + + @testing.resolve_artifact_names + def test_aliasing(self): + """test that eager loading uses aliases to insulate the eager load from regular criterion against those tables.""" + + mapper(User, users, properties = dict( + addresses = relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id) + )) + q = create_session().query(User) + l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id).order_by(User.id) + assert self.static.user_address_result[1:2] == l.all() + +class AddEntityTest(_fixtures.FixtureTest): + run_inserts = 'once' + run_deletes = None + + @testing.resolve_artifact_names + def _assert_result(self): + return [ + ( + User(id=7, + addresses=[Address(id=1)] + ), + Order(id=1, + items=[Item(id=1), Item(id=2), Item(id=3)] + ), + ), + ( + User(id=7, + addresses=[Address(id=1)] + ), + Order(id=3, + items=[Item(id=3), Item(id=4), Item(id=5)] + ), + ), + ( + User(id=7, + addresses=[Address(id=1)] + ), + Order(id=5, + items=[Item(id=5)] + ), + ), + ( + User(id=9, + addresses=[Address(id=5)] + ), + Order(id=2, + items=[Item(id=1), Item(id=2), Item(id=3)] + ), + ), + ( + User(id=9, + addresses=[Address(id=5)] + ), + Order(id=4, + items=[Item(id=1), Item(id=5)] + ), + ) + ] + + @testing.resolve_artifact_names + def test_mapper_configured(self): + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False), + 'orders':relation(Order) + }) + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id) + }) + mapper(Item, items) + + + sess = create_session() + oalias = sa.orm.aliased(Order) + def go(): + ret = sess.query(User, oalias).join(('orders', oalias)).order_by(User.id, oalias.id).all() + eq_(ret, self._assert_result()) + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_options(self): + mapper(User, users, properties={ + 'addresses':relation(Address), + 'orders':relation(Order) + }) + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, order_by=items.c.id) + }) + mapper(Item, items) + + sess = create_session() + + oalias = sa.orm.aliased(Order) + def go(): + ret = sess.query(User, oalias).options(eagerload('addresses')).join(('orders', oalias)).order_by(User.id, oalias.id).all() + eq_(ret, self._assert_result()) + self.assert_sql_count(testing.db, go, 6) + + sess.expunge_all() + def go(): + ret = sess.query(User, oalias).options(eagerload('addresses'), eagerload(oalias.items)).join(('orders', oalias)).order_by(User.id, oalias.id).all() + eq_(ret, self._assert_result()) + self.assert_sql_count(testing.db, go, 1) + +class OrderBySecondaryTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('m2m', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + Column('bid', Integer, ForeignKey('b.id'))) + + Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50))) + Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50))) + + @classmethod + def fixtures(cls): + return dict( + a=(('id', 'data'), + (1, 'a1'), + (2, 'a2')), + + b=(('id', 'data'), + (1, 'b1'), + (2, 'b2'), + (3, 'b3'), + (4, 'b4')), + + m2m=(('id', 'aid', 'bid'), + (2, 1, 1), + (4, 2, 4), + (1, 1, 3), + (6, 2, 2), + (3, 1, 2), + (5, 2, 3))) + + @testing.resolve_artifact_names + def test_ordering(self): + class A(_base.ComparableEntity):pass + class B(_base.ComparableEntity):pass + + mapper(A, a, properties={ + 'bs':relation(B, secondary=m2m, lazy=False, order_by=m2m.c.id) + }) + mapper(B, b) + + sess = create_session() + eq_(sess.query(A).all(), [A(data='a1', bs=[B(data='b3'), B(data='b1'), B(data='b2')]), A(bs=[B(data='b4'), B(data='b3'), B(data='b2')])]) + + +class SelfReferentialEagerTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('nodes', metadata, + Column('id', Integer, sa.Sequence('node_id_seq', optional=True), + primary_key=True), + Column('parent_id', Integer, ForeignKey('nodes.id')), + Column('data', String(30))) + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_basic(self): + class Node(_base.ComparableEntity): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=False, join_depth=3, order_by=nodes.c.id) + }) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + n1.append(Node(data='n13')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) + sess.add(n1) + sess.flush() + sess.expunge_all() + def go(): + d = sess.query(Node).filter_by(data='n1').all()[0] + assert Node(data='n1', children=[ + Node(data='n11'), + Node(data='n12', children=[ + Node(data='n121'), + Node(data='n122'), + Node(data='n123') + ]), + Node(data='n13') + ]) == d + self.assert_sql_count(testing.db, go, 1) + + sess.expunge_all() + def go(): + d = sess.query(Node).filter_by(data='n1').first() + assert Node(data='n1', children=[ + Node(data='n11'), + Node(data='n12', children=[ + Node(data='n121'), + Node(data='n122'), + Node(data='n123') + ]), + Node(data='n13') + ]) == d + self.assert_sql_count(testing.db, go, 1) + + + @testing.resolve_artifact_names + def test_lazy_fallback_doesnt_affect_eager(self): + class Node(_base.ComparableEntity): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=False, join_depth=1, order_by=nodes.c.id) + }) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + n1.append(Node(data='n13')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) + sess.add(n1) + sess.flush() + sess.expunge_all() + + # eager load with join depth 1. when eager load of 'n1' hits the + # children of 'n12', no columns are present, eager loader degrades to + # lazy loader; fine. but then, 'n12' is *also* in the first level of + # columns since we're loading the whole table. when those rows + # arrive, now we *can* eager load its children and an eager collection + # should be initialized. essentially the 'n12' instance is present in + # not just two different rows but two distinct sets of columns in this + # result set. + def go(): + allnodes = sess.query(Node).order_by(Node.data).all() + n12 = allnodes[2] + assert n12.data == 'n12' + assert [ + Node(data='n121'), + Node(data='n122'), + Node(data='n123') + ] == list(n12.children) + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_with_deferred(self): + class Node(_base.ComparableEntity): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=False, join_depth=3, order_by=nodes.c.id), + 'data':deferred(nodes.c.data) + }) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + sess.add(n1) + sess.flush() + sess.expunge_all() + + def go(): + eq_( + Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), + sess.query(Node).order_by(Node.id).first(), + ) + self.assert_sql_count(testing.db, go, 4) + + sess.expunge_all() + + def go(): + assert Node(data='n1', children=[Node(data='n11'), Node(data='n12')]) == sess.query(Node).options(undefer('data')).order_by(Node.id).first() + self.assert_sql_count(testing.db, go, 3) + + sess.expunge_all() + + def go(): + assert Node(data='n1', children=[Node(data='n11'), Node(data='n12')]) == sess.query(Node).options(undefer('data'), undefer('children.data')).first() + self.assert_sql_count(testing.db, go, 1) + + + @testing.resolve_artifact_names + def test_options(self): + class Node(_base.ComparableEntity): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=True, order_by=nodes.c.id) + }, order_by=nodes.c.id) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + n1.append(Node(data='n13')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) + sess.add(n1) + sess.flush() + sess.expunge_all() + def go(): + d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first() + assert Node(data='n1', children=[ + Node(data='n11'), + Node(data='n12', children=[ + Node(data='n121'), + Node(data='n122'), + Node(data='n123') + ]), + Node(data='n13') + ]) == d + self.assert_sql_count(testing.db, go, 2) + + def go(): + d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first() + + # test that the query isn't wrapping the initial query for eager loading. + self.assert_sql_execution(testing.db, go, + CompiledSQL( + "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes " + "WHERE nodes.data = :data_1 ORDER BY nodes.id LIMIT 1 OFFSET 0", + {'data_1': 'n1'} + ) + ) + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_no_depth(self): + class Node(_base.ComparableEntity): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=False) + }) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + n1.append(Node(data='n13')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) + sess.add(n1) + sess.flush() + sess.expunge_all() + def go(): + d = sess.query(Node).filter_by(data='n1').first() + assert Node(data='n1', children=[ + Node(data='n11'), + Node(data='n12', children=[ + Node(data='n121'), + Node(data='n122'), + Node(data='n123') + ]), + Node(data='n13') + ]) == d + self.assert_sql_count(testing.db, go, 3) + +class MixedSelfReferentialEagerTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('a_table', metadata, + Column('id', Integer, primary_key=True) + ) + + Table('b_table', metadata, + Column('id', Integer, primary_key=True), + Column('parent_b1_id', Integer, ForeignKey('b_table.id')), + Column('parent_a_id', Integer, ForeignKey('a_table.id')), + Column('parent_b2_id', Integer, ForeignKey('b_table.id'))) + + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + class A(_base.ComparableEntity): + pass + class B(_base.ComparableEntity): + pass + + mapper(A,a_table) + mapper(B,b_table,properties = { + 'parent_b1': relation(B, + remote_side = [b_table.c.id], + primaryjoin = (b_table.c.parent_b1_id ==b_table.c.id), + order_by = b_table.c.id + ), + 'parent_z': relation(A,lazy = True), + 'parent_b2': relation(B, + remote_side = [b_table.c.id], + primaryjoin = (b_table.c.parent_b2_id ==b_table.c.id), + order_by = b_table.c.id + ) + }); + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + a_table.insert().execute(dict(id=1), dict(id=2), dict(id=3)) + b_table.insert().execute( + dict(id=1, parent_a_id=2, parent_b1_id=None, parent_b2_id=None), + dict(id=2, parent_a_id=1, parent_b1_id=1, parent_b2_id=None), + dict(id=3, parent_a_id=1, parent_b1_id=1, parent_b2_id=2), + dict(id=4, parent_a_id=3, parent_b1_id=1, parent_b2_id=None), + dict(id=5, parent_a_id=3, parent_b1_id=None, parent_b2_id=2), + dict(id=6, parent_a_id=1, parent_b1_id=1, parent_b2_id=3), + dict(id=7, parent_a_id=2, parent_b1_id=None, parent_b2_id=3), + dict(id=8, parent_a_id=2, parent_b1_id=1, parent_b2_id=2), + dict(id=9, parent_a_id=None, parent_b1_id=1, parent_b2_id=None), + dict(id=10, parent_a_id=3, parent_b1_id=7, parent_b2_id=2), + dict(id=11, parent_a_id=3, parent_b1_id=1, parent_b2_id=8), + dict(id=12, parent_a_id=2, parent_b1_id=5, parent_b2_id=2), + dict(id=13, parent_a_id=3, parent_b1_id=4, parent_b2_id=4), + dict(id=14, parent_a_id=3, parent_b1_id=7, parent_b2_id=2), + ) + + @testing.resolve_artifact_names + def test_eager_load(self): + session = create_session() + def go(): + eq_( + session.query(B).options(eagerload('parent_b1'),eagerload('parent_b2'),eagerload('parent_z')). + filter(B.id.in_([2, 8, 11])).order_by(B.id).all(), + [ + B(id=2, parent_z=A(id=1), parent_b1=B(id=1), parent_b2=None), + B(id=8, parent_z=A(id=2), parent_b1=B(id=1), parent_b2=B(id=2)), + B(id=11, parent_z=A(id=3), parent_b1=B(id=1), parent_b2=B(id=8)) + ] + ) + self.assert_sql_count(testing.db, go, 1) + +class SelfReferentialM2MEagerTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('widget', metadata, + Column('id', Integer, primary_key=True), + Column('name', sa.Unicode(40), nullable=False, unique=True), + ) + + Table('widget_rel', metadata, + Column('parent_id', Integer, ForeignKey('widget.id')), + Column('child_id', Integer, ForeignKey('widget.id')), + sa.UniqueConstraint('parent_id', 'child_id'), + ) + + @testing.resolve_artifact_names + def test_basic(self): + class Widget(_base.ComparableEntity): + pass + + mapper(Widget, widget, properties={ + 'children': relation(Widget, secondary=widget_rel, + primaryjoin=widget_rel.c.parent_id==widget.c.id, + secondaryjoin=widget_rel.c.child_id==widget.c.id, + lazy=False, join_depth=1, + ) + }) + + sess = create_session() + w1 = Widget(name=u'w1') + w2 = Widget(name=u'w2') + w1.children.append(w2) + sess.add(w1) + sess.flush() + sess.expunge_all() + + assert [Widget(name='w1', children=[Widget(name='w2')])] == sess.query(Widget).filter(Widget.name==u'w1').all() + +class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user'), + 'orders':relation(Order, backref='user'), # o2m, m2o + }) + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m + }) + mapper(Item, items, properties={ + 'keywords':relation(Keyword, secondary=item_keywords) #m2m + }) + mapper(Keyword, keywords) + + @testing.resolve_artifact_names + def test_two_entities(self): + sess = create_session() + + # two FROM clauses + def go(): + eq_( + [ + (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])), + (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])), + ], + sess.query(User, Order).filter(User.id==Order.user_id).\ + options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).\ + order_by(User.id, Order.id).all(), + ) + self.assert_sql_count(testing.db, go, 1) + + # one FROM clause + def go(): + eq_( + [ + (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])), + (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])), + ], + sess.query(User, Order).join(User.orders).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).\ + order_by(User.id, Order.id).all(), + ) + self.assert_sql_count(testing.db, go, 1) + + @testing.exclude('sqlite', '>', (0, 0, 0), "sqlite flat out blows it on the multiple JOINs") + @testing.resolve_artifact_names + def test_two_entities_with_joins(self): + sess = create_session() + + # two FROM clauses where there's a join on each one + def go(): + u1 = aliased(User) + o1 = aliased(Order) + eq_( + [ + ( + User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), + Order(description=u'order 2', isopen=0, items=[Item(description=u'item 1'), Item(description=u'item 2'), Item(description=u'item 3')]), + User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), + Order(description=u'order 3', isopen=1, items=[Item(description=u'item 3'), Item(description=u'item 4'), Item(description=u'item 5')]) + ), + + ( + User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), + Order(description=u'order 2', isopen=0, items=[Item(description=u'item 1'), Item(description=u'item 2'), Item(description=u'item 3')]), + User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), + Order(address_id=None, description=u'order 5', isopen=0, items=[Item(description=u'item 5')]) + ), + + ( + User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), + Order(description=u'order 4', isopen=1, items=[Item(description=u'item 1'), Item(description=u'item 5')]), + User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), + Order(address_id=None, description=u'order 5', isopen=0, items=[Item(description=u'item 5')]) + ), + ], + sess.query(User, Order, u1, o1).\ + join((Order, User.orders)).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).\ + join((o1, u1.orders)).options(eagerload(u1.addresses), eagerload(o1.items)).filter(u1.id==7).\ + filter(Order.id 10) + assert res2.count() == 19 + + @testing.resolve_artifact_names + def test_options(self): + query = create_session().query(Foo) + class ext1(sa.orm.MapperExtension): + def populate_instance(self, mapper, selectcontext, row, instance, **flags): + instance.TEST = "hello world" + return sa.orm.EXT_CONTINUE + assert query.options(sa.orm.extension(ext1()))[0].TEST == "hello world" + + @testing.resolve_artifact_names + def test_order_by(self): + query = create_session().query(Foo) + assert query.order_by([Foo.bar])[0].bar == 0 + assert query.order_by([sa.desc(Foo.bar)])[0].bar == 99 + + @testing.resolve_artifact_names + def test_offset(self): + query = create_session().query(Foo) + assert list(query.order_by([Foo.bar]).offset(10))[0].bar == 10 + + @testing.resolve_artifact_names + def test_offset(self): + query = create_session().query(Foo) + assert len(list(query.limit(10))) == 10 + + +class GenerativeTest2(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('Table1', metadata, + Column('id', Integer, primary_key=True)) + Table('Table2', metadata, + Column('t1id', Integer, ForeignKey("Table1.id"), + primary_key=True), + Column('num', Integer, primary_key=True)) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + class Obj1(_base.BasicEntity): + pass + class Obj2(_base.BasicEntity): + pass + + mapper(Obj1, Table1) + mapper(Obj2, Table2) + + @classmethod + def fixtures(cls): + return dict( + Table1=(('id',), + (1,), + (2,), + (3,), + (4,)), + Table2=(('num', 't1id'), + (1, 1), + (2, 1), + (3, 1), + (4, 2), + (5, 2), + (6, 3))) + + @testing.resolve_artifact_names + def test_distinct_count(self): + query = create_session().query(Obj1) + eq_(query.count(), 4) + + res = query.filter(sa.and_(Table1.c.id == Table2.c.t1id, + Table2.c.t1id == 1)) + eq_(res.count(), 3) + res = query.filter(sa.and_(Table1.c.id == Table2.c.t1id, + Table2.c.t1id == 1)).distinct() + eq_(res.count(), 1) + + +class RelationsTest(_fixtures.FixtureTest): + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(User, users, properties={ + 'orders':relation(mapper(Order, orders, properties={ + 'addresses':relation(mapper(Address, addresses))}))}) + + + @testing.resolve_artifact_names + def test_join(self): + """Query.join""" + + session = create_session() + q = (session.query(User).join(['orders', 'addresses']). + filter(Address.id == 1)) + eq_([User(id=7)], q.all()) + + @testing.resolve_artifact_names + def test_outer_join(self): + """Query.outerjoin""" + + session = create_session() + q = (session.query(User).outerjoin(['orders', 'addresses']). + filter(sa.or_(Order.id == None, Address.id == 1))) + eq_(set([User(id=7), User(id=8), User(id=10)]), + set(q.all())) + + @testing.resolve_artifact_names + def test_outer_join_count(self): + """test the join and outerjoin functions on Query""" + + session = create_session() + + q = (session.query(User).outerjoin(['orders', 'addresses']). + filter(sa.or_(Order.id == None, Address.id == 1))) + eq_(q.count(), 4) + + @testing.resolve_artifact_names + def test_from(self): + session = create_session() + + sel = users.outerjoin(orders).outerjoin( + addresses, orders.c.address_id == addresses.c.id) + q = (session.query(User).select_from(sel). + filter(sa.or_(Order.id == None, Address.id == 1))) + eq_(set([User(id=7), User(id=8), User(id=10)]), + set(q.all())) + + +class CaseSensitiveTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('Table1', metadata, + Column('ID', Integer, primary_key=True)) + Table('Table2', metadata, + Column('T1ID', Integer, ForeignKey("Table1.ID"), + primary_key=True), + Column('NUM', Integer, primary_key=True)) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + class Obj1(_base.BasicEntity): + pass + class Obj2(_base.BasicEntity): + pass + + mapper(Obj1, Table1) + mapper(Obj2, Table2) + + @classmethod + def fixtures(cls): + return dict( + Table1=(('ID',), + (1,), + (2,), + (3,), + (4,)), + Table2=(('NUM', 'T1ID'), + (1, 1), + (2, 1), + (3, 1), + (4, 2), + (5, 2), + (6, 3))) + + @testing.resolve_artifact_names + def test_distinct_count(self): + q = create_session(bind=testing.db).query(Obj1) + assert q.count() == 4 + res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)) + assert res.count() == 3 + res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)).distinct() + eq_(res.count(), 1) + + diff --git a/test/orm/test_instrumentation.py b/test/orm/test_instrumentation.py new file mode 100644 index 000000000..b4c8f8601 --- /dev/null +++ b/test/orm/test_instrumentation.py @@ -0,0 +1,766 @@ + +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy import MetaData, Integer, ForeignKey, util +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers +from sqlalchemy.test.testing import eq_, ne_ +from sqlalchemy.util import function_named +from test.orm import _base + + +def modifies_instrumentation_finders(fn): + def decorated(*args, **kw): + pristine = attributes.instrumentation_finders[:] + try: + fn(*args, **kw) + finally: + del attributes.instrumentation_finders[:] + attributes.instrumentation_finders.extend(pristine) + return function_named(decorated, fn.func_name) + +def with_lookup_strategy(strategy): + def decorate(fn): + def wrapped(*args, **kw): + try: + attributes._install_lookup_strategy(strategy) + return fn(*args, **kw) + finally: + attributes._install_lookup_strategy(sa.util.symbol('native')) + return function_named(wrapped, fn.func_name) + return decorate + + +class InitTest(_base.ORMTest): + def fixture(self): + return Table('t', MetaData(), + Column('id', Integer, primary_key=True), + Column('type', Integer), + Column('x', Integer), + Column('y', Integer)) + + def register(self, cls, canary): + original_init = cls.__init__ + attributes.register_class(cls) + ne_(cls.__init__, original_init) + manager = attributes.manager_of_class(cls) + def on_init(state, instance, args, kwargs): + canary.append((cls, 'on_init', type(instance))) + manager.events.add_listener('on_init', on_init) + + def test_ai(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + + obj = A() + eq_(inits, [(A, '__init__')]) + + def test_A(self): + inits = [] + + class A(object): pass + self.register(A, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A)]) + + def test_Ai(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + self.register(A, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A), (A, '__init__')]) + + def test_ai_B(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + + class B(A): pass + self.register(B, inits) + + obj = A() + eq_(inits, [(A, '__init__')]) + + del inits[:] + + obj = B() + eq_(inits, [(B, 'on_init', B), (A, '__init__')]) + + def test_ai_Bi(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + + class B(A): + def __init__(self): + inits.append((B, '__init__')) + super(B, self).__init__() + self.register(B, inits) + + obj = A() + eq_(inits, [(A, '__init__')]) + + del inits[:] + + obj = B() + eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')]) + + def test_Ai_bi(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + self.register(A, inits) + + class B(A): + def __init__(self): + inits.append((B, '__init__')) + super(B, self).__init__() + + obj = A() + eq_(inits, [(A, 'on_init', A), (A, '__init__')]) + + del inits[:] + + obj = B() + eq_(inits, [(B, '__init__'), (A, 'on_init', B), (A, '__init__')]) + + def test_Ai_Bi(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + self.register(A, inits) + + class B(A): + def __init__(self): + inits.append((B, '__init__')) + super(B, self).__init__() + self.register(B, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A), (A, '__init__')]) + + del inits[:] + + obj = B() + eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')]) + + def test_Ai_B(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + self.register(A, inits) + + class B(A): pass + self.register(B, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A), (A, '__init__')]) + + del inits[:] + + obj = B() + eq_(inits, [(B, 'on_init', B), (A, '__init__')]) + + def test_Ai_Bi_Ci(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + self.register(A, inits) + + class B(A): + def __init__(self): + inits.append((B, '__init__')) + super(B, self).__init__() + self.register(B, inits) + + class C(B): + def __init__(self): + inits.append((C, '__init__')) + super(C, self).__init__() + self.register(C, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A), (A, '__init__')]) + + del inits[:] + + obj = B() + eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')]) + + del inits[:] + obj = C() + eq_(inits, [(C, 'on_init', C), (C, '__init__'), (B, '__init__'), + (A, '__init__')]) + + def test_Ai_bi_Ci(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + self.register(A, inits) + + class B(A): + def __init__(self): + inits.append((B, '__init__')) + super(B, self).__init__() + + class C(B): + def __init__(self): + inits.append((C, '__init__')) + super(C, self).__init__() + self.register(C, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A), (A, '__init__')]) + + del inits[:] + + obj = B() + eq_(inits, [(B, '__init__'), (A, 'on_init', B), (A, '__init__')]) + + del inits[:] + obj = C() + eq_(inits, [(C, 'on_init', C), (C, '__init__'), (B, '__init__'), + (A, '__init__')]) + + def test_Ai_b_Ci(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + self.register(A, inits) + + class B(A): pass + + class C(B): + def __init__(self): + inits.append((C, '__init__')) + super(C, self).__init__() + self.register(C, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A), (A, '__init__')]) + + del inits[:] + + obj = B() + eq_(inits, [(A, 'on_init', B), (A, '__init__')]) + + del inits[:] + obj = C() + eq_(inits, [(C, 'on_init', C), (C, '__init__'), (A, '__init__')]) + + def test_Ai_B_Ci(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + self.register(A, inits) + + class B(A): pass + self.register(B, inits) + + class C(B): + def __init__(self): + inits.append((C, '__init__')) + super(C, self).__init__() + self.register(C, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A), (A, '__init__')]) + + del inits[:] + + obj = B() + eq_(inits, [(B, 'on_init', B), (A, '__init__')]) + + del inits[:] + obj = C() + eq_(inits, [(C, 'on_init', C), (C, '__init__'), (A, '__init__')]) + + def test_Ai_B_C(self): + inits = [] + + class A(object): + def __init__(self): + inits.append((A, '__init__')) + self.register(A, inits) + + class B(A): pass + self.register(B, inits) + + class C(B): pass + self.register(C, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A), (A, '__init__')]) + + del inits[:] + + obj = B() + eq_(inits, [(B, 'on_init', B), (A, '__init__')]) + + del inits[:] + obj = C() + eq_(inits, [(C, 'on_init', C), (A, '__init__')]) + + def test_A_Bi_C(self): + inits = [] + + class A(object): pass + self.register(A, inits) + + class B(A): + def __init__(self): + inits.append((B, '__init__')) + self.register(B, inits) + + class C(B): pass + self.register(C, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A)]) + + del inits[:] + + obj = B() + eq_(inits, [(B, 'on_init', B), (B, '__init__')]) + + del inits[:] + obj = C() + eq_(inits, [(C, 'on_init', C), (B, '__init__')]) + + def test_A_B_Ci(self): + inits = [] + + class A(object): pass + self.register(A, inits) + + class B(A): pass + self.register(B, inits) + + class C(B): + def __init__(self): + inits.append((C, '__init__')) + self.register(C, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A)]) + + del inits[:] + + obj = B() + eq_(inits, [(B, 'on_init', B)]) + + del inits[:] + obj = C() + eq_(inits, [(C, 'on_init', C), (C, '__init__')]) + + def test_A_B_C(self): + inits = [] + + class A(object): pass + self.register(A, inits) + + class B(A): pass + self.register(B, inits) + + class C(B): pass + self.register(C, inits) + + obj = A() + eq_(inits, [(A, 'on_init', A)]) + + del inits[:] + + obj = B() + eq_(inits, [(B, 'on_init', B)]) + + del inits[:] + obj = C() + eq_(inits, [(C, 'on_init', C)]) + + def test_defaulted_init(self): + class X(object): + def __init__(self_, a, b=123, c='abc'): + self_.a = a + self_.b = b + self_.c = c + attributes.register_class(X) + + o = X('foo') + eq_(o.a, 'foo') + eq_(o.b, 123) + eq_(o.c, 'abc') + + class Y(object): + unique = object() + + class OutOfScopeForEval(object): + def __repr__(self_): + # misleading repr + return '123' + + outofscope = OutOfScopeForEval() + + def __init__(self_, u=unique, o=outofscope): + self_.u = u + self_.o = o + + attributes.register_class(Y) + + o = Y() + assert o.u is Y.unique + assert o.o is Y.outofscope + + +class MapperInitTest(_base.ORMTest): + + def fixture(self): + return Table('t', MetaData(), + Column('id', Integer, primary_key=True), + Column('type', Integer), + Column('x', Integer), + Column('y', Integer)) + + def test_partially_mapped_inheritance(self): + class A(object): + pass + + class B(A): + pass + + class C(B): + def __init__(self, x): + pass + + m = mapper(A, self.fixture()) + + # B is not mapped in the current implementation + assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, B) + + # C is not mapped in the current implementation + assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, C) + +class InstrumentationCollisionTest(_base.ORMTest): + def test_none(self): + class A(object): pass + attributes.register_class(A) + + mgr_factory = lambda cls: attributes.ClassManager(cls) + class B(object): + __sa_instrumentation_manager__ = staticmethod(mgr_factory) + attributes.register_class(B) + + class C(object): + __sa_instrumentation_manager__ = attributes.ClassManager + attributes.register_class(C) + + def test_single_down(self): + class A(object): pass + attributes.register_class(A) + + mgr_factory = lambda cls: attributes.ClassManager(cls) + class B(A): + __sa_instrumentation_manager__ = staticmethod(mgr_factory) + + assert_raises(TypeError, attributes.register_class, B) + + def test_single_up(self): + + class A(object): pass + # delay registration + + mgr_factory = lambda cls: attributes.ClassManager(cls) + class B(A): + __sa_instrumentation_manager__ = staticmethod(mgr_factory) + attributes.register_class(B) + assert_raises(TypeError, attributes.register_class, A) + + def test_diamond_b1(self): + mgr_factory = lambda cls: attributes.ClassManager(cls) + + class A(object): pass + class B1(A): pass + class B2(A): + __sa_instrumentation_manager__ = mgr_factory + class C(object): pass + + assert_raises(TypeError, attributes.register_class, B1) + + def test_diamond_b2(self): + mgr_factory = lambda cls: attributes.ClassManager(cls) + + class A(object): pass + class B1(A): pass + class B2(A): + __sa_instrumentation_manager__ = mgr_factory + class C(object): pass + + assert_raises(TypeError, attributes.register_class, B2) + + def test_diamond_c_b(self): + mgr_factory = lambda cls: attributes.ClassManager(cls) + + class A(object): pass + class B1(A): pass + class B2(A): + __sa_instrumentation_manager__ = mgr_factory + class C(object): pass + + attributes.register_class(C) + assert_raises(TypeError, attributes.register_class, B1) + + +class OnLoadTest(_base.ORMTest): + """Check that Events.on_load is not hit in regular attributes operations.""" + + def test_basic(self): + import pickle + + global A + class A(object): + pass + + def canary(instance): assert False + + try: + attributes.register_class(A) + manager = attributes.manager_of_class(A) + manager.events.add_listener('on_load', canary) + + a = A() + p_a = pickle.dumps(a) + re_a = pickle.loads(p_a) + finally: + del A + + @classmethod + def teardown_class(cls): + clear_mappers() + attributes._install_lookup_strategy(util.symbol('native')) + + +class ExtendedEventsTest(_base.ORMTest): + """Allow custom Events implementations.""" + + @modifies_instrumentation_finders + def test_subclassed(self): + class MyEvents(attributes.Events): + pass + class MyClassManager(attributes.ClassManager): + event_registry_factory = MyEvents + + attributes.instrumentation_finders.insert(0, lambda cls: MyClassManager) + + class A(object): pass + + attributes.register_class(A) + manager = attributes.manager_of_class(A) + assert isinstance(manager.events, MyEvents) + + + +class NativeInstrumentationTest(_base.ORMTest): + @with_lookup_strategy(sa.util.symbol('native')) + def test_register_reserved_attribute(self): + class T(object): pass + + attributes.register_class(T) + manager = attributes.manager_of_class(T) + + sa = attributes.ClassManager.STATE_ATTR + ma = attributes.ClassManager.MANAGER_ATTR + + fails = lambda method, attr: assert_raises( + KeyError, getattr(manager, method), attr, property()) + + fails('install_member', sa) + fails('install_member', ma) + fails('install_descriptor', sa) + fails('install_descriptor', ma) + + @with_lookup_strategy(sa.util.symbol('native')) + def test_mapped_stateattr(self): + t = Table('t', MetaData(), + Column('id', Integer, primary_key=True), + Column(attributes.ClassManager.STATE_ATTR, Integer)) + + class T(object): pass + + assert_raises(KeyError, mapper, T, t) + + @with_lookup_strategy(sa.util.symbol('native')) + def test_mapped_managerattr(self): + t = Table('t', MetaData(), + Column('id', Integer, primary_key=True), + Column(attributes.ClassManager.MANAGER_ATTR, Integer)) + + class T(object): pass + assert_raises(KeyError, mapper, T, t) + + +class MiscTest(_base.ORMTest): + """Seems basic, but not directly covered elsewhere!""" + + def test_compileonattr(self): + t = Table('t', MetaData(), + Column('id', Integer, primary_key=True), + Column('x', Integer)) + class A(object): pass + mapper(A, t) + + a = A() + assert a.id is None + + def test_compileonattr_rel(self): + m = MetaData() + t1 = Table('t1', m, + Column('id', Integer, primary_key=True), + Column('x', Integer)) + t2 = Table('t2', m, + Column('id', Integer, primary_key=True), + Column('t1_id', Integer, ForeignKey('t1.id'))) + class A(object): pass + class B(object): pass + mapper(A, t1, properties=dict(bs=relation(B))) + mapper(B, t2) + + a = A() + assert not a.bs + + def test_uninstrument(self): + class A(object):pass + + manager = attributes.register_class(A) + + assert attributes.manager_of_class(A) is manager + attributes.unregister_class(A) + assert attributes.manager_of_class(A) is None + + def test_compileonattr_rel_backref_a(self): + m = MetaData() + t1 = Table('t1', m, + Column('id', Integer, primary_key=True), + Column('x', Integer)) + t2 = Table('t2', m, + Column('id', Integer, primary_key=True), + Column('t1_id', Integer, ForeignKey('t1.id'))) + + class Base(object): + def __init__(self, *args, **kwargs): + pass + + for base in object, Base: + class A(base): pass + class B(base): pass + mapper(A, t1, properties=dict(bs=relation(B, backref='a'))) + mapper(B, t2) + + b = B() + assert b.a is None + a = A() + b.a = a + + session = create_session() + session.add(b) + assert a in session, "base is %s" % base + + def test_compileonattr_rel_backref_b(self): + m = MetaData() + t1 = Table('t1', m, + Column('id', Integer, primary_key=True), + Column('x', Integer)) + t2 = Table('t2', m, + Column('id', Integer, primary_key=True), + Column('t1_id', Integer, ForeignKey('t1.id'))) + + class Base(object): + def __init__(self): pass + class Base_AKW(object): + def __init__(self, *args, **kwargs): pass + + for base in object, Base, Base_AKW: + class A(base): pass + class B(base): pass + mapper(A, t1) + mapper(B, t2, properties=dict(a=relation(A, backref='bs'))) + + a = A() + b = B() + b.a = a + + session = create_session() + session.add(a) + assert b in session, 'base: %s' % base + + +class FinderTest(_base.ORMTest): + def test_standard(self): + class A(object): pass + + attributes.register_class(A) + + eq_(type(attributes.manager_of_class(A)), attributes.ClassManager) + + def test_nativeext_interfaceexact(self): + class A(object): + __sa_instrumentation_manager__ = sa.orm.interfaces.InstrumentationManager + + attributes.register_class(A) + ne_(type(attributes.manager_of_class(A)), attributes.ClassManager) + + def test_nativeext_submanager(self): + class Mine(attributes.ClassManager): pass + class A(object): + __sa_instrumentation_manager__ = Mine + + attributes.register_class(A) + eq_(type(attributes.manager_of_class(A)), Mine) + + @modifies_instrumentation_finders + def test_customfinder_greedy(self): + class Mine(attributes.ClassManager): pass + class A(object): pass + def find(cls): + return Mine + + attributes.instrumentation_finders.insert(0, find) + attributes.register_class(A) + eq_(type(attributes.manager_of_class(A)), Mine) + + @modifies_instrumentation_finders + def test_customfinder_pass(self): + class A(object): pass + def find(cls): + return None + + attributes.instrumentation_finders.insert(0, find) + attributes.register_class(A) + eq_(type(attributes.manager_of_class(A)), attributes.ClassManager) + + diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py new file mode 100644 index 000000000..819f29911 --- /dev/null +++ b/test/orm/test_lazy_relations.py @@ -0,0 +1,419 @@ +"""basic tests of lazy loaded attributes""" + +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import datetime +from sqlalchemy import exc as sa_exc +from sqlalchemy.orm import attributes +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base, _fixtures + + +class LazyTest(_fixtures.FixtureTest): + run_inserts = 'once' + run_deletes = None + + @testing.resolve_artifact_names + def test_basic(self): + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=True) + }) + sess = create_session() + q = sess.query(User) + assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all() + + @testing.resolve_artifact_names + def test_needs_parent(self): + """test the error raised when parent object is not bound.""" + + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=True) + }) + sess = create_session() + q = sess.query(User) + u = q.filter(users.c.id == 7).first() + sess.expunge(u) + assert_raises(sa_exc.InvalidRequestError, getattr, u, 'addresses') + + @testing.resolve_artifact_names + def test_orderby(self): + mapper(User, users, properties = { + 'addresses':relation(mapper(Address, addresses), lazy=True, order_by=addresses.c.email_address), + }) + q = create_session().query(User) + assert [ + User(id=7, addresses=[ + Address(id=1) + ]), + User(id=8, addresses=[ + Address(id=3, email_address='ed@bettyboop.com'), + Address(id=4, email_address='ed@lala.com'), + Address(id=2, email_address='ed@wood.com') + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=10, addresses=[]) + ] == q.all() + + @testing.resolve_artifact_names + def test_orderby_secondary(self): + """tests that a regular mapper select on a single table can order by a relation to a second table""" + + mapper(Address, addresses) + + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=True), + )) + q = create_session().query(User) + l = q.filter(users.c.id==addresses.c.user_id).order_by(addresses.c.email_address).all() + assert [ + User(id=8, addresses=[ + Address(id=2, email_address='ed@wood.com'), + Address(id=3, email_address='ed@bettyboop.com'), + Address(id=4, email_address='ed@lala.com'), + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=7, addresses=[ + Address(id=1) + ]), + ] == l + + @testing.resolve_artifact_names + def test_orderby_desc(self): + mapper(Address, addresses) + + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=True, order_by=[sa.desc(addresses.c.email_address)]), + )) + sess = create_session() + assert [ + User(id=7, addresses=[ + Address(id=1) + ]), + User(id=8, addresses=[ + Address(id=2, email_address='ed@wood.com'), + Address(id=4, email_address='ed@lala.com'), + Address(id=3, email_address='ed@bettyboop.com'), + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + User(id=10, addresses=[]) + ] == sess.query(User).all() + + @testing.resolve_artifact_names + def test_no_orphan(self): + """test that a lazily loaded child object is not marked as an orphan""" + + mapper(User, users, properties={ + 'addresses':relation(Address, cascade="all,delete-orphan", lazy=True) + }) + mapper(Address, addresses) + + sess = create_session() + user = sess.query(User).get(7) + assert getattr(User, 'addresses').hasparent(attributes.instance_state(user.addresses[0]), optimistic=True) + assert not sa.orm.class_mapper(Address)._is_orphan(attributes.instance_state(user.addresses[0])) + + @testing.resolve_artifact_names + def test_limit(self): + """test limit operations combined with lazy-load relationships.""" + + mapper(Item, items) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy=True) + }) + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=True), + 'orders':relation(Order, lazy=True) + }) + + sess = create_session() + q = sess.query(User) + + if testing.against('maxdb', 'mssql'): + l = q.limit(2).all() + assert self.static.user_all_result[:2] == l + else: + l = q.limit(2).offset(1).all() + assert self.static.user_all_result[1:3] == l + + @testing.resolve_artifact_names + def test_distinct(self): + mapper(Item, items) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy=True) + }) + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), lazy=True), + 'orders':relation(Order, lazy=True) + }) + + sess = create_session() + q = sess.query(User) + + # use a union all to get a lot of rows to join against + u2 = users.alias('u2') + s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') + print [key for key in s.c.keys()] + l = q.filter(s.c.u2_id==User.id).distinct().all() + assert self.static.user_all_result == l + + @testing.resolve_artifact_names + def test_one_to_many_scalar(self): + mapper(User, users, properties = dict( + address = relation(mapper(Address, addresses), lazy=True, uselist=False) + )) + q = create_session().query(User) + l = q.filter(users.c.id == 7).all() + assert [User(id=7, address=Address(id=1))] == l + + @testing.resolve_artifact_names + def test_many_to_one_binds(self): + mapper(Address, addresses, primary_key=[addresses.c.user_id, addresses.c.email_address]) + + mapper(User, users, properties = dict( + address = relation(Address, uselist=False, + primaryjoin=sa.and_(users.c.id==addresses.c.user_id, addresses.c.email_address=='ed@bettyboop.com') + ) + )) + q = create_session().query(User) + eq_( + [ + User(id=7, address=None), + User(id=8, address=Address(id=3)), + User(id=9, address=None), + User(id=10, address=None), + ], + list(q) + ) + + + @testing.resolve_artifact_names + def test_double(self): + """tests lazy loading with two relations simulatneously, from the same table, using aliases. """ + + openorders = sa.alias(orders, 'openorders') + closedorders = sa.alias(orders, 'closedorders') + + mapper(Address, addresses) + + mapper(Order, orders) + + open_mapper = mapper(Order, openorders, non_primary=True) + closed_mapper = mapper(Order, closedorders, non_primary=True) + mapper(User, users, properties = dict( + addresses = relation(Address, lazy = True), + open_orders = relation(open_mapper, primaryjoin = sa.and_(openorders.c.isopen == 1, users.c.id==openorders.c.user_id), lazy=True), + closed_orders = relation(closed_mapper, primaryjoin = sa.and_(closedorders.c.isopen == 0, users.c.id==closedorders.c.user_id), lazy=True) + )) + q = create_session().query(User) + + assert [ + User( + id=7, + addresses=[Address(id=1)], + open_orders = [Order(id=3)], + closed_orders = [Order(id=1), Order(id=5)] + ), + User( + id=8, + addresses=[Address(id=2), Address(id=3), Address(id=4)], + open_orders = [], + closed_orders = [] + ), + User( + id=9, + addresses=[Address(id=5)], + open_orders = [Order(id=4)], + closed_orders = [Order(id=2)] + ), + User(id=10) + + ] == q.all() + + sess = create_session() + user = sess.query(User).get(7) + assert [Order(id=1), Order(id=5)] == create_session().query(closed_mapper).with_parent(user, property='closed_orders').all() + assert [Order(id=3)] == create_session().query(open_mapper).with_parent(user, property='open_orders').all() + + @testing.resolve_artifact_names + def test_many_to_many(self): + + mapper(Keyword, keywords) + mapper(Item, items, properties = dict( + keywords = relation(Keyword, secondary=item_keywords, lazy=True), + )) + + q = create_session().query(Item) + assert self.static.item_keyword_result == q.all() + + assert self.static.item_keyword_result[0:2] == q.join('keywords').filter(keywords.c.name == 'red').all() + + @testing.resolve_artifact_names + def test_uses_get(self): + """test that a simple many-to-one lazyload optimizes to use query.get().""" + + for pj in ( + None, + users.c.id==addresses.c.user_id, + addresses.c.user_id==users.c.id + ): + mapper(Address, addresses, properties = dict( + user = relation(mapper(User, users), lazy=True, primaryjoin=pj) + )) + + sess = create_session() + + # load address + a1 = sess.query(Address).filter_by(email_address="ed@wood.com").one() + + # load user that is attached to the address + u1 = sess.query(User).get(8) + + def go(): + # lazy load of a1.user should get it from the session + assert a1.user is u1 + self.assert_sql_count(testing.db, go, 0) + sa.orm.clear_mappers() + + @testing.resolve_artifact_names + def test_many_to_one(self): + mapper(Address, addresses, properties = dict( + user = relation(mapper(User, users), lazy=True) + )) + sess = create_session() + q = sess.query(Address) + a = q.filter(addresses.c.id==1).one() + + assert a.user is not None + + u1 = sess.query(User).get(7) + + assert a.user is u1 + + @testing.resolve_artifact_names + def test_backrefs_dont_lazyload(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user') + }) + mapper(Address, addresses) + sess = create_session() + ad = sess.query(Address).filter_by(id=1).one() + assert ad.user.id == 7 + def go(): + ad.user = None + assert ad.user is None + self.assert_sql_count(testing.db, go, 0) + + u1 = sess.query(User).filter_by(id=7).one() + def go(): + assert ad not in u1.addresses + self.assert_sql_count(testing.db, go, 1) + + sess.expire(u1, ['addresses']) + def go(): + assert ad in u1.addresses + self.assert_sql_count(testing.db, go, 1) + + sess.expire(u1, ['addresses']) + ad2 = Address() + def go(): + ad2.user = u1 + assert ad2.user is u1 + self.assert_sql_count(testing.db, go, 0) + + def go(): + assert ad2 in u1.addresses + self.assert_sql_count(testing.db, go, 1) + + +class M2OGetTest(_fixtures.FixtureTest): + run_inserts = 'once' + run_deletes = None + + @testing.resolve_artifact_names + def test_m2o_noload(self): + """test that a NULL foreign key doesn't trigger a lazy load""" + mapper(User, users) + + mapper(Address, addresses, properties={ + 'user':relation(User) + }) + + sess = create_session() + ad1 = Address(email_address='somenewaddress', id=12) + sess.add(ad1) + sess.flush() + sess.expunge_all() + + ad2 = sess.query(Address).get(1) + ad3 = sess.query(Address).get(ad1.id) + def go(): + # one lazy load + assert ad2.user.name == 'jack' + # no lazy load + assert ad3.user is None + self.assert_sql_count(testing.db, go, 1) + +class CorrelatedTest(_base.MappedTest): + + @classmethod + def define_tables(self, meta): + Table('user_t', meta, + Column('id', Integer, primary_key=True), + Column('name', String(50))) + + Table('stuff', meta, + Column('id', Integer, primary_key=True), + Column('date', sa.Date), + Column('user_id', Integer, ForeignKey('user_t.id'))) + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + user_t.insert().execute( + {'id':1, 'name':'user1'}, + {'id':2, 'name':'user2'}, + {'id':3, 'name':'user3'}) + + stuff.insert().execute( + {'id':1, 'user_id':1, 'date':datetime.date(2007, 10, 15)}, + {'id':2, 'user_id':1, 'date':datetime.date(2007, 12, 15)}, + {'id':3, 'user_id':1, 'date':datetime.date(2007, 11, 15)}, + {'id':4, 'user_id':2, 'date':datetime.date(2008, 1, 15)}, + {'id':5, 'user_id':3, 'date':datetime.date(2007, 6, 15)}) + + @testing.resolve_artifact_names + def test_correlated_lazyload(self): + class User(_base.ComparableEntity): + pass + + class Stuff(_base.ComparableEntity): + pass + + mapper(Stuff, stuff) + + stuff_view = sa.select([stuff.c.id]).where(stuff.c.user_id==user_t.c.id).correlate(user_t).order_by(sa.desc(stuff.c.date)).limit(1) + + mapper(User, user_t, properties={ + 'stuff':relation(Stuff, primaryjoin=sa.and_(user_t.c.id==stuff.c.user_id, stuff.c.id==(stuff_view.as_scalar()))) + }) + + sess = create_session() + + eq_(sess.query(User).all(), [ + User(name='user1', stuff=[Stuff(date=datetime.date(2007, 12, 15), id=2)]), + User(name='user2', stuff=[Stuff(id=4, date=datetime.date(2008, 1 , 15))]), + User(name='user3', stuff=[Stuff(id=5, date=datetime.date(2007, 6, 15))]) + ]) + + diff --git a/test/orm/test_lazytest1.py b/test/orm/test_lazytest1.py new file mode 100644 index 000000000..f76cb3203 --- /dev/null +++ b/test/orm/test_lazytest1.py @@ -0,0 +1,92 @@ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base + + +class LazyTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('infos', metadata, + Column('pk', Integer, primary_key=True), + Column('info', String(128))) + + Table('data', metadata, + Column('data_pk', Integer, primary_key=True), + Column('info_pk', Integer, + ForeignKey('infos.pk')), + Column('timeval', Integer), + Column('data_val', String(128))) + + Table('rels', metadata, + Column('rel_pk', Integer, primary_key=True), + Column('info_pk', Integer, + ForeignKey('infos.pk')), + Column('start', Integer), + Column('finish', Integer)) + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + infos.insert().execute( + {'pk':1, 'info':'pk_1_info'}, + {'pk':2, 'info':'pk_2_info'}, + {'pk':3, 'info':'pk_3_info'}, + {'pk':4, 'info':'pk_4_info'}, + {'pk':5, 'info':'pk_5_info'}) + + rels.insert().execute( + {'rel_pk':1, 'info_pk':1, 'start':10, 'finish':19}, + {'rel_pk':2, 'info_pk':1, 'start':100, 'finish':199}, + {'rel_pk':3, 'info_pk':2, 'start':20, 'finish':29}, + {'rel_pk':4, 'info_pk':3, 'start':13, 'finish':23}, + {'rel_pk':5, 'info_pk':5, 'start':15, 'finish':25}) + + data.insert().execute( + {'data_pk':1, 'info_pk':1, 'timeval':11, 'data_val':'11_data'}, + {'data_pk':2, 'info_pk':1, 'timeval':9, 'data_val':'9_data'}, + {'data_pk':3, 'info_pk':1, 'timeval':13, 'data_val':'13_data'}, + {'data_pk':4, 'info_pk':2, 'timeval':23, 'data_val':'23_data'}, + {'data_pk':5, 'info_pk':2, 'timeval':13, 'data_val':'13_data'}, + {'data_pk':6, 'info_pk':1, 'timeval':15, 'data_val':'15_data'}) + + @testing.resolve_artifact_names + def testone(self): + """A lazy load which has multiple join conditions. + + Including two that are against the same column in the child table. + + """ + class Information(object): + pass + + class Relation(object): + pass + + class Data(object): + pass + + session = create_session() + + mapper(Data, data) + mapper(Relation, rels, properties={ + 'datas': relation(Data, + primaryjoin=sa.and_( + rels.c.info_pk == + data.c.info_pk, + data.c.timeval >= rels.c.start, + data.c.timeval <= rels.c.finish), + foreign_keys=[data.c.info_pk])}) + mapper(Information, infos, properties={ + 'rels': relation(Relation) + }) + + info = session.query(Information).get(1) + assert info + assert len(info.rels) == 2 + assert len(info.rels[0].datas) == 3 + + diff --git a/test/orm/test_manytomany.py b/test/orm/test_manytomany.py new file mode 100644 index 000000000..dcd547f80 --- /dev/null +++ b/test/orm/test_manytomany.py @@ -0,0 +1,330 @@ +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base + + +class M2MTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('place', metadata, + Column('place_id', Integer, sa.Sequence('pid_seq', optional=True), + primary_key=True), + Column('name', String(30), nullable=False)) + + Table('transition', metadata, + Column('transition_id', Integer, + sa.Sequence('tid_seq', optional=True), primary_key=True), + Column('name', String(30), nullable=False)) + + Table('place_thingy', metadata, + Column('thingy_id', Integer, sa.Sequence('thid_seq', optional=True), + primary_key=True), + Column('place_id', Integer, ForeignKey('place.place_id'), + nullable=False), + Column('name', String(30), nullable=False)) + + # association table #1 + Table('place_input', metadata, + Column('place_id', Integer, ForeignKey('place.place_id')), + Column('transition_id', Integer, + ForeignKey('transition.transition_id'))) + + # association table #2 + Table('place_output', metadata, + Column('place_id', Integer, ForeignKey('place.place_id')), + Column('transition_id', Integer, + ForeignKey('transition.transition_id'))) + + Table('place_place', metadata, + Column('pl1_id', Integer, ForeignKey('place.place_id')), + Column('pl2_id', Integer, ForeignKey('place.place_id'))) + + @classmethod + def setup_classes(cls): + class Place(_base.BasicEntity): + def __init__(self, name=None): + self.name = name + def __str__(self): + return "(Place '%s')" % self.name + __repr__ = __str__ + + class PlaceThingy(_base.BasicEntity): + def __init__(self, name=None): + self.name = name + + class Transition(_base.BasicEntity): + def __init__(self, name=None): + self.name = name + self.inputs = [] + self.outputs = [] + def __repr__(self): + return ' '.join((object.__repr__(self), + repr(self.inputs), + repr(self.outputs))) + + @testing.resolve_artifact_names + def test_error(self): + mapper(Place, place, properties={ + 'transitions':relation(Transition, secondary=place_input, backref='places') + }) + mapper(Transition, transition, properties={ + 'places':relation(Place, secondary=place_input, backref='transitions') + }) + assert_raises_message(sa.exc.ArgumentError, "Error creating backref", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_circular(self): + """test a many-to-many relationship from a table to itself.""" + + Place.mapper = mapper(Place, place) + + Place.mapper.add_property('places', relation( + Place.mapper, secondary=place_place, primaryjoin=place.c.place_id==place_place.c.pl1_id, + secondaryjoin=place.c.place_id==place_place.c.pl2_id, + order_by=place_place.c.pl2_id, + lazy=True, + )) + + sess = create_session() + p1 = Place('place1') + p2 = Place('place2') + p3 = Place('place3') + p4 = Place('place4') + p5 = Place('place5') + p6 = Place('place6') + p7 = Place('place7') + sess.add_all((p1, p2, p3, p4, p5, p6, p7)) + p1.places.append(p2) + p1.places.append(p3) + p5.places.append(p6) + p6.places.append(p1) + p7.places.append(p1) + p1.places.append(p5) + p4.places.append(p3) + p3.places.append(p4) + sess.flush() + + sess.expunge_all() + l = sess.query(Place).order_by(place.c.place_id).all() + (p1, p2, p3, p4, p5, p6, p7) = l + assert p1.places == [p2,p3,p5] + assert p5.places == [p6] + assert p7.places == [p1] + assert p6.places == [p1] + assert p4.places == [p3] + assert p3.places == [p4] + assert p2.places == [] + + for p in l: + pp = p.places + print "Place " + str(p) +" places " + repr(pp) + + [sess.delete(p) for p in p1,p2,p3,p4,p5,p6,p7] + sess.flush() + + @testing.resolve_artifact_names + def test_double(self): + """test that a mapper can have two eager relations to the same table, via + two different association tables. aliases are required.""" + + Place.mapper = mapper(Place, place, properties = { + 'thingies':relation(mapper(PlaceThingy, place_thingy), lazy=False) + }) + + Transition.mapper = mapper(Transition, transition, properties = dict( + inputs = relation(Place.mapper, place_output, lazy=False), + outputs = relation(Place.mapper, place_input, lazy=False), + ) + ) + + tran = Transition('transition1') + tran.inputs.append(Place('place1')) + tran.outputs.append(Place('place2')) + tran.outputs.append(Place('place3')) + sess = create_session() + sess.add(tran) + sess.flush() + + sess.expunge_all() + r = sess.query(Transition).all() + self.assert_unordered_result(r, Transition, + {'name': 'transition1', + 'inputs': (Place, [{'name':'place1'}]), + 'outputs': (Place, [{'name':'place2'}, {'name':'place3'}]) + }) + + @testing.resolve_artifact_names + def test_bidirectional(self): + """tests a many-to-many backrefs""" + Place.mapper = mapper(Place, place) + Transition.mapper = mapper(Transition, transition, properties = dict( + inputs = relation(Place.mapper, place_output, lazy=True, backref='inputs'), + outputs = relation(Place.mapper, place_input, lazy=True, backref='outputs'), + ) + ) + + t1 = Transition('transition1') + t2 = Transition('transition2') + t3 = Transition('transition3') + p1 = Place('place1') + p2 = Place('place2') + p3 = Place('place3') + + t1.inputs.append(p1) + t1.inputs.append(p2) + t1.outputs.append(p3) + t2.inputs.append(p1) + p2.inputs.append(t2) + p3.inputs.append(t2) + p1.outputs.append(t1) + sess = create_session() + sess.add_all((t1, t2, t3,p1, p2, p3)) + sess.flush() + + self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])}) + self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])}) + + +class M2MTest2(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('student', metadata, + Column('name', String(20), primary_key=True)) + + Table('course', metadata, + Column('name', String(20), primary_key=True)) + + Table('enroll', metadata, + Column('student_id', String(20), ForeignKey('student.name'), + primary_key=True), + Column('course_id', String(20), ForeignKey('course.name'), + primary_key=True)) + + @classmethod + def setup_classes(cls): + class Student(_base.BasicEntity): + def __init__(self, name=''): + self.name = name + class Course(_base.BasicEntity): + def __init__(self, name=''): + self.name = name + + @testing.resolve_artifact_names + def test_circular(self): + + mapper(Student, student) + mapper(Course, course, properties={ + 'students': relation(Student, enroll, backref='courses')}) + + sess = create_session() + s1 = Student('Student1') + c1 = Course('Course1') + c2 = Course('Course2') + c3 = Course('Course3') + s1.courses.append(c1) + s1.courses.append(c2) + c3.students.append(s1) + self.assert_(len(s1.courses) == 3) + self.assert_(len(c1.students) == 1) + sess.add(s1) + sess.flush() + sess.expunge_all() + s = sess.query(Student).filter_by(name='Student1').one() + c = sess.query(Course).filter_by(name='Course3').one() + self.assert_(len(s.courses) == 3) + del s.courses[1] + self.assert_(len(s.courses) == 2) + + @testing.resolve_artifact_names + def test_dupliates_raise(self): + """test constraint error is raised for dupe entries in a list""" + + mapper(Student, student) + mapper(Course, course, properties={ + 'students': relation(Student, enroll, backref='courses')}) + + sess = create_session() + s1 = Student("s1") + c1 = Course('c1') + s1.courses.append(c1) + s1.courses.append(c1) + sess.add(s1) + assert_raises(sa.exc.DBAPIError, sess.flush) + + @testing.resolve_artifact_names + def test_delete(self): + """A many-to-many table gets cleared out with deletion from the backref side""" + + mapper(Student, student) + mapper(Course, course, properties = { + 'students': relation(Student, enroll, lazy=True, + backref='courses')}) + + sess = create_session() + s1 = Student('Student1') + c1 = Course('Course1') + c2 = Course('Course2') + c3 = Course('Course3') + s1.courses.append(c1) + s1.courses.append(c2) + c3.students.append(s1) + sess.add(s1) + sess.flush() + sess.delete(s1) + sess.flush() + assert enroll.count().scalar() == 0 + +class M2MTest3(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('c', metadata, + Column('c1', Integer, primary_key = True), + Column('c2', String(20))) + + Table('a', metadata, + Column('a1', Integer, primary_key=True), + Column('a2', String(20)), + Column('c1', Integer, ForeignKey('c.c1'))) + + Table('c2a1', metadata, + Column('c1', Integer, ForeignKey('c.c1')), + Column('a1', Integer, ForeignKey('a.a1'))) + + Table('c2a2', metadata, + Column('c1', Integer, ForeignKey('c.c1')), + Column('a1', Integer, ForeignKey('a.a1'))) + + Table('b', metadata, + Column('b1', Integer, primary_key=True), + Column('a1', Integer, ForeignKey('a.a1')), + Column('b2', sa.Boolean)) + + @testing.resolve_artifact_names + def test_basic(self): + class C(object):pass + class A(object):pass + class B(object):pass + + mapper(B, b) + + mapper(A, a, properties={ + 'tbs': relation(B, primaryjoin=sa.and_(b.c.a1 == a.c.a1, + b.c.b2 == True), + lazy=False)}) + + mapper(C, c, properties={ + 'a1s': relation(A, secondary=c2a1, lazy=False), + 'a2s': relation(A, secondary=c2a2, lazy=False)}) + + assert create_session().query(C).with_labels().statement + + # TODO: seems like just a test for an ancient exception throw. + # how about some data/inserts/queries/assertions for this one + + diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py new file mode 100644 index 000000000..025b96424 --- /dev/null +++ b/test/orm/test_mapper.py @@ -0,0 +1,2475 @@ +"""General mapper operations with an emphasis on selecting/loading.""" + +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing, pickleable +from sqlalchemy import MetaData, Integer, String, ForeignKey, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.engine import default +from sqlalchemy.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased +from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property +from sqlalchemy.test.testing import eq_, AssertsCompiledSQL +from test.orm import _base, _fixtures + + +class MapperTest(_fixtures.FixtureTest): + + @testing.resolve_artifact_names + def test_prop_shadow(self): + """A backref name may not shadow an existing property name.""" + + mapper(Address, addresses) + mapper(User, users, + properties={ + 'addresses':relation(Address, backref='email_address') + }) + assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_update_attr_keys(self): + """test that update()/insert() use the correct key when given InstrumentedAttributes.""" + + mapper(User, users, properties={ + 'foobar':users.c.name + }) + + users.insert().values({User.foobar:'name1'}).execute() + eq_(sa.select([User.foobar]).where(User.foobar=='name1').execute().fetchall(), [('name1',)]) + + users.update().values({User.foobar:User.foobar + 'foo'}).execute() + eq_(sa.select([User.foobar]).where(User.foobar=='name1foo').execute().fetchall(), [('name1foo',)]) + + @testing.resolve_artifact_names + def test_utils(self): + from sqlalchemy.orm.util import _is_mapped_class, _is_aliased_class + + class Foo(object): + x = "something" + @property + def y(self): + return "somethign else" + m = mapper(Foo, users) + a1 = aliased(Foo) + + f = Foo() + + for fn, arg, ret in [ + (_is_mapped_class, Foo.x, False), + (_is_mapped_class, Foo.y, False), + (_is_mapped_class, Foo, True), + (_is_mapped_class, f, False), + (_is_mapped_class, a1, True), + (_is_mapped_class, m, True), + (_is_aliased_class, a1, True), + (_is_aliased_class, Foo.x, False), + (_is_aliased_class, Foo.y, False), + (_is_aliased_class, Foo, False), + (_is_aliased_class, f, False), + (_is_aliased_class, a1, True), + (_is_aliased_class, m, False), + ]: + assert fn(arg) == ret + + + + @testing.resolve_artifact_names + def test_prop_accessor(self): + mapper(User, users) + assert_raises(NotImplementedError, + getattr, sa.orm.class_mapper(User), 'properties') + + + @testing.resolve_artifact_names + def test_bad_cascade(self): + mapper(Address, addresses) + assert_raises(sa.exc.ArgumentError, + relation, Address, cascade="fake, all, delete-orphan") + + @testing.resolve_artifact_names + def test_exceptions_sticky(self): + """test preservation of mapper compile errors raised during hasattr().""" + + mapper(Address, addresses, properties={ + 'user':relation(User) + }) + + hasattr(Address.user, 'property') + assert_raises_message(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers) + + @testing.resolve_artifact_names + def test_column_prefix(self): + mapper(User, users, column_prefix='_', properties={ + 'user_name': synonym('_name') + }) + + s = create_session() + u = s.query(User).get(7) + eq_(u._name, 'jack') + eq_(u._id,7) + u2 = s.query(User).filter_by(user_name='jack').one() + assert u is u2 + + @testing.resolve_artifact_names + def test_no_pks_1(self): + s = sa.select([users.c.name]).alias('foo') + assert_raises(sa.exc.ArgumentError, mapper, User, s) + + @testing.emits_warning( + 'mapper Mapper|User|Select object creating an alias for ' + 'the given selectable - use Class attributes for queries') + @testing.resolve_artifact_names + def test_no_pks_2(self): + s = sa.select([users.c.name]) + assert_raises(sa.exc.ArgumentError, mapper, User, s) + + @testing.resolve_artifact_names + def test_recompile_on_other_mapper(self): + """A compile trigger on an already-compiled mapper still triggers a check against all mappers.""" + mapper(User, users) + sa.orm.compile_mappers() + assert sa.orm.mapperlib._new_mappers is False + + m = mapper(Address, addresses, properties={ + 'user': relation(User, backref="addresses")}) + + assert m.compiled is False + assert sa.orm.mapperlib._new_mappers is True + u = User() + assert User.addresses + assert sa.orm.mapperlib._new_mappers is False + + @testing.resolve_artifact_names + def test_compile_on_session(self): + m = mapper(User, users) + session = create_session() + session.connection(m) + + @testing.resolve_artifact_names + def test_incomplete_columns(self): + """Loading from a select which does not contain all columns""" + mapper(Address, addresses) + s = create_session() + a = s.query(Address).from_statement( + sa.select([addresses.c.id, addresses.c.user_id])).first() + eq_(a.user_id, 7) + eq_(a.id, 1) + # email address auto-defers + assert 'email_addres' not in a.__dict__ + eq_(a.email_address, 'jack@bean.com') + + @testing.resolve_artifact_names + def test_bad_constructor(self): + """If the construction of a mapped class fails, the instance does not get placed in the session""" + class Foo(object): + def __init__(self, one, two, _sa_session=None): + pass + + mapper(Foo, users, extension=sa.orm.scoped_session( + create_session).extension) + + sess = create_session() + assert_raises(TypeError, Foo, 'one', _sa_session=sess) + eq_(len(list(sess)), 0) + assert_raises(TypeError, Foo, 'one') + Foo('one', 'two', _sa_session=sess) + eq_(len(list(sess)), 1) + + @testing.resolve_artifact_names + def test_constructor_exc_1(self): + """Exceptions raised in the mapped class are not masked by sa decorations""" + ex = AssertionError('oops') + sess = create_session() + + class Foo(object): + def __init__(self, **kw): + raise ex + mapper(Foo, users) + + try: + Foo() + assert False + except Exception, e: + assert e is ex + + sa.orm.clear_mappers() + mapper(Foo, users, extension=sa.orm.scoped_session( + create_session).extension) + def bad_expunge(foo): + raise Exception("this exception should be stated as a warning") + + sess.expunge = bad_expunge + assert_raises(sa.exc.SAWarning, Foo, _sa_session=sess) + + @testing.resolve_artifact_names + def test_constructor_exc_2(self): + """TypeError is raised for illegal constructor args, whether or not explicit __init__ is present [ticket:908].""" + + class Foo(object): + def __init__(self): + pass + class Bar(object): + pass + + mapper(Foo, users) + mapper(Bar, addresses) + assert_raises(TypeError, Foo, x=5) + assert_raises(TypeError, Bar, x=5) + + @testing.resolve_artifact_names + def test_props(self): + m = mapper(User, users, properties = { + 'addresses' : relation(mapper(Address, addresses)) + }).compile() + assert User.addresses.property is m.get_property('addresses') + + @testing.resolve_artifact_names + def test_compile_on_prop_1(self): + mapper(User, users, properties = { + 'addresses' : relation(mapper(Address, addresses)) + }) + User.addresses.any(Address.email_address=='foo@bar.com') + + @testing.resolve_artifact_names + def test_compile_on_prop_2(self): + mapper(User, users, properties = { + 'addresses' : relation(mapper(Address, addresses)) + }) + eq_(str(User.id == 3), str(users.c.id==3)) + + @testing.resolve_artifact_names + def test_compile_on_prop_3(self): + class Foo(User):pass + mapper(User, users) + mapper(Foo, addresses, inherits=User) + assert getattr(Foo().__class__, 'name').impl is not None + + @testing.resolve_artifact_names + def test_deferred_subclass_attribute_instrument(self): + class Foo(User):pass + mapper(User, users) + compile_mappers() + mapper(Foo, addresses, inherits=User) + assert getattr(Foo().__class__, 'name').impl is not None + + @testing.resolve_artifact_names + def test_compile_on_get_props_1(self): + m =mapper(User, users) + assert not m.compiled + assert list(m.iterate_properties) + assert m.compiled + + @testing.resolve_artifact_names + def test_compile_on_get_props_2(self): + m= mapper(User, users) + assert not m.compiled + assert m.get_property('name') + assert m.compiled + + @testing.resolve_artifact_names + def test_add_property(self): + assert_col = [] + + class User(_base.ComparableEntity): + def _get_name(self): + assert_col.append(('get', self._name)) + return self._name + def _set_name(self, name): + assert_col.append(('set', name)) + self._name = name + name = property(_get_name, _set_name) + + def _uc_name(self): + if self._name is None: + return None + return self._name.upper() + uc_name = property(_uc_name) + uc_name2 = property(_uc_name) + + m = mapper(User, users) + mapper(Address, addresses) + + class UCComparator(sa.orm.PropComparator): + __hash__ = None + + def __eq__(self, other): + cls = self.prop.parent.class_ + col = getattr(cls, 'name') + if other is None: + return col == None + else: + return sa.func.upper(col) == sa.func.upper(other) + + m.add_property('_name', deferred(users.c.name)) + m.add_property('name', synonym('_name')) + m.add_property('addresses', relation(Address)) + m.add_property('uc_name', sa.orm.comparable_property(UCComparator)) + m.add_property('uc_name2', sa.orm.comparable_property( + UCComparator, User.uc_name2)) + + sess = create_session(autocommit=False) + assert sess.query(User).get(7) + + u = sess.query(User).filter_by(name='jack').one() + + def go(): + eq_(len(u.addresses), + len(self.static.user_address_result[0].addresses)) + eq_(u.name, 'jack') + eq_(u.uc_name, 'JACK') + eq_(u.uc_name2, 'JACK') + eq_(assert_col, [('get', 'jack')], str(assert_col)) + self.sql_count_(2, go) + + u.name = 'ed' + u3 = User() + u3.name = 'some user' + sess.add(u3) + sess.flush() + sess.rollback() + + @testing.resolve_artifact_names + def test_replace_property(self): + m = mapper(User, users) + m.add_property('_name',users.c.name) + m.add_property('name', synonym('_name', proxy=True)) + + sess = create_session() + u = sess.query(User).filter_by(name='jack').one() + eq_(u._name, 'jack') + eq_(u.name, 'jack') + u.name = 'jacko' + assert m._columntoproperty[users.c.name] is m.get_property('_name') + + sa.orm.clear_mappers() + + m = mapper(User, users) + m.add_property('name', synonym('_name', map_column=True)) + + sess.expunge_all() + u = sess.query(User).filter_by(name='jack').one() + eq_(u._name, 'jack') + eq_(u.name, 'jack') + u.name = 'jacko' + assert m._columntoproperty[users.c.name] is m.get_property('_name') + + @testing.resolve_artifact_names + def test_synonym_replaces_backref(self): + assert_calls = [] + class Address(object): + def _get_user(self): + assert_calls.append("get") + return self._user + def _set_user(self, user): + assert_calls.append("set") + self._user = user + user = property(_get_user, _set_user) + + # synonym is created against nonexistent prop + mapper(Address, addresses, properties={ + 'user':synonym('_user') + }) + sa.orm.compile_mappers() + + # later, backref sets up the prop + mapper(User, users, properties={ + 'addresses':relation(Address, backref='_user') + }) + + sess = create_session() + u1 = sess.query(User).get(7) + u2 = sess.query(User).get(8) + # comparaison ops need to work + a1 = sess.query(Address).filter(Address.user==u1).one() + eq_(a1.id, 1) + a1.user = u2 + assert a1.user is u2 + eq_(assert_calls, ["set", "get"]) + + @testing.resolve_artifact_names + def test_self_ref_synonym(self): + t = Table('nodes', MetaData(), + Column('id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('nodes.id'))) + + class Node(object): + pass + + mapper(Node, t, properties={ + '_children':relation(Node, backref=backref('_parent', remote_side=t.c.id)), + 'children':synonym('_children'), + 'parent':synonym('_parent') + }) + + n1 = Node() + n2 = Node() + n1.children.append(n2) + assert n2.parent is n2._parent is n1 + assert n1.children[0] is n1._children[0] is n2 + eq_(str(Node.parent == n2), ":param_1 = nodes.parent_id") + + @testing.resolve_artifact_names + def test_illegal_non_primary(self): + mapper(User, users) + mapper(Address, addresses) + try: + mapper(User, users, non_primary=True, properties={ + 'addresses':relation(Address) + }).compile() + assert False + except sa.exc.ArgumentError, e: + assert "Attempting to assign a new relation 'addresses' to a non-primary mapper on class 'User'" in str(e) + + @testing.resolve_artifact_names + def test_illegal_non_primary_2(self): + try: + mapper(User, users, non_primary=True) + assert False + except sa.exc.InvalidRequestError, e: + assert "Configure a primary mapper first" in str(e) + + @testing.resolve_artifact_names + def test_prop_filters(self): + t = Table('person', MetaData(), + Column('id', Integer, primary_key=True), + Column('type', String(128)), + Column('name', String(128)), + Column('employee_number', Integer), + Column('boss_id', Integer, ForeignKey('person.id')), + Column('vendor_id', Integer)) + + class Person(object): pass + class Vendor(Person): pass + class Employee(Person): pass + class Manager(Employee): pass + class Hoho(object): pass + class Lala(object): pass + + class HasDef(object): + def name(self): + pass + + p_m = mapper(Person, t, polymorphic_on=t.c.type, + include_properties=('id', 'type', 'name')) + e_m = mapper(Employee, inherits=p_m, polymorphic_identity='employee', + properties={ + 'boss': relation(Manager, backref=backref('peon', ), remote_side=t.c.id) + }, + exclude_properties=('vendor_id',)) + + m_m = mapper(Manager, inherits=e_m, polymorphic_identity='manager', + include_properties=('id', 'type')) + + v_m = mapper(Vendor, inherits=p_m, polymorphic_identity='vendor', + exclude_properties=('boss_id', 'employee_number')) + h_m = mapper(Hoho, t, include_properties=('id', 'type', 'name')) + l_m = mapper(Lala, t, exclude_properties=('vendor_id', 'boss_id'), + column_prefix="p_") + + hd_m = mapper(HasDef, t, column_prefix="h_") + + p_m.compile() + #sa.orm.compile_mappers() + + def assert_props(cls, want): + have = set([n for n in dir(cls) if not n.startswith('_')]) + want = set(want) + eq_(have, want) + + def assert_instrumented(cls, want): + have = set([p.key for p in class_mapper(cls).iterate_properties]) + want = set(want) + eq_(have, want) + + assert_props(HasDef, ['h_boss_id', 'h_employee_number', 'h_id', 'name', 'h_name', 'h_vendor_id', 'h_type']) + assert_props(Person, ['id', 'name', 'type']) + assert_instrumented(Person, ['id', 'name', 'type']) + assert_props(Employee, ['boss', 'boss_id', 'employee_number', + 'id', 'name', 'type']) + assert_instrumented(Employee,['boss', 'boss_id', 'employee_number', + 'id', 'name', 'type']) + assert_props(Manager, ['boss', 'boss_id', 'employee_number', 'peon', + 'id', 'name', 'type']) + + # 'peon' and 'type' are both explicitly stated properties + assert_instrumented(Manager, ['peon', 'type', 'id']) + + assert_props(Vendor, ['vendor_id', 'id', 'name', 'type']) + assert_props(Hoho, ['id', 'name', 'type']) + assert_props(Lala, ['p_employee_number', 'p_id', 'p_name', 'p_type']) + + # excluding the discriminator column is currently not allowed + class Foo(Person): + pass + assert_raises(sa.exc.InvalidRequestError, mapper, Foo, inherits=Person, polymorphic_identity='foo', exclude_properties=('type',) ) + + @testing.resolve_artifact_names + def test_mapping_to_join(self): + """Mapping to a join""" + usersaddresses = sa.join(users, addresses, + users.c.id == addresses.c.user_id) + mapper(User, usersaddresses, primary_key=[users.c.id]) + l = create_session().query(User).order_by(users.c.id).all() + eq_(l, self.static.user_result[:3]) + + @testing.resolve_artifact_names + def test_mapping_to_join_no_pk(self): + m = mapper(Address, addresses.join(email_bounces)) + m.compile() + assert addresses in m._pks_by_table + assert email_bounces not in m._pks_by_table + + sess = create_session() + a = Address(id=10, email_address='e1') + sess.add(a) + sess.flush() + + eq_(addresses.count().scalar(), 6) + eq_(email_bounces.count().scalar(), 5) + + @testing.resolve_artifact_names + def test_mapping_to_outerjoin(self): + """Mapping to an outer join with a nullable composite primary key.""" + + + mapper(User, users.outerjoin(addresses), + allow_null_pks=True, + primary_key=[users.c.id, addresses.c.id], + properties=dict( + address_id=addresses.c.id)) + + session = create_session() + l = session.query(User).order_by(User.id, User.address_id).all() + + eq_(l, [ + User(id=7, address_id=1), + User(id=8, address_id=2), + User(id=8, address_id=3), + User(id=8, address_id=4), + User(id=9, address_id=5), + User(id=10, address_id=None)]) + + @testing.resolve_artifact_names + def test_custom_join(self): + """select_from totally replace the FROM parameters.""" + + mapper(Item, items) + + mapper(Order, orders, properties=dict( + items=relation(Item, order_items))) + + mapper(User, users, properties=dict( + orders=relation(Order))) + + session = create_session() + l = (session.query(User). + select_from(users.join(orders). + join(order_items). + join(items)). + filter(items.c.description == 'item 4')).all() + + eq_(l, [self.static.user_result[0]]) + + @testing.resolve_artifact_names + def test_cancel_order_by(self): + mapper(User, users, order_by=users.c.name.desc()) + + assert "order by users.name desc" in str(create_session().query(User).statement).lower() + assert "order by" not in str(create_session().query(User).order_by(None).statement).lower() + assert "order by users.name asc" in str(create_session().query(User).order_by(User.name.asc()).statement).lower() + + eq_( + create_session().query(User).all(), + [User(id=7, name=u'jack'), User(id=9, name=u'fred'), User(id=8, name=u'ed'), User(id=10, name=u'chuck')] + ) + + eq_( + create_session().query(User).order_by(User.name).all(), + [User(id=10, name=u'chuck'), User(id=8, name=u'ed'), User(id=9, name=u'fred'), User(id=7, name=u'jack')] + ) + + # 'Raises a "expression evaluation not supported" error at prepare time + @testing.fails_on('firebird', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_function(self): + """Mapping to a SELECT statement that has functions in it.""" + + s = sa.select([users, + (users.c.id * 2).label('concat'), + sa.func.count(addresses.c.id).label('count')], + users.c.id == addresses.c.user_id, + group_by=[c for c in users.c]).alias('myselect') + + mapper(User, s, order_by=s.c.id) + sess = create_session() + l = sess.query(User).all() + + for idx, total in enumerate((14, 16)): + eq_(l[idx].concat, l[idx].id * 2) + eq_(l[idx].concat, total) + + @testing.resolve_artifact_names + def test_count(self): + """The count function on Query.""" + + mapper(User, users) + + session = create_session() + q = session.query(User) + + eq_(q.count(), 4) + eq_(q.filter(User.id.in_([8,9])).count(), 2) + eq_(q.filter(users.c.id.in_([8,9])).count(), 2) + + eq_(session.query(User.id).count(), 4) + eq_(session.query(User.id).filter(User.id.in_((8, 9))).count(), 2) + + @testing.resolve_artifact_names + def test_many_to_many_count(self): + mapper(Keyword, keywords) + mapper(Item, items, properties=dict( + keywords = relation(Keyword, item_keywords, lazy=True))) + + session = create_session() + q = (session.query(Item). + join('keywords'). + distinct(). + filter(Keyword.name == "red")) + eq_(q.count(), 2) + + @testing.resolve_artifact_names + def test_override_1(self): + """Overriding a column raises an error.""" + def go(): + mapper(User, users, + properties=dict( + name=relation(mapper(Address, addresses)))) + + assert_raises(sa.exc.ArgumentError, go) + + @testing.resolve_artifact_names + def test_override_2(self): + """exclude_properties cancels the error.""" + + mapper(User, users, + exclude_properties=['name'], + properties=dict( + name=relation(mapper(Address, addresses)))) + + assert bool(User.name) + + @testing.resolve_artifact_names + def test_override_3(self): + """The column being named elsewhere also cancels the error,""" + mapper(User, users, + properties=dict( + name=relation(mapper(Address, addresses)), + foo=users.c.name)) + + @testing.resolve_artifact_names + def test_synonym(self): + + assert_col = [] + class extendedproperty(property): + attribute = 123 + def __getitem__(self, key): + return 'value' + + class User(object): + def _get_name(self): + assert_col.append(('get', self.name)) + return self.name + def _set_name(self, name): + assert_col.append(('set', name)) + self.name = name + uname = extendedproperty(_get_name, _set_name) + + mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=True), + uname = synonym('name'), + adlist = synonym('addresses', proxy=True), + adname = synonym('addresses') + )) + + # ensure the synonym can get at the proxied comparators without + # an explicit compile + User.name == 'ed' + User.adname.any() + + assert hasattr(User, 'adlist') + # as of 0.4.2, synonyms always create a property + assert hasattr(User, 'adname') + + # test compile + assert not isinstance(User.uname == 'jack', bool) + + assert User.uname.property + assert User.adlist.property + + sess = create_session() + + # test RowTuple names + row = sess.query(User.id, User.uname).first() + assert row.uname == row[1] + + u = sess.query(User).filter(User.uname=='jack').one() + + fixture = self.static.user_address_result[0].addresses + eq_(u.adlist, fixture) + + addr = sess.query(Address).filter_by(id=fixture[0].id).one() + u = sess.query(User).filter(User.adname.contains(addr)).one() + u2 = sess.query(User).filter(User.adlist.contains(addr)).one() + + assert u is u2 + + assert u not in sess.dirty + u.uname = "some user name" + assert len(assert_col) > 0 + eq_(assert_col, [('set', 'some user name')]) + eq_(u.uname, "some user name") + eq_(assert_col, [('set', 'some user name'), ('get', 'some user name')]) + eq_(u.name, "some user name") + assert u in sess.dirty + + eq_(User.uname.attribute, 123) + eq_(User.uname['key'], 'value') + + @testing.resolve_artifact_names + def test_synonym_column_location(self): + def go(): + mapper(User, users, properties={ + 'not_name':synonym('_name', map_column=True)}) + + assert_raises_message( + sa.exc.ArgumentError, + ("Can't compile synonym '_name': no column on table " + "'users' named 'not_name'"), + go) + + @testing.resolve_artifact_names + def test_column_synonyms(self): + """Synonyms which automatically instrument properties, set up aliased column, etc.""" + + + assert_col = [] + class User(object): + def _get_name(self): + assert_col.append(('get', self._name)) + return self._name + def _set_name(self, name): + assert_col.append(('set', name)) + self._name = name + name = property(_get_name, _set_name) + + mapper(Address, addresses) + mapper(User, users, properties = { + 'addresses':relation(Address, lazy=True), + 'name':synonym('_name', map_column=True) + }) + + # test compile + assert not isinstance(User.name == 'jack', bool) + + assert hasattr(User, 'name') + assert hasattr(User, '_name') + + sess = create_session() + u = sess.query(User).filter(User.name == 'jack').one() + eq_(u.name, 'jack') + u.name = 'foo' + eq_(u.name, 'foo') + eq_(assert_col, [('get', 'jack'), ('set', 'foo'), ('get', 'foo')]) + + @testing.resolve_artifact_names + def test_comparable(self): + class extendedproperty(property): + attribute = 123 + + def method1(self): + return "method1" + + def __getitem__(self, key): + return 'value' + + class UCComparator(sa.orm.PropComparator): + __hash__ = None + + def method1(self): + return "uccmethod1" + + def method2(self, other): + return "method2" + + def __eq__(self, other): + cls = self.prop.parent.class_ + col = getattr(cls, 'name') + if other is None: + return col == None + else: + return sa.func.upper(col) == sa.func.upper(other) + + def map_(with_explicit_property): + class User(object): + @extendedproperty + def uc_name(self): + if self.name is None: + return None + return self.name.upper() + if with_explicit_property: + args = (UCComparator, User.uc_name) + else: + args = (UCComparator,) + + mapper(User, users, properties=dict( + uc_name = sa.orm.comparable_property(*args))) + return User + + for User in (map_(True), map_(False)): + sess = create_session() + sess.begin() + q = sess.query(User) + + assert hasattr(User, 'name') + assert hasattr(User, 'uc_name') + + eq_(User.uc_name.method1(), "method1") + eq_(User.uc_name.method2('x'), "method2") + + assert_raises_message( + AttributeError, + "Neither 'extendedproperty' object nor 'UCComparator' object has an attribute 'nonexistent'", + getattr, User.uc_name, 'nonexistent') + + # test compile + assert not isinstance(User.uc_name == 'jack', bool) + u = q.filter(User.uc_name=='JACK').one() + + assert u.uc_name == "JACK" + assert u not in sess.dirty + + u.name = "some user name" + eq_(u.name, "some user name") + assert u in sess.dirty + eq_(u.uc_name, "SOME USER NAME") + + sess.flush() + sess.expunge_all() + + q = sess.query(User) + u2 = q.filter(User.name=='some user name').one() + u3 = q.filter(User.uc_name=='SOME USER NAME').one() + + assert u2 is u3 + + eq_(User.uc_name.attribute, 123) + eq_(User.uc_name['key'], 'value') + sess.rollback() + + @testing.resolve_artifact_names + def test_comparable_column(self): + class MyComparator(sa.orm.properties.ColumnProperty.Comparator): + def __eq__(self, other): + # lower case comparison + return func.lower(self.__clause_element__()) == func.lower(other) + + def intersects(self, other): + # non-standard comparator + return self.__clause_element__().op('&=')(other) + + mapper(User, users, properties={ + 'name':sa.orm.column_property(users.c.name, comparator_factory=MyComparator) + }) + + assert_raises_message( + AttributeError, + "Neither 'InstrumentedAttribute' object nor 'MyComparator' object has an attribute 'nonexistent'", + getattr, User.name, "nonexistent") + + eq_(str((User.name == 'ed').compile(dialect=sa.engine.default.DefaultDialect())) , "lower(users.name) = lower(:lower_1)") + eq_(str((User.name.intersects('ed')).compile(dialect=sa.engine.default.DefaultDialect())), "users.name &= :name_1") + + + @testing.resolve_artifact_names + def test_reconstructor(self): + recon = [] + + class User(object): + @reconstructor + def reconstruct(self): + recon.append('go') + + mapper(User, users) + + User() + eq_(recon, []) + create_session().query(User).first() + eq_(recon, ['go']) + + @testing.resolve_artifact_names + def test_reconstructor_inheritance(self): + recon = [] + class A(object): + @reconstructor + def reconstruct(self): + recon.append('A') + + class B(A): + @reconstructor + def reconstruct(self): + recon.append('B') + + class C(A): + @reconstructor + def reconstruct(self): + recon.append('C') + + mapper(A, users, polymorphic_on=users.c.name, + polymorphic_identity='jack') + mapper(B, inherits=A, polymorphic_identity='ed') + mapper(C, inherits=A, polymorphic_identity='chuck') + + A() + B() + C() + eq_(recon, []) + + sess = create_session() + sess.query(A).first() + sess.query(B).first() + sess.query(C).first() + eq_(recon, ['A', 'B', 'C']) + + @testing.resolve_artifact_names + def test_unmapped_reconstructor_inheritance(self): + recon = [] + class Base(object): + @reconstructor + def reconstruct(self): + recon.append('go') + + class User(Base): + pass + + mapper(User, users) + + User() + eq_(recon, []) + + create_session().query(User).first() + eq_(recon, ['go']) + + @testing.resolve_artifact_names + def test_unmapped_error(self): + mapper(Address, addresses) + sa.orm.clear_mappers() + + mapper(User, users, properties={ + 'addresses':relation(Address) + }) + + assert_raises(sa.orm.exc.UnmappedClassError, sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_oldstyle_mixin(self): + class OldStyle: + pass + class NewStyle(object): + pass + + class A(NewStyle, OldStyle): + pass + + mapper(A, users) + + class B(OldStyle, NewStyle): + pass + + mapper(B, users) + + +class OptionsTest(_fixtures.FixtureTest): + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_synonym_options(self): + mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=True, + order_by=addresses.c.id), + adlist = synonym('addresses', proxy=True))) + + + def go(): + sess = create_session() + u = (sess.query(User). + order_by(User.id). + options(sa.orm.eagerload('adlist')). + filter_by(name='jack')).one() + eq_(u.adlist, + [self.static.user_address_result[0].addresses[0]]) + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_eager_options(self): + """A lazy relation can be upgraded to an eager relation.""" + mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), + order_by=addresses.c.id))) + + sess = create_session() + l = (sess.query(User). + order_by(User.id). + options(sa.orm.eagerload('addresses'))).all() + + def go(): + eq_(l, self.static.user_address_result) + self.sql_count_(0, go) + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_eager_options_with_limit(self): + mapper(User, users, properties=dict( + addresses=relation(mapper(Address, addresses), lazy=True))) + + sess = create_session() + u = (sess.query(User). + options(sa.orm.eagerload('addresses')). + filter_by(id=8)).one() + + def go(): + eq_(u.id, 8) + eq_(len(u.addresses), 3) + self.sql_count_(0, go) + + sess.expunge_all() + + u = sess.query(User).filter_by(id=8).one() + eq_(u.id, 8) + eq_(len(u.addresses), 3) + + @testing.fails_on('maxdb', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_lazy_options_with_limit(self): + mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=False))) + + sess = create_session() + u = (sess.query(User). + options(sa.orm.lazyload('addresses')). + filter_by(id=8)).one() + + def go(): + eq_(u.id, 8) + eq_(len(u.addresses), 3) + self.sql_count_(1, go) + + @testing.resolve_artifact_names + def test_eager_degrade(self): + """An eager relation automatically degrades to a lazy relation if eager columns are not available""" + mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=False))) + + sess = create_session() + # first test straight eager load, 1 statement + def go(): + l = sess.query(User).order_by(User.id).all() + eq_(l, self.static.user_address_result) + self.sql_count_(1, go) + + sess.expunge_all() + + # then select just from users. run it into instances. + # then assert the data, which will launch 3 more lazy loads + # (previous users in session fell out of scope and were removed from + # session's identity map) + r = users.select().order_by(users.c.id).execute() + def go(): + l = list(sess.query(User).instances(r)) + eq_(l, self.static.user_address_result) + self.sql_count_(4, go) + + + @testing.resolve_artifact_names + def test_eager_degrade_deep(self): + # test with a deeper set of eager loads. when we first load the three + # users, they will have no addresses or orders. the number of lazy + # loads when traversing the whole thing will be three for the + # addresses and three for the orders. + mapper(Address, addresses) + + mapper(Keyword, keywords) + + mapper(Item, items, properties=dict( + keywords=relation(Keyword, secondary=item_keywords, + lazy=False, + order_by=item_keywords.c.keyword_id))) + + mapper(Order, orders, properties=dict( + items=relation(Item, secondary=order_items, lazy=False, + order_by=order_items.c.item_id))) + + mapper(User, users, properties=dict( + addresses=relation(Address, lazy=False, + order_by=addresses.c.id), + orders=relation(Order, lazy=False, + order_by=orders.c.id))) + + sess = create_session() + + # first test straight eager load, 1 statement + def go(): + l = sess.query(User).order_by(User.id).all() + eq_(l, self.static.user_all_result) + self.assert_sql_count(testing.db, go, 1) + + sess.expunge_all() + + # then select just from users. run it into instances. + # then assert the data, which will launch 6 more lazy loads + r = users.select().execute() + def go(): + l = list(sess.query(User).instances(r)) + eq_(l, self.static.user_all_result) + self.assert_sql_count(testing.db, go, 6) + + @testing.resolve_artifact_names + def test_lazy_options(self): + """An eager relation can be upgraded to a lazy relation.""" + mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=False) + )) + + sess = create_session() + l = (sess.query(User). + order_by(User.id). + options(sa.orm.lazyload('addresses'))).all() + + def go(): + eq_(l, self.static.user_address_result) + self.sql_count_(4, go) + + +class DeepOptionsTest(_fixtures.FixtureTest): + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Keyword, keywords) + + mapper(Item, items, properties=dict( + keywords=relation(Keyword, item_keywords, + order_by=item_keywords.c.item_id))) + + mapper(Order, orders, properties=dict( + items=relation(Item, order_items, + order_by=items.c.id))) + + mapper(User, users, order_by=users.c.id, properties=dict( + orders=relation(Order, order_by=orders.c.id))) + + @testing.resolve_artifact_names + def test_deep_options_1(self): + sess = create_session() + + # eagerload nothing. + u = sess.query(User).all() + def go(): + x = u[0].orders[1].items[0].keywords[1] + self.assert_sql_count(testing.db, go, 3) + + @testing.resolve_artifact_names + def test_deep_options_2(self): + sess = create_session() + + # eagerload orders.items.keywords; eagerload_all() implies eager load + # of orders, orders.items + l = (sess.query(User). + options(sa.orm.eagerload_all('orders.items.keywords'))).all() + def go(): + x = l[0].orders[1].items[0].keywords[1] + self.sql_count_(0, go) + + + @testing.resolve_artifact_names + def test_deep_options_3(self): + sess = create_session() + + # same thing, with separate options calls + q2 = (sess.query(User). + options(sa.orm.eagerload('orders')). + options(sa.orm.eagerload('orders.items')). + options(sa.orm.eagerload('orders.items.keywords'))) + u = q2.all() + def go(): + x = u[0].orders[1].items[0].keywords[1] + self.sql_count_(0, go) + + @testing.resolve_artifact_names + def test_deep_options_4(self): + sess = create_session() + + assert_raises_message( + sa.exc.ArgumentError, + r"Can't find entity Mapper\|Order\|orders in Query. " + r"Current list: \['Mapper\|User\|users'\]", + sess.query(User).options, sa.orm.eagerload(Order.items)) + + # eagerload "keywords" on items. it will lazy load "orders", then + # lazy load the "items" on the order, but on "items" it will eager + # load the "keywords" + q3 = sess.query(User).options(sa.orm.eagerload('orders.items.keywords')) + u = q3.all() + def go(): + x = u[0].orders[1].items[0].keywords[1] + self.sql_count_(2, go) + +class ValidatorTest(_fixtures.FixtureTest): + @testing.resolve_artifact_names + def test_scalar(self): + class User(_base.ComparableEntity): + @validates('name') + def validate_name(self, key, name): + assert name != 'fred' + return name + ' modified' + + mapper(User, users) + sess = create_session() + u1 = User(name='ed') + eq_(u1.name, 'ed modified') + assert_raises(AssertionError, setattr, u1, "name", "fred") + eq_(u1.name, 'ed modified') + sess.add(u1) + sess.flush() + sess.expunge_all() + eq_(sess.query(User).filter_by(name='ed modified').one(), User(name='ed')) + + + @testing.resolve_artifact_names + def test_collection(self): + class User(_base.ComparableEntity): + @validates('addresses') + def validate_address(self, key, ad): + assert '@' in ad.email_address + return ad + + mapper(User, users, properties={'addresses':relation(Address)}) + mapper(Address, addresses) + sess = create_session() + u1 = User(name='edward') + assert_raises(AssertionError, u1.addresses.append, Address(email_address='noemail')) + u1.addresses.append(Address(id=15, email_address='foo@bar.com')) + sess.add(u1) + sess.flush() + sess.expunge_all() + eq_( + sess.query(User).filter_by(name='edward').one(), + User(name='edward', addresses=[Address(email_address='foo@bar.com')]) + ) + +class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): + @testing.resolve_artifact_names + def test_kwarg_accepted(self): + class DummyComposite(object): + def __init__(self, x, y): + pass + + from sqlalchemy.orm.interfaces import PropComparator + + class MyFactory(PropComparator): + pass + + for args in ( + (column_property, users.c.name), + (deferred, users.c.name), + (synonym, 'name'), + (composite, DummyComposite, users.c.id, users.c.name), + (relation, Address), + (backref, 'address'), + (comparable_property, ), + (dynamic_loader, Address) + ): + fn = args[0] + args = args[1:] + fn(comparator_factory=MyFactory, *args) + + @testing.resolve_artifact_names + def test_column(self): + from sqlalchemy.orm.properties import ColumnProperty + + class MyFactory(ColumnProperty.Comparator): + __hash__ = None + def __eq__(self, other): + return func.foobar(self.__clause_element__()) == func.foobar(other) + mapper(User, users, properties={'name':column_property(users.c.name, comparator_factory=MyFactory)}) + self.assert_compile(User.name == 'ed', "foobar(users.name) = foobar(:foobar_1)", dialect=default.DefaultDialect()) + self.assert_compile(aliased(User).name == 'ed', "foobar(users_1.name) = foobar(:foobar_1)", dialect=default.DefaultDialect()) + + @testing.resolve_artifact_names + def test_synonym(self): + from sqlalchemy.orm.properties import ColumnProperty + class MyFactory(ColumnProperty.Comparator): + __hash__ = None + def __eq__(self, other): + return func.foobar(self.__clause_element__()) == func.foobar(other) + mapper(User, users, properties={'name':synonym('_name', map_column=True, comparator_factory=MyFactory)}) + self.assert_compile(User.name == 'ed', "foobar(users.name) = foobar(:foobar_1)", dialect=default.DefaultDialect()) + self.assert_compile(aliased(User).name == 'ed', "foobar(users_1.name) = foobar(:foobar_1)", dialect=default.DefaultDialect()) + + @testing.resolve_artifact_names + def test_relation(self): + from sqlalchemy.orm.properties import PropertyLoader + + class MyFactory(PropertyLoader.Comparator): + __hash__ = None + def __eq__(self, other): + return func.foobar(self.__clause_element__().c.user_id) == func.foobar(other.id) + + class MyFactory2(PropertyLoader.Comparator): + __hash__ = None + def __eq__(self, other): + return func.foobar(self.__clause_element__().c.id) == func.foobar(other.user_id) + + mapper(User, users) + mapper(Address, addresses, properties={ + 'user':relation(User, comparator_factory=MyFactory, + backref=backref("addresses", comparator_factory=MyFactory2) + ) + } + ) + self.assert_compile(Address.user == User(id=5), "foobar(addresses.user_id) = foobar(:foobar_1)", dialect=default.DefaultDialect()) + self.assert_compile(User.addresses == Address(id=5, user_id=7), "foobar(users.id) = foobar(:foobar_1)", dialect=default.DefaultDialect()) + + self.assert_compile(aliased(Address).user == User(id=5), "foobar(addresses_1.user_id) = foobar(:foobar_1)", dialect=default.DefaultDialect()) + self.assert_compile(aliased(User).addresses == Address(id=5, user_id=7), "foobar(users_1.id) = foobar(:foobar_1)", dialect=default.DefaultDialect()) + + +class DeferredTest(_fixtures.FixtureTest): + + @testing.resolve_artifact_names + def test_basic(self): + """A basic deferred load.""" + + mapper(Order, orders, order_by=orders.c.id, properties={ + 'description': deferred(orders.c.description)}) + + o = Order() + self.assert_(o.description is None) + + q = create_session().query(Order) + def go(): + l = q.all() + o2 = l[2] + x = o2.description + + self.sql_eq_(go, [ + ("SELECT orders.id AS orders_id, " + "orders.user_id AS orders_user_id, " + "orders.address_id AS orders_address_id, " + "orders.isopen AS orders_isopen " + "FROM orders ORDER BY orders.id", {}), + ("SELECT orders.description AS orders_description " + "FROM orders WHERE orders.id = :param_1", + {'param_1':3})]) + + @testing.resolve_artifact_names + def test_unsaved(self): + """Deferred loading does not kick in when just PK cols are set.""" + + mapper(Order, orders, properties={ + 'description': deferred(orders.c.description)}) + + sess = create_session() + o = Order() + sess.add(o) + o.id = 7 + def go(): + o.description = "some description" + self.sql_count_(0, go) + + @testing.resolve_artifact_names + def test_synonym_group_bug(self): + mapper(Order, orders, properties={ + 'isopen':synonym('_isopen', map_column=True), + 'description':deferred(orders.c.description, group='foo') + }) + + sess = create_session() + o1 = sess.query(Order).get(1) + eq_(o1.description, "order 1") + + @testing.resolve_artifact_names + def test_unsaved_2(self): + mapper(Order, orders, properties={ + 'description': deferred(orders.c.description)}) + + sess = create_session() + o = Order() + sess.add(o) + def go(): + o.description = "some description" + self.sql_count_(0, go) + + @testing.resolve_artifact_names + def test_unsaved_group(self): + """Deferred loading doesnt kick in when just PK cols are set""" + + mapper(Order, orders, order_by=orders.c.id, properties=dict( + description=deferred(orders.c.description, group='primary'), + opened=deferred(orders.c.isopen, group='primary'))) + + sess = create_session() + o = Order() + sess.add(o) + o.id = 7 + def go(): + o.description = "some description" + self.sql_count_(0, go) + + @testing.resolve_artifact_names + def test_unsaved_group_2(self): + mapper(Order, orders, order_by=orders.c.id, properties=dict( + description=deferred(orders.c.description, group='primary'), + opened=deferred(orders.c.isopen, group='primary'))) + + sess = create_session() + o = Order() + sess.add(o) + def go(): + o.description = "some description" + self.sql_count_(0, go) + + @testing.resolve_artifact_names + def test_save(self): + m = mapper(Order, orders, properties={ + 'description': deferred(orders.c.description)}) + + sess = create_session() + o2 = sess.query(Order).get(2) + o2.isopen = 1 + sess.flush() + + @testing.resolve_artifact_names + def test_group(self): + """Deferred load with a group""" + mapper(Order, orders, properties={ + 'userident': deferred(orders.c.user_id, group='primary'), + 'addrident': deferred(orders.c.address_id, group='primary'), + 'description': deferred(orders.c.description, group='primary'), + 'opened': deferred(orders.c.isopen, group='primary') + }) + + sess = create_session() + q = sess.query(Order).order_by(Order.id) + def go(): + l = q.all() + o2 = l[2] + eq_(o2.opened, 1) + eq_(o2.userident, 7) + eq_(o2.description, 'order 3') + + self.sql_eq_(go, [ + ("SELECT orders.id AS orders_id " + "FROM orders ORDER BY orders.id", {}), + ("SELECT orders.user_id AS orders_user_id, " + "orders.address_id AS orders_address_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen " + "FROM orders WHERE orders.id = :param_1", + {'param_1':3})]) + + o2 = q.all()[2] + eq_(o2.description, 'order 3') + assert o2 not in sess.dirty + o2.description = 'order 3' + def go(): + sess.flush() + self.sql_count_(0, go) + + @testing.resolve_artifact_names + def test_preserve_changes(self): + """A deferred load operation doesn't revert modifications on attributes""" + mapper(Order, orders, properties = { + 'userident': deferred(orders.c.user_id, group='primary'), + 'description': deferred(orders.c.description, group='primary'), + 'opened': deferred(orders.c.isopen, group='primary') + }) + sess = create_session() + o = sess.query(Order).get(3) + assert 'userident' not in o.__dict__ + o.description = 'somenewdescription' + eq_(o.description, 'somenewdescription') + def go(): + eq_(o.opened, 1) + self.assert_sql_count(testing.db, go, 1) + eq_(o.description, 'somenewdescription') + assert o in sess.dirty + + @testing.resolve_artifact_names + def test_commits_state(self): + """ + When deferred elements are loaded via a group, they get the proper + CommittedState and don't result in changes being committed + + """ + mapper(Order, orders, properties = { + 'userident':deferred(orders.c.user_id, group='primary'), + 'description':deferred(orders.c.description, group='primary'), + 'opened':deferred(orders.c.isopen, group='primary')}) + + sess = create_session() + o2 = sess.query(Order).get(3) + + # this will load the group of attributes + eq_(o2.description, 'order 3') + assert o2 not in sess.dirty + # this will mark it as 'dirty', but nothing actually changed + o2.description = 'order 3' + # therefore the flush() shouldnt actually issue any SQL + self.assert_sql_count(testing.db, sess.flush, 0) + + @testing.resolve_artifact_names + def test_options(self): + """Options on a mapper to create deferred and undeferred columns""" + + mapper(Order, orders) + + sess = create_session() + q = sess.query(Order).order_by(Order.id).options(defer('user_id')) + + def go(): + q.all()[0].user_id + + self.sql_eq_(go, [ + ("SELECT orders.id AS orders_id, " + "orders.address_id AS orders_address_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen " + "FROM orders ORDER BY orders.id", {}), + ("SELECT orders.user_id AS orders_user_id " + "FROM orders WHERE orders.id = :param_1", + {'param_1':1})]) + sess.expunge_all() + + q2 = q.options(sa.orm.undefer('user_id')) + self.sql_eq_(q2.all, [ + ("SELECT orders.id AS orders_id, " + "orders.user_id AS orders_user_id, " + "orders.address_id AS orders_address_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen " + "FROM orders ORDER BY orders.id", + {})]) + + @testing.resolve_artifact_names + def test_undefer_group(self): + mapper(Order, orders, properties={ + 'userident':deferred(orders.c.user_id, group='primary'), + 'description':deferred(orders.c.description, group='primary'), + 'opened':deferred(orders.c.isopen, group='primary')}) + + sess = create_session() + q = sess.query(Order).order_by(Order.id) + def go(): + l = q.options(sa.orm.undefer_group('primary')).all() + o2 = l[2] + eq_(o2.opened, 1) + eq_(o2.userident, 7) + eq_(o2.description, 'order 3') + + self.sql_eq_(go, [ + ("SELECT orders.user_id AS orders_user_id, " + "orders.description AS orders_description, " + "orders.isopen AS orders_isopen, " + "orders.id AS orders_id, " + "orders.address_id AS orders_address_id " + "FROM orders ORDER BY orders.id", + {})]) + + @testing.resolve_artifact_names + def test_locates_col(self): + """Manually adding a column to the result undefers the column.""" + + mapper(Order, orders, properties={ + 'description':deferred(orders.c.description)}) + + sess = create_session() + o1 = sess.query(Order).order_by(Order.id).first() + def go(): + eq_(o1.description, 'order 1') + self.sql_count_(1, go) + + sess = create_session() + o1 = (sess.query(Order). + order_by(Order.id). + add_column(orders.c.description).first())[0] + def go(): + eq_(o1.description, 'order 1') + self.sql_count_(0, go) + + @testing.resolve_artifact_names + def test_deep_options(self): + mapper(Item, items, properties=dict( + description=deferred(items.c.description))) + mapper(Order, orders, properties=dict( + items=relation(Item, secondary=order_items))) + mapper(User, users, properties=dict( + orders=relation(Order, order_by=orders.c.id))) + + sess = create_session() + q = sess.query(User).order_by(User.id) + l = q.all() + item = l[0].orders[1].items[1] + def go(): + eq_(item.description, 'item 4') + self.sql_count_(1, go) + eq_(item.description, 'item 4') + + sess.expunge_all() + l = q.options(sa.orm.undefer('orders.items.description')).all() + item = l[0].orders[1].items[1] + def go(): + eq_(item.description, 'item 4') + self.sql_count_(0, go) + eq_(item.description, 'item 4') + +class DeferredPopulationTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table("thing", metadata, + Column("id", Integer, primary_key=True), + Column("name", String(20))) + + Table("human", metadata, + Column("id", Integer, primary_key=True), + Column("thing_id", Integer, ForeignKey("thing.id")), + Column("name", String(20))) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + class Human(_base.BasicEntity): pass + class Thing(_base.BasicEntity): pass + + mapper(Human, human, properties={"thing": relation(Thing)}) + mapper(Thing, thing, properties={"name": deferred(thing.c.name)}) + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + thing.insert().execute([ + {"id": 1, "name": "Chair"}, + ]) + + human.insert().execute([ + {"id": 1, "thing_id": 1, "name": "Clark Kent"}, + ]) + + def _test(self, thing): + assert "name" in attributes.instance_state(thing).dict + + @testing.resolve_artifact_names + def test_no_previous_query(self): + session = create_session() + thing = session.query(Thing).options(sa.orm.undefer("name")).first() + self._test(thing) + + @testing.resolve_artifact_names + def test_query_twice_with_clear(self): + session = create_session() + result = session.query(Thing).first() + session.expunge_all() + thing = session.query(Thing).options(sa.orm.undefer("name")).first() + self._test(thing) + + @testing.resolve_artifact_names + def test_query_twice_no_clear(self): + session = create_session() + result = session.query(Thing).first() + thing = session.query(Thing).options(sa.orm.undefer("name")).first() + self._test(thing) + + @testing.resolve_artifact_names + def test_eagerload_with_clear(self): + session = create_session() + human = session.query(Human).options(sa.orm.eagerload("thing")).first() + session.expunge_all() + thing = session.query(Thing).options(sa.orm.undefer("name")).first() + self._test(thing) + + @testing.resolve_artifact_names + def test_eagerload_no_clear(self): + session = create_session() + human = session.query(Human).options(sa.orm.eagerload("thing")).first() + thing = session.query(Thing).options(sa.orm.undefer("name")).first() + self._test(thing) + + @testing.resolve_artifact_names + def test_join_with_clear(self): + session = create_session() + result = session.query(Human).add_entity(Thing).join("thing").first() + session.expunge_all() + thing = session.query(Thing).options(sa.orm.undefer("name")).first() + self._test(thing) + + @testing.resolve_artifact_names + def test_join_no_clear(self): + session = create_session() + result = session.query(Human).add_entity(Thing).join("thing").first() + thing = session.query(Thing).options(sa.orm.undefer("name")).first() + self._test(thing) + + +class CompositeTypesTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('graphs', metadata, + Column('id', Integer, primary_key=True), + Column('version_id', Integer, primary_key=True, nullable=True), + Column('name', String(30))) + + Table('edges', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('graph_id', Integer, nullable=False), + Column('graph_version_id', Integer, nullable=False), + Column('x1', Integer), + Column('y1', Integer), + Column('x2', Integer), + Column('y2', Integer), + sa.ForeignKeyConstraint( + ['graph_id', 'graph_version_id'], + ['graphs.id', 'graphs.version_id'])) + + Table('foobars', metadata, + Column('id', Integer, primary_key=True), + Column('x1', Integer, default=2), + Column('x2', Integer), + Column('x3', Integer, default=15), + Column('x4', Integer) + ) + + @testing.resolve_artifact_names + def test_basic(self): + class Point(object): + def __init__(self, x, y): + self.x = x + self.y = y + def __composite_values__(self): + return [self.x, self.y] + __hash__ = None + def __eq__(self, other): + return isinstance(other, Point) and other.x == self.x and other.y == self.y + def __ne__(self, other): + return not isinstance(other, Point) or not self.__eq__(other) + + class Graph(object): + pass + class Edge(object): + def __init__(self, start, end): + self.start = start + self.end = end + + mapper(Graph, graphs, properties={ + 'edges':relation(Edge) + }) + mapper(Edge, edges, properties={ + 'start':sa.orm.composite(Point, edges.c.x1, edges.c.y1), + 'end': sa.orm.composite(Point, edges.c.x2, edges.c.y2) + }) + + sess = create_session() + g = Graph() + g.id = 1 + g.version_id=1 + g.edges.append(Edge(Point(3, 4), Point(5, 6))) + g.edges.append(Edge(Point(14, 5), Point(2, 7))) + sess.add(g) + sess.flush() + + sess.expunge_all() + g2 = sess.query(Graph).get([g.id, g.version_id]) + for e1, e2 in zip(g.edges, g2.edges): + eq_(e1.start, e2.start) + eq_(e1.end, e2.end) + + g2.edges[1].end = Point(18, 4) + sess.flush() + sess.expunge_all() + e = sess.query(Edge).get(g2.edges[1].id) + eq_(e.end, Point(18, 4)) + + e.end.x = 19 + e.end.y = 5 + sess.flush() + sess.expunge_all() + eq_(sess.query(Edge).get(g2.edges[1].id).end, Point(19, 5)) + + g.edges[1].end = Point(19, 5) + + sess.expunge_all() + def go(): + g2 = (sess.query(Graph). + options(sa.orm.eagerload('edges'))).get([g.id, g.version_id]) + for e1, e2 in zip(g.edges, g2.edges): + eq_(e1.start, e2.start) + eq_(e1.end, e2.end) + self.assert_sql_count(testing.db, go, 1) + + # test comparison of CompositeProperties to their object instances + g = sess.query(Graph).get([1, 1]) + assert sess.query(Edge).filter(Edge.start==Point(3, 4)).one() is g.edges[0] + + assert sess.query(Edge).filter(Edge.start!=Point(3, 4)).first() is g.edges[1] + + eq_(sess.query(Edge).filter(Edge.start==None).all(), []) + + # query by columns + eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)]) + + e = g.edges[1] + e.end.x = e.end.y = None + sess.flush() + eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, None, None)]) + + + @testing.resolve_artifact_names + def test_pk(self): + """Using a composite type as a primary key""" + + class Version(object): + def __init__(self, id, version): + self.id = id + self.version = version + def __composite_values__(self): + return (self.id, self.version) + __hash__ = None + def __eq__(self, other): + return other.id == self.id and other.version == self.version + def __ne__(self, other): + return not self.__eq__(other) + + class Graph(object): + def __init__(self, version): + self.version = version + + mapper(Graph, graphs, allow_null_pks=True, properties={ + 'version':sa.orm.composite(Version, graphs.c.id, + graphs.c.version_id)}) + + sess = create_session() + g = Graph(Version(1, 1)) + sess.add(g) + sess.flush() + + sess.expunge_all() + g2 = sess.query(Graph).get([1, 1]) + eq_(g.version, g2.version) + sess.expunge_all() + + g2 = sess.query(Graph).get(Version(1, 1)) + eq_(g.version, g2.version) + + # test pk mutation + @testing.fails_on('mssql', 'Cannot update identity columns.') + def update_pk(): + g2.version = Version(2, 1) + sess.flush() + g3 = sess.query(Graph).get(Version(2, 1)) + eq_(g2.version, g3.version) + update_pk() + + # test pk with one column NULL + # TODO: can't seem to get NULL in for a PK value + # in either mysql or postgres, autoincrement=False etc. + # notwithstanding + @testing.fails_on_everything_except("sqlite") + def go(): + g = Graph(Version(2, None)) + sess.add(g) + sess.flush() + sess.expunge_all() + g2 = sess.query(Graph).filter_by(version=Version(2, None)).one() + eq_(g.version, g2.version) + go() + + @testing.resolve_artifact_names + def test_attributes_with_defaults(self): + class Foobar(object): + pass + + class FBComposite(object): + def __init__(self, x1, x2, x3, x4): + self.x1 = x1 + self.x2 = x2 + self.x3 = x3 + self.x4 = x4 + def __composite_values__(self): + return self.x1, self.x2, self.x3, self.x4 + __hash__ = None + def __eq__(self, other): + return other.x1 == self.x1 and other.x2 == self.x2 and other.x3 == self.x3 and other.x4 == self.x4 + def __ne__(self, other): + return not self.__eq__(other) + + mapper(Foobar, foobars, properties=dict( + foob=sa.orm.composite(FBComposite, foobars.c.x1, foobars.c.x2, foobars.c.x3, foobars.c.x4) + )) + + sess = create_session() + f1 = Foobar() + f1.foob = FBComposite(None, 5, None, None) + sess.add(f1) + sess.flush() + + assert f1.foob == FBComposite(2, 5, 15, None) + + + f2 = Foobar() + sess.add(f2) + sess.flush() + assert f2.foob == FBComposite(2, None, 15, None) + + + @testing.resolve_artifact_names + def test_set_composite_values(self): + class Foobar(object): + pass + + class FBComposite(object): + def __init__(self, x1, x2, x3, x4): + self.x1val = x1 + self.x2val = x2 + self.x3 = x3 + self.x4 = x4 + def __composite_values__(self): + return self.x1val, self.x2val, self.x3, self.x4 + def __set_composite_values__(self, x1, x2, x3, x4): + self.x1val = x1 + self.x2val = x2 + self.x3 = x3 + self.x4 = x4 + __hash__ = None + def __eq__(self, other): + return other.x1val == self.x1val and other.x2val == self.x2val and other.x3 == self.x3 and other.x4 == self.x4 + def __ne__(self, other): + return not self.__eq__(other) + + mapper(Foobar, foobars, properties=dict( + foob=sa.orm.composite(FBComposite, foobars.c.x1, foobars.c.x2, foobars.c.x3, foobars.c.x4) + )) + + sess = create_session() + f1 = Foobar() + f1.foob = FBComposite(None, 5, None, None) + sess.add(f1) + sess.flush() + + assert f1.foob == FBComposite(2, 5, 15, None) + + @testing.resolve_artifact_names + def test_save_null(self): + """test saving a null composite value + + See google groups thread for more context: + http://groups.google.com/group/sqlalchemy/browse_thread/thread/0c6580a1761b2c29 + + """ + class Point(object): + def __init__(self, x, y): + self.x = x + self.y = y + def __composite_values__(self): + return [self.x, self.y] + __hash__ = None + def __eq__(self, other): + return other.x == self.x and other.y == self.y + def __ne__(self, other): + return not self.__eq__(other) + + class Graph(object): + pass + class Edge(object): + def __init__(self, start, end): + self.start = start + self.end = end + + mapper(Graph, graphs, properties={ + 'edges':relation(Edge) + }) + mapper(Edge, edges, properties={ + 'start':sa.orm.composite(Point, edges.c.x1, edges.c.y1), + 'end':sa.orm.composite(Point, edges.c.x2, edges.c.y2) + }) + + sess = create_session() + g = Graph() + g.id = 1 + g.version_id=1 + e = Edge(None, None) + g.edges.append(e) + + sess.add(g) + sess.flush() + + sess.expunge_all() + + g2 = sess.query(Graph).get([1, 1]) + assert g2.edges[-1].start.x is None + assert g2.edges[-1].start.y is None + + +class NoLoadTest(_fixtures.FixtureTest): + run_inserts = 'once' + run_deletes = None + + @testing.resolve_artifact_names + def test_basic(self): + """A basic one-to-many lazy load""" + m = mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=None) + )) + q = create_session().query(m) + l = [None] + def go(): + x = q.filter(User.id == 7).all() + x[0].addresses + l[0] = x + self.assert_sql_count(testing.db, go, 1) + + self.assert_result(l[0], User, + {'id' : 7, 'addresses' : (Address, [])}, + ) + + @testing.resolve_artifact_names + def test_options(self): + m = mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=None) + )) + q = create_session().query(m).options(sa.orm.lazyload('addresses')) + l = [None] + def go(): + x = q.filter(User.id == 7).all() + x[0].addresses + l[0] = x + self.sql_count_(2, go) + + self.assert_result(l[0], User, + {'id' : 7, 'addresses' : (Address, [{'id' : 1}])}, + ) + + +class MapperExtensionTest(_fixtures.FixtureTest): + run_inserts = None + + def extension(self): + methods = [] + + class Ext(sa.orm.MapperExtension): + def instrument_class(self, mapper, cls): + methods.append('instrument_class') + return sa.orm.EXT_CONTINUE + + def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): + methods.append('init_instance') + return sa.orm.EXT_CONTINUE + + def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): + methods.append('init_failed') + return sa.orm.EXT_CONTINUE + + def translate_row(self, mapper, context, row): + methods.append('translate_row') + return sa.orm.EXT_CONTINUE + + def create_instance(self, mapper, selectcontext, row, class_): + methods.append('create_instance') + return sa.orm.EXT_CONTINUE + + def reconstruct_instance(self, mapper, instance): + methods.append('reconstruct_instance') + return sa.orm.EXT_CONTINUE + + def append_result(self, mapper, selectcontext, row, instance, result, **flags): + methods.append('append_result') + return sa.orm.EXT_CONTINUE + + def populate_instance(self, mapper, selectcontext, row, instance, **flags): + methods.append('populate_instance') + return sa.orm.EXT_CONTINUE + + def before_insert(self, mapper, connection, instance): + methods.append('before_insert') + return sa.orm.EXT_CONTINUE + + def after_insert(self, mapper, connection, instance): + methods.append('after_insert') + return sa.orm.EXT_CONTINUE + + def before_update(self, mapper, connection, instance): + methods.append('before_update') + return sa.orm.EXT_CONTINUE + + def after_update(self, mapper, connection, instance): + methods.append('after_update') + return sa.orm.EXT_CONTINUE + + def before_delete(self, mapper, connection, instance): + methods.append('before_delete') + return sa.orm.EXT_CONTINUE + + def after_delete(self, mapper, connection, instance): + methods.append('after_delete') + return sa.orm.EXT_CONTINUE + + return Ext, methods + + @testing.resolve_artifact_names + def test_basic(self): + """test that common user-defined methods get called.""" + Ext, methods = self.extension() + + mapper(User, users, extension=Ext()) + sess = create_session() + u = User(name='u1') + sess.add(u) + sess.flush() + u = sess.query(User).populate_existing().get(u.id) + sess.expunge_all() + u = sess.query(User).get(u.id) + u.name = 'u1 changed' + sess.flush() + sess.delete(u) + sess.flush() + eq_(methods, + ['instrument_class', 'init_instance', 'before_insert', + 'after_insert', 'translate_row', 'populate_instance', + 'append_result', 'translate_row', 'create_instance', + 'populate_instance', 'reconstruct_instance', 'append_result', + 'before_update', 'after_update', 'before_delete', 'after_delete']) + + @testing.resolve_artifact_names + def test_inheritance(self): + Ext, methods = self.extension() + + class AdminUser(User): + pass + + mapper(User, users, extension=Ext()) + mapper(AdminUser, addresses, inherits=User) + + sess = create_session() + am = AdminUser(name='au1', email_address='au1@e1') + sess.add(am) + sess.flush() + am = sess.query(AdminUser).populate_existing().get(am.id) + sess.expunge_all() + am = sess.query(AdminUser).get(am.id) + am.name = 'au1 changed' + sess.flush() + sess.delete(am) + sess.flush() + eq_(methods, + ['instrument_class', 'instrument_class', 'init_instance', + 'before_insert', 'after_insert', 'translate_row', + 'populate_instance', 'append_result', 'translate_row', + 'create_instance', 'populate_instance', 'reconstruct_instance', + 'append_result', 'before_update', 'after_update', 'before_delete', + 'after_delete']) + + @testing.resolve_artifact_names + def test_after_with_no_changes(self): + """after_update is called even if no columns were updated.""" + + Ext, methods = self.extension() + + mapper(Item, items, extension=Ext() , properties={ + 'keywords': relation(Keyword, secondary=item_keywords)}) + mapper(Keyword, keywords, extension=Ext()) + + sess = create_session() + i1 = Item(description="i1") + k1 = Keyword(name="k1") + sess.add(i1) + sess.add(k1) + sess.flush() + eq_(methods, + ['instrument_class', 'instrument_class', 'init_instance', + 'init_instance', 'before_insert', 'after_insert', + 'before_insert', 'after_insert']) + + del methods[:] + i1.keywords.append(k1) + sess.flush() + eq_(methods, ['before_update', 'after_update']) + + + @testing.resolve_artifact_names + def test_inheritance_with_dupes(self): + """Inheritance with the same extension instance on both mappers.""" + Ext, methods = self.extension() + + class AdminUser(User): + pass + + ext = Ext() + mapper(User, users, extension=ext) + mapper(AdminUser, addresses, inherits=User, extension=ext) + + sess = create_session() + am = AdminUser(name="au1", email_address="au1@e1") + sess.add(am) + sess.flush() + am = sess.query(AdminUser).populate_existing().get(am.id) + sess.expunge_all() + am = sess.query(AdminUser).get(am.id) + am.name = 'au1 changed' + sess.flush() + sess.delete(am) + sess.flush() + eq_(methods, + ['instrument_class', 'instrument_class', 'init_instance', + 'before_insert', 'after_insert', 'translate_row', + 'populate_instance', 'append_result', 'translate_row', + 'create_instance', 'populate_instance', 'reconstruct_instance', + 'append_result', 'before_update', 'after_update', 'before_delete', + 'after_delete']) + + @testing.resolve_artifact_names + def test_create_instance(self): + class CreateUserExt(sa.orm.MapperExtension): + def create_instance(self, mapper, selectcontext, row, class_): + return User.__new__(User) + + mapper(User, users, extension=CreateUserExt()) + sess = create_session() + u1 = User() + u1.name = 'ed' + sess.add(u1) + sess.flush() + sess.expunge_all() + assert sess.query(User).first() + + +class RequirementsTest(_base.MappedTest): + """Tests the contract for user classes.""" + + @classmethod + def define_tables(cls, metadata): + Table('ht1', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('value', String(10))) + Table('ht2', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('ht1_id', Integer, ForeignKey('ht1.id')), + Column('value', String(10))) + Table('ht3', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('value', String(10))) + Table('ht4', metadata, + Column('ht1_id', Integer, ForeignKey('ht1.id'), + primary_key=True), + Column('ht3_id', Integer, ForeignKey('ht3.id'), + primary_key=True)) + Table('ht5', metadata, + Column('ht1_id', Integer, ForeignKey('ht1.id'), + primary_key=True)) + Table('ht6', metadata, + Column('ht1a_id', Integer, ForeignKey('ht1.id'), + primary_key=True), + Column('ht1b_id', Integer, ForeignKey('ht1.id'), + primary_key=True), + Column('value', String(10))) + + @testing.resolve_artifact_names + def test_baseclass(self): + class OldStyle: + pass + + assert_raises(sa.exc.ArgumentError, mapper, OldStyle, ht1) + + assert_raises(sa.exc.ArgumentError, mapper, 123) + + class NoWeakrefSupport(str): + pass + + # TODO: is weakref support detectable without an instance? + #self.assertRaises(sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2) + + @testing.resolve_artifact_names + def test_comparison_overrides(self): + """Simple tests to ensure users can supply comparison __methods__. + + The suite-level test --options are better suited to detect + problems- they add selected __methods__ across the board on all + ORM tests. This test simply shoves a variety of operations + through the ORM to catch basic regressions early in a standard + test run. + """ + + # adding these methods directly to each class to avoid decoration + # by the testlib decorators. + class _Base(object): + def __init__(self, value='abc'): + self.value = value + def __nonzero__(self): + return False + def __hash__(self): + return hash(self.value) + def __eq__(self, other): + if isinstance(other, type(self)): + return self.value == other.value + return False + + class H1(_Base): + pass + class H2(_Base): + pass + class H3(_Base): + pass + class H6(_Base): + pass + + mapper(H1, ht1, properties={ + 'h2s': relation(H2, backref='h1'), + 'h3s': relation(H3, secondary=ht4, backref='h1s'), + 'h1s': relation(H1, secondary=ht5, backref='parent_h1'), + 't6a': relation(H6, backref='h1a', + primaryjoin=ht1.c.id==ht6.c.ht1a_id), + 't6b': relation(H6, backref='h1b', + primaryjoin=ht1.c.id==ht6.c.ht1b_id), + }) + mapper(H2, ht2) + mapper(H3, ht3) + mapper(H6, ht6) + + s = create_session() + for i in range(3): + h1 = H1() + s.add(h1) + + h1.h2s.append(H2()) + h1.h3s.extend([H3(), H3()]) + h1.h1s.append(H1()) + + s.flush() + eq_(ht1.count().scalar(), 4) + + h6 = H6() + h6.h1a = h1 + h6.h1b = h1 + + h6 = H6() + h6.h1a = h1 + h6.h1b = x = H1() + assert x in s + + h6.h1b.h2s.append(H2()) + + s.flush() + + h1.h2s.extend([H2(), H2()]) + s.flush() + + h1s = s.query(H1).options(sa.orm.eagerload('h2s')).all() + eq_(len(h1s), 5) + + self.assert_unordered_result(h1s, H1, + {'h2s': []}, + {'h2s': []}, + {'h2s': (H2, [{'value': 'abc'}, + {'value': 'abc'}, + {'value': 'abc'}])}, + {'h2s': []}, + {'h2s': (H2, [{'value': 'abc'}])}) + + h1s = s.query(H1).options(sa.orm.eagerload('h3s')).all() + + eq_(len(h1s), 5) + h1s = s.query(H1).options(sa.orm.eagerload_all('t6a.h1b'), + sa.orm.eagerload('h2s'), + sa.orm.eagerload_all('h3s.h1s')).all() + eq_(len(h1s), 5) + + +class MagicNamesTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('cartographers', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('alias', String(50)), + Column('quip', String(100))) + Table('maps', metadata, + Column('id', Integer, primary_key=True), + Column('cart_id', Integer, + ForeignKey('cartographers.id')), + Column('state', String(2)), + Column('data', sa.Text)) + + @classmethod + def setup_classes(cls): + class Cartographer(_base.BasicEntity): + pass + + class Map(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def test_mappish(self): + mapper(Cartographer, cartographers, properties=dict( + query=cartographers.c.quip)) + mapper(Map, maps, properties=dict( + mapper=relation(Cartographer, backref='maps'))) + + c = Cartographer(name='Lenny', alias='The Dude', + query='Where be dragons?') + m = Map(state='AK', mapper=c) + + sess = create_session() + sess.add(c) + sess.flush() + sess.expunge_all() + + for C, M in ((Cartographer, Map), + (sa.orm.aliased(Cartographer), sa.orm.aliased(Map))): + c1 = (sess.query(C). + filter(C.alias=='The Dude'). + filter(C.query=='Where be dragons?')).one() + m1 = sess.query(M).filter(M.mapper==c1).one() + + @testing.resolve_artifact_names + def test_direct_stateish(self): + for reserved in (sa.orm.attributes.ClassManager.STATE_ATTR, + sa.orm.attributes.ClassManager.MANAGER_ATTR): + t = Table('t', sa.MetaData(), + Column('id', Integer, primary_key=True), + Column(reserved, Integer)) + class T(object): + pass + + assert_raises_message( + KeyError, + ('%r: requested attribute name conflicts with ' + 'instrumentation attribute of the same name.' % reserved), + mapper, T, t) + + @testing.resolve_artifact_names + def test_indirect_stateish(self): + for reserved in (sa.orm.attributes.ClassManager.STATE_ATTR, + sa.orm.attributes.ClassManager.MANAGER_ATTR): + class M(object): + pass + + assert_raises_message( + KeyError, + ('requested attribute name conflicts with ' + 'instrumentation attribute of the same name'), + mapper, M, maps, properties={ + reserved: maps.c.state}) + + + diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py new file mode 100644 index 000000000..70097cbee --- /dev/null +++ b/test/orm/test_merge.py @@ -0,0 +1,735 @@ +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy.util import OrderedSet +from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property +from sqlalchemy.test.testing import eq_, ne_ +from test.orm import _base, _fixtures + + +class MergeTest(_fixtures.FixtureTest): + """Session..merge() functionality""" + + run_inserts = None + + def on_load_tracker(self, cls, canary=None): + if canary is None: + def canary(instance): + canary.called += 1 + canary.called = 0 + + manager = sa.orm.attributes.manager_of_class(cls) + manager.events.add_listener('on_load', canary) + + return canary + + @testing.resolve_artifact_names + def test_transient_to_pending(self): + mapper(User, users) + sess = create_session() + on_load = self.on_load_tracker(User) + + u = User(id=7, name='fred') + eq_(on_load.called, 0) + u2 = sess.merge(u) + eq_(on_load.called, 1) + assert u2 in sess + eq_(u2, User(id=7, name='fred')) + sess.flush() + sess.expunge_all() + eq_(sess.query(User).first(), User(id=7, name='fred')) + + @testing.resolve_artifact_names + def test_transient_to_pending_collection(self): + mapper(User, users, properties={ + 'addresses': relation(Address, backref='user', + collection_class=OrderedSet)}) + mapper(Address, addresses) + on_load = self.on_load_tracker(User) + self.on_load_tracker(Address, on_load) + + u = User(id=7, name='fred', addresses=OrderedSet([ + Address(id=1, email_address='fred1'), + Address(id=2, email_address='fred2'), + ])) + eq_(on_load.called, 0) + + sess = create_session() + sess.merge(u) + eq_(on_load.called, 3) + + merged_users = [e for e in sess if isinstance(e, User)] + eq_(len(merged_users), 1) + assert merged_users[0] is not u + + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).one(), + User(id=7, name='fred', addresses=OrderedSet([ + Address(id=1, email_address='fred1'), + Address(id=2, email_address='fred2'), + ])) + ) + + @testing.resolve_artifact_names + def test_transient_to_persistent(self): + mapper(User, users) + on_load = self.on_load_tracker(User) + + sess = create_session() + u = User(id=7, name='fred') + sess.add(u) + sess.flush() + sess.expunge_all() + + eq_(on_load.called, 0) + + _u2 = u2 = User(id=7, name='fred jones') + eq_(on_load.called, 0) + u2 = sess.merge(u2) + assert u2 is not _u2 + eq_(on_load.called, 1) + sess.flush() + sess.expunge_all() + eq_(sess.query(User).first(), User(id=7, name='fred jones')) + eq_(on_load.called, 2) + + @testing.resolve_artifact_names + def test_transient_to_persistent_collection(self): + mapper(User, users, properties={ + 'addresses':relation(Address, + backref='user', + collection_class=OrderedSet, + cascade="all, delete-orphan") + }) + mapper(Address, addresses) + + on_load = self.on_load_tracker(User) + self.on_load_tracker(Address, on_load) + + u = User(id=7, name='fred', addresses=OrderedSet([ + Address(id=1, email_address='fred1'), + Address(id=2, email_address='fred2'), + ])) + sess = create_session() + sess.add(u) + sess.flush() + sess.expunge_all() + + eq_(on_load.called, 0) + + u = User(id=7, name='fred', addresses=OrderedSet([ + Address(id=3, email_address='fred3'), + Address(id=4, email_address='fred4'), + ])) + + u = sess.merge(u) + + # 1. merges User object. updates into session. + # 2.,3. merges Address ids 3 & 4, saves into session. + # 4.,5. loads pre-existing elements in "addresses" collection, + # marks as deleted, Address ids 1 and 2. + eq_(on_load.called, 5) + + eq_(u, + User(id=7, name='fred', addresses=OrderedSet([ + Address(id=3, email_address='fred3'), + Address(id=4, email_address='fred4'), + ])) + ) + sess.flush() + sess.expunge_all() + eq_(sess.query(User).one(), + User(id=7, name='fred', addresses=OrderedSet([ + Address(id=3, email_address='fred3'), + Address(id=4, email_address='fred4'), + ])) + ) + + @testing.resolve_artifact_names + def test_detached_to_persistent_collection(self): + mapper(User, users, properties={ + 'addresses':relation(Address, + backref='user', + collection_class=OrderedSet)}) + mapper(Address, addresses) + on_load = self.on_load_tracker(User) + self.on_load_tracker(Address, on_load) + + a = Address(id=1, email_address='fred1') + u = User(id=7, name='fred', addresses=OrderedSet([ + a, + Address(id=2, email_address='fred2'), + ])) + sess = create_session() + sess.add(u) + sess.flush() + sess.expunge_all() + + u.name='fred jones' + u.addresses.add(Address(id=3, email_address='fred3')) + u.addresses.remove(a) + + eq_(on_load.called, 0) + u = sess.merge(u) + eq_(on_load.called, 4) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).first(), + User(id=7, name='fred jones', addresses=OrderedSet([ + Address(id=2, email_address='fred2'), + Address(id=3, email_address='fred3')]))) + + @testing.resolve_artifact_names + def test_unsaved_cascade(self): + """Merge of a transient entity with two child transient entities, with a bidirectional relation.""" + + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), + cascade="all", backref="user") + }) + on_load = self.on_load_tracker(User) + self.on_load_tracker(Address, on_load) + sess = create_session() + + u = User(id=7, name='fred') + a1 = Address(email_address='foo@bar.com') + a2 = Address(email_address='hoho@bar.com') + u.addresses.append(a1) + u.addresses.append(a2) + + u2 = sess.merge(u) + eq_(on_load.called, 3) + + eq_(u, + User(id=7, name='fred', addresses=[ + Address(email_address='foo@bar.com'), + Address(email_address='hoho@bar.com')])) + eq_(u2, + User(id=7, name='fred', addresses=[ + Address(email_address='foo@bar.com'), + Address(email_address='hoho@bar.com')])) + + sess.flush() + sess.expunge_all() + u2 = sess.query(User).get(7) + + eq_(u2, User(id=7, name='fred', addresses=[ + Address(email_address='foo@bar.com'), + Address(email_address='hoho@bar.com')])) + eq_(on_load.called, 6) + + @testing.resolve_artifact_names + def test_merge_empty_attributes(self): + mapper(User, dingalings) + u1 = User(id=1) + sess = create_session() + sess.merge(u1) + sess.flush() + assert u1.address_id is u1.data is None + + @testing.resolve_artifact_names + def test_attribute_cascade(self): + """Merge of a persistent entity with two child persistent entities.""" + + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), backref='user') + }) + on_load = self.on_load_tracker(User) + self.on_load_tracker(Address, on_load) + + sess = create_session() + + # set up data and save + u = User(id=7, name='fred', addresses=[ + Address(email_address='foo@bar.com'), + Address(email_address = 'hoho@la.com')]) + sess.add(u) + sess.flush() + + # assert data was saved + sess2 = create_session() + u2 = sess2.query(User).get(7) + eq_(u2, + User(id=7, name='fred', addresses=[ + Address(email_address='foo@bar.com'), + Address(email_address='hoho@la.com')])) + + # make local changes to data + u.name = 'fred2' + u.addresses[1].email_address = 'hoho@lalala.com' + + eq_(on_load.called, 3) + + # new session, merge modified data into session + sess3 = create_session() + u3 = sess3.merge(u) + eq_(on_load.called, 6) + + # ensure local changes are pending + eq_(u3, User(id=7, name='fred2', addresses=[ + Address(email_address='foo@bar.com'), + Address(email_address='hoho@lalala.com')])) + + # save merged data + sess3.flush() + + # assert modified/merged data was saved + sess.expunge_all() + u = sess.query(User).get(7) + eq_(u, User(id=7, name='fred2', addresses=[ + Address(email_address='foo@bar.com'), + Address(email_address='hoho@lalala.com')])) + eq_(on_load.called, 9) + + # merge persistent object into another session + sess4 = create_session() + u = sess4.merge(u) + assert len(u.addresses) + for a in u.addresses: + assert a.user is u + def go(): + sess4.flush() + # no changes; therefore flush should do nothing + self.assert_sql_count(testing.db, go, 0) + eq_(on_load.called, 12) + + # test with "dontload" merge + sess5 = create_session() + u = sess5.merge(u, dont_load=True) + assert len(u.addresses) + for a in u.addresses: + assert a.user is u + def go(): + sess5.flush() + # no changes; therefore flush should do nothing + # but also, dont_load wipes out any difference in committed state, + # so no flush at all + self.assert_sql_count(testing.db, go, 0) + eq_(on_load.called, 15) + + sess4 = create_session() + u = sess4.merge(u, dont_load=True) + # post merge change + u.addresses[1].email_address='afafds' + def go(): + sess4.flush() + # afafds change flushes + self.assert_sql_count(testing.db, go, 1) + eq_(on_load.called, 18) + + sess5 = create_session() + u2 = sess5.query(User).get(u.id) + eq_(u2.name, 'fred2') + eq_(u2.addresses[1].email_address, 'afafds') + eq_(on_load.called, 21) + + @testing.resolve_artifact_names + def test_one_to_many_cascade(self): + + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses))}) + + on_load = self.on_load_tracker(User) + self.on_load_tracker(Address, on_load) + + sess = create_session() + u = User(name='fred') + a1 = Address(email_address='foo@bar') + a2 = Address(email_address='foo@quux') + u.addresses.extend([a1, a2]) + + sess.add(u) + sess.flush() + + eq_(on_load.called, 0) + + sess2 = create_session() + u2 = sess2.query(User).get(u.id) + eq_(on_load.called, 1) + + u.addresses[1].email_address = 'addr 2 modified' + sess2.merge(u) + eq_(u2.addresses[1].email_address, 'addr 2 modified') + eq_(on_load.called, 3) + + sess3 = create_session() + u3 = sess3.query(User).get(u.id) + eq_(on_load.called, 4) + + u.name = 'also fred' + sess3.merge(u) + eq_(on_load.called, 6) + eq_(u3.name, 'also fred') + + @testing.resolve_artifact_names + def test_many_to_many_cascade(self): + + mapper(Order, orders, properties={ + 'items':relation(mapper(Item, items), secondary=order_items)}) + + on_load = self.on_load_tracker(Order) + self.on_load_tracker(Item, on_load) + + sess = create_session() + + i1 = Item() + i1.description='item 1' + + i2 = Item() + i2.description = 'item 2' + + o = Order() + o.description = 'order description' + o.items.append(i1) + o.items.append(i2) + + sess.add(o) + sess.flush() + + eq_(on_load.called, 0) + + sess2 = create_session() + o2 = sess2.query(Order).get(o.id) + eq_(on_load.called, 1) + + o.items[1].description = 'item 2 modified' + sess2.merge(o) + eq_(o2.items[1].description, 'item 2 modified') + eq_(on_load.called, 3) + + sess3 = create_session() + o3 = sess3.query(Order).get(o.id) + eq_( on_load.called, 4) + + o.description = 'desc modified' + sess3.merge(o) + eq_(on_load.called, 6) + eq_(o3.description, 'desc modified') + + @testing.resolve_artifact_names + def test_one_to_one_cascade(self): + + mapper(User, users, properties={ + 'address':relation(mapper(Address, addresses),uselist = False) + }) + on_load = self.on_load_tracker(User) + self.on_load_tracker(Address, on_load) + sess = create_session() + + u = User() + u.id = 7 + u.name = "fred" + a1 = Address() + a1.email_address='foo@bar.com' + u.address = a1 + + sess.add(u) + sess.flush() + + eq_(on_load.called, 0) + + sess2 = create_session() + u2 = sess2.query(User).get(7) + eq_(on_load.called, 1) + u2.name = 'fred2' + u2.address.email_address = 'hoho@lalala.com' + eq_(on_load.called, 2) + + u3 = sess.merge(u2) + eq_(on_load.called, 2) + assert u3 is u + + @testing.resolve_artifact_names + def test_transient_dontload(self): + mapper(User, users) + + sess = create_session() + u = User() + assert_raises_message(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True) + + + @testing.resolve_artifact_names + def test_dontload_with_backrefs(self): + """dontload populates relations in both directions without requiring a load""" + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), backref='user') + }) + + u = User(id=7, name='fred', addresses=[ + Address(email_address='ad1'), + Address(email_address='ad2')]) + sess = create_session() + sess.add(u) + sess.flush() + sess.close() + assert 'user' in u.addresses[1].__dict__ + + sess = create_session() + u2 = sess.merge(u, dont_load=True) + assert 'user' in u2.addresses[1].__dict__ + eq_(u2.addresses[1].user, User(id=7, name='fred')) + + sess.expire(u2.addresses[1], ['user']) + assert 'user' not in u2.addresses[1].__dict__ + sess.close() + + sess = create_session() + u = sess.merge(u2, dont_load=True) + assert 'user' not in u.addresses[1].__dict__ + eq_(u.addresses[1].user, User(id=7, name='fred')) + + + @testing.resolve_artifact_names + def test_dontload_with_eager(self): + """ + + This test illustrates that with dont_load=True, we can't just copy the + committed_state of the merged instance over; since it references + collection objects which themselves are to be merged. This + committed_state would instead need to be piecemeal 'converted' to + represent the correct objects. However, at the moment I'd rather not + support this use case; if you are merging with dont_load=True, you're + typically dealing with caching and the merged objects shouldnt be + 'dirty'. + + """ + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses)) + }) + sess = create_session() + u = User() + u.id = 7 + u.name = "fred" + a1 = Address() + a1.email_address='foo@bar.com' + u.addresses.append(a1) + + sess.add(u) + sess.flush() + + sess2 = create_session() + u2 = sess2.query(User).options(sa.orm.eagerload('addresses')).get(7) + + sess3 = create_session() + u3 = sess3.merge(u2, dont_load=True) + def go(): + sess3.flush() + self.assert_sql_count(testing.db, go, 0) + + @testing.resolve_artifact_names + def test_dont_load_disallows_dirty(self): + """dont_load doesnt support 'dirty' objects right now + + (see test_dont_load_with_eager()). Therefore lets assert it. + + """ + mapper(User, users) + sess = create_session() + u = User() + u.id = 7 + u.name = "fred" + sess.add(u) + sess.flush() + + u.name = 'ed' + sess2 = create_session() + try: + sess2.merge(u, dont_load=True) + assert False + except sa.exc.InvalidRequestError, e: + assert ("merge() with dont_load=True option does not support " + "objects marked as 'dirty'. flush() all changes on mapped " + "instances before merging with dont_load=True.") in str(e) + + u2 = sess2.query(User).get(7) + + sess3 = create_session() + u3 = sess3.merge(u2, dont_load=True) + assert not sess3.dirty + def go(): + sess3.flush() + self.assert_sql_count(testing.db, go, 0) + + + @testing.resolve_artifact_names + def test_dont_load_sets_backrefs(self): + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses),backref='user')}) + + sess = create_session() + u = User() + u.id = 7 + u.name = "fred" + a1 = Address() + a1.email_address='foo@bar.com' + u.addresses.append(a1) + + sess.add(u) + sess.flush() + + assert u.addresses[0].user is u + + sess2 = create_session() + u2 = sess2.merge(u, dont_load=True) + assert not sess2.dirty + def go(): + assert u2.addresses[0].user is u2 + self.assert_sql_count(testing.db, go, 0) + + @testing.resolve_artifact_names + def test_dont_load_preserves_parents(self): + """Merge with dont_load does not trigger a 'delete-orphan' operation. + + merge with dont_load sets attributes without using events. this means + the 'hasparent' flag is not propagated to the newly merged instance. + in fact this works out OK, because the '_state.parents' collection on + the newly merged instance is empty; since the mapper doesn't see an + active 'False' setting in this collection when _is_orphan() is called, + it does not count as an orphan (i.e. this is the 'optimistic' logic in + mapper._is_orphan().) + + """ + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses), + backref='user', cascade="all, delete-orphan")}) + sess = create_session() + u = User() + u.id = 7 + u.name = "fred" + a1 = Address() + a1.email_address='foo@bar.com' + u.addresses.append(a1) + sess.add(u) + sess.flush() + + assert u.addresses[0].user is u + + sess2 = create_session() + u2 = sess2.merge(u, dont_load=True) + assert not sess2.dirty + a2 = u2.addresses[0] + a2.email_address='somenewaddress' + assert not sa.orm.object_mapper(a2)._is_orphan( + sa.orm.attributes.instance_state(a2)) + sess2.flush() + sess2.expunge_all() + + eq_(sess2.query(User).get(u2.id).addresses[0].email_address, + 'somenewaddress') + + # this use case is not supported; this is with a pending Address on + # the pre-merged object, and we currently dont support 'dirty' objects + # being merged with dont_load=True. in this case, the empty + # '_state.parents' collection would be an issue, since the optimistic + # flag is False in _is_orphan() for pending instances. so if we start + # supporting 'dirty' with dont_load=True, this test will need to pass + sess = create_session() + u = sess.query(User).get(7) + u.addresses.append(Address()) + sess2 = create_session() + try: + u2 = sess2.merge(u, dont_load=True) + assert False + + # if dont_load is changed to support dirty objects, this code + # needs to pass + a2 = u2.addresses[0] + a2.email_address='somenewaddress' + assert not sa.orm.object_mapper(a2)._is_orphan( + sa.orm.attributes.instance_state(a2)) + sess2.flush() + sess2.expunge_all() + eq_(sess2.query(User).get(u2.id).addresses[0].email_address, + 'somenewaddress') + except sa.exc.InvalidRequestError, e: + assert "dont_load=True option does not support" in str(e) + + @testing.resolve_artifact_names + def test_synonym_comparable(self): + class User(object): + + class Comparator(PropComparator): + pass + + def _getValue(self): + return self._value + + def _setValue(self, value): + setattr(self, '_value', value) + + value = property(_getValue, _setValue) + + mapper(User, users, properties={ + 'uid':synonym('id'), + 'foobar':comparable_property(User.Comparator,User.value), + }) + + sess = create_session() + u = User() + u.name = 'ed' + sess.add(u) + sess.flush() + sess.expunge(u) + sess.merge(u) + + @testing.resolve_artifact_names + def test_cascade_doesnt_blowaway_manytoone(self): + """a merge test that was fixed by [ticket:1202]""" + + s = create_session(autoflush=True) + mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses),backref='user')}) + + a1 = Address(user=s.merge(User(id=1, name='ed')), email_address='x') + before_id = id(a1.user) + a2 = Address(user=s.merge(User(id=1, name='jack')), email_address='x') + after_id = id(a1.user) + other_id = id(a2.user) + eq_(before_id, other_id) + eq_(after_id, other_id) + eq_(before_id, after_id) + eq_(a1.user, a2.user) + + @testing.resolve_artifact_names + def test_cascades_dont_autoflush(self): + sess = create_session(autoflush=True) + m = mapper(User, users, properties={ + 'addresses':relation(mapper(Address, addresses),backref='user')}) + user = User(id=8, name='fred', addresses=[Address(email_address='user')]) + merged_user = sess.merge(user) + assert merged_user in sess.new + sess.flush() + assert merged_user not in sess.new + + @testing.resolve_artifact_names + def test_cascades_dont_autoflush_2(self): + mapper(User, users, properties={ + 'addresses':relation(Address, + backref='user', + cascade="all, delete-orphan") + }) + mapper(Address, addresses) + + u = User(id=7, name='fred', addresses=[ + Address(id=1, email_address='fred1'), + ]) + sess = create_session(autoflush=True, autocommit=False) + sess.add(u) + sess.commit() + + sess.expunge_all() + + u = User(id=7, name='fred', addresses=[ + Address(id=1, email_address='fred1'), + Address(id=2, email_address='fred2'), + ]) + sess.merge(u) + assert sess.autoflush + sess.commit() + + + + diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py new file mode 100644 index 000000000..1376c402e --- /dev/null +++ b/test/orm/test_naturalpks.py @@ -0,0 +1,482 @@ +""" +Primary key changing capabilities and passive/non-passive cascading updates. + +""" +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base + +class NaturalPKTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + users = Table('users', metadata, + Column('username', String(50), primary_key=True), + Column('fullname', String(100)), + test_needs_fk=True) + + addresses = Table('addresses', metadata, + Column('email', String(50), primary_key=True), + Column('username', String(50), ForeignKey('users.username', onupdate="cascade")), + test_needs_fk=True) + + items = Table('items', metadata, + Column('itemname', String(50), primary_key=True), + Column('description', String(100)), + test_needs_fk=True) + + users_to_items = Table('users_to_items', metadata, + Column('username', String(50), ForeignKey('users.username', onupdate='cascade'), primary_key=True), + Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True), + test_needs_fk=True) + + @classmethod + def setup_classes(cls): + class User(_base.ComparableEntity): + pass + class Address(_base.ComparableEntity): + pass + class Item(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def test_entity(self): + mapper(User, users) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + + sess.add(u1) + sess.flush() + assert sess.query(User).get('jack') is u1 + + u1.username = 'ed' + sess.flush() + + def go(): + assert sess.query(User).get('ed') is u1 + self.assert_sql_count(testing.db, go, 0) + + assert sess.query(User).get('jack') is None + + sess.expunge_all() + u1 = sess.query(User).get('ed') + eq_(User(username='ed', fullname='jack'), u1) + + @testing.resolve_artifact_names + def test_load_after_expire(self): + mapper(User, users) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + + sess.add(u1) + sess.flush() + assert sess.query(User).get('jack') is u1 + + users.update(values={User.username:'jack'}).execute(username='ed') + + # expire/refresh works off of primary key. the PK is gone + # in this case so theres no way to look it up. criterion- + # based session invalidation could solve this [ticket:911] + sess.expire(u1) + assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username') + + sess.expunge_all() + assert sess.query(User).get('jack') is None + assert sess.query(User).get('ed').fullname == 'jack' + + @testing.resolve_artifact_names + def test_flush_new_pk_after_expire(self): + mapper(User, users) + sess = create_session() + u1 = User(username='jack', fullname='jack') + + sess.add(u1) + sess.flush() + assert sess.query(User).get('jack') is u1 + + sess.expire(u1) + u1.username = 'ed' + sess.flush() + sess.expunge_all() + assert sess.query(User).get('ed').fullname == 'jack' + + + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + def test_onetomany_passive(self): + self._test_onetomany(True) + + def test_onetomany_nonpassive(self): + self._test_onetomany(False) + + @testing.resolve_artifact_names + def _test_onetomany(self, passive_updates): + mapper(User, users, properties={ + 'addresses':relation(Address, passive_updates=passive_updates) + }) + mapper(Address, addresses) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + u1.addresses.append(Address(email='jack1')) + u1.addresses.append(Address(email='jack2')) + sess.add(u1) + sess.flush() + + assert sess.query(Address).get('jack1') is u1.addresses[0] + + u1.username = 'ed' + sess.flush() + assert u1.addresses[0].username == 'ed' + + sess.expunge_all() + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + + u1 = sess.query(User).get('ed') + u1.username = 'jack' + def go(): + sess.flush() + if not passive_updates: + self.assert_sql_count(testing.db, go, 4) # test passive_updates=False; load addresses, update user, update 2 addresses + else: + self.assert_sql_count(testing.db, go, 1) # test passive_updates=True; update user + sess.expunge_all() + assert User(username='jack', addresses=[Address(username='jack'), Address(username='jack')]) == sess.query(User).get('jack') + + u1 = sess.query(User).get('jack') + u1.addresses = [] + u1.username = 'fred' + sess.flush() + sess.expunge_all() + assert sess.query(Address).get('jack1').username is None + u1 = sess.query(User).get('fred') + eq_(User(username='fred', fullname='jack'), u1) + + + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + def test_manytoone_passive(self): + self._test_manytoone(True) + + def test_manytoone_nonpassive(self): + self._test_manytoone(False) + + @testing.resolve_artifact_names + def _test_manytoone(self, passive_updates): + mapper(User, users) + mapper(Address, addresses, properties={ + 'user':relation(User, passive_updates=passive_updates) + }) + + sess = create_session() + a1 = Address(email='jack1') + a2 = Address(email='jack2') + + u1 = User(username='jack', fullname='jack') + a1.user = u1 + a2.user = u1 + sess.add(a1) + sess.add(a2) + sess.flush() + + u1.username = 'ed' + + def go(): + sess.flush() + if passive_updates: + self.assert_sql_count(testing.db, go, 1) + else: + self.assert_sql_count(testing.db, go, 3) + + def go(): + sess.flush() + self.assert_sql_count(testing.db, go, 0) + + assert a1.username == a2.username == 'ed' + sess.expunge_all() + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + def test_onetoone_passive(self): + self._test_onetoone(True) + + def test_onetoone_nonpassive(self): + self._test_onetoone(False) + + @testing.resolve_artifact_names + def _test_onetoone(self, passive_updates): + mapper(User, users, properties={ + "address":relation(Address, passive_updates=passive_updates, uselist=False) + }) + mapper(Address, addresses) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + sess.add(u1) + sess.flush() + + a1 = Address(email='jack1') + u1.address = a1 + sess.add(a1) + sess.flush() + + u1.username = 'ed' + + def go(): + sess.flush() + if passive_updates: + sess.expire(u1, ['address']) + self.assert_sql_count(testing.db, go, 1) + else: + self.assert_sql_count(testing.db, go, 2) + + def go(): + sess.flush() + self.assert_sql_count(testing.db, go, 0) + + sess.expunge_all() + eq_([Address(username='ed')], sess.query(Address).all()) + + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + def test_bidirectional_passive(self): + self._test_bidirectional(True) + + def test_bidirectional_nonpassive(self): + self._test_bidirectional(False) + + @testing.resolve_artifact_names + def _test_bidirectional(self, passive_updates): + mapper(User, users) + mapper(Address, addresses, properties={ + 'user':relation(User, passive_updates=passive_updates, + backref='addresses')}) + + sess = create_session() + a1 = Address(email='jack1') + a2 = Address(email='jack2') + + u1 = User(username='jack', fullname='jack') + a1.user = u1 + a2.user = u1 + sess.add(a1) + sess.add(a2) + sess.flush() + + u1.username = 'ed' + (ad1, ad2) = sess.query(Address).all() + eq_([Address(username='jack'), Address(username='jack')], [ad1, ad2]) + def go(): + sess.flush() + if passive_updates: + sess.expire(u1, ['addresses']) + self.assert_sql_count(testing.db, go, 1) + else: + self.assert_sql_count(testing.db, go, 3) + eq_([Address(username='ed'), Address(username='ed')], [ad1, ad2]) + sess.expunge_all() + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + + u1 = sess.query(User).get('ed') + assert len(u1.addresses) == 2 # load addresses + u1.username = 'fred' + def go(): + sess.flush() + # check that the passive_updates is on on the other side + if passive_updates: + sess.expire(u1, ['addresses']) + self.assert_sql_count(testing.db, go, 1) + else: + self.assert_sql_count(testing.db, go, 3) + sess.expunge_all() + eq_([Address(username='fred'), Address(username='fred')], sess.query(Address).all()) + + + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + def test_manytomany_passive(self): + self._test_manytomany(True) + + @testing.fails_on('mysql', 'the executemany() of the association table fails to report the correct row count') + def test_manytomany_nonpassive(self): + self._test_manytomany(False) + + @testing.resolve_artifact_names + def _test_manytomany(self, passive_updates): + mapper(User, users, properties={ + 'items':relation(Item, secondary=users_to_items, backref='users', + passive_updates=passive_updates)}) + mapper(Item, items) + + sess = create_session() + u1 = User(username='jack') + u2 = User(username='fred') + i1 = Item(itemname='item1') + i2 = Item(itemname='item2') + + u1.items.append(i1) + u1.items.append(i2) + i2.users.append(u2) + sess.add(u1) + sess.add(u2) + sess.flush() + + r = sess.query(Item).all() + # ComparableEntity can't handle a comparison with the backrefs + # involved.... + eq_(Item(itemname='item1'), r[0]) + eq_(['jack'], [u.username for u in r[0].users]) + eq_(Item(itemname='item2'), r[1]) + eq_(['jack', 'fred'], [u.username for u in r[1].users]) + + u2.username='ed' + def go(): + sess.flush() + go() + def go(): + sess.flush() + self.assert_sql_count(testing.db, go, 0) + + sess.expunge_all() + r = sess.query(Item).all() + eq_(Item(itemname='item1'), r[0]) + eq_(['jack'], [u.username for u in r[0].users]) + eq_(Item(itemname='item2'), r[1]) + eq_(['ed', 'jack'], sorted([u.username for u in r[1].users])) + + sess.expunge_all() + u2 = sess.query(User).get(u2.username) + u2.username='wendy' + sess.flush() + r = sess.query(Item).with_parent(u2).all() + eq_(Item(itemname='item2'), r[0]) + + +class SelfRefTest(_base.MappedTest): + __unsupported_on__ = 'mssql' # mssql doesn't allow ON UPDATE on self-referential keys + + @classmethod + def define_tables(cls, metadata): + Table('nodes', metadata, + Column('name', String(50), primary_key=True), + Column('parent', String(50), + ForeignKey('nodes.name', onupdate='cascade'))) + + @classmethod + def setup_classes(cls): + class Node(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def test_onetomany(self): + mapper(Node, nodes, properties={ + 'children': relation(Node, + backref=sa.orm.backref('parentnode', + remote_side=nodes.c.name, + passive_updates=False), + passive_updates=False)}) + + sess = create_session() + n1 = Node(name='n1') + n1.children.append(Node(name='n11')) + n1.children.append(Node(name='n12')) + n1.children.append(Node(name='n13')) + sess.add(n1) + sess.flush() + + n1.name = 'new n1' + sess.flush() + eq_(n1.children[1].parent, 'new n1') + eq_(['new n1', 'new n1', 'new n1'], + [n.parent + for n in sess.query(Node).filter( + Node.name.in_(['n11', 'n12', 'n13']))]) + + +class NonPKCascadeTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('username', String(50), unique=True), + Column('fullname', String(100)), + test_needs_fk=True) + + Table('addresses', metadata, + Column('id', Integer, primary_key=True), + Column('email', String(50)), + Column('username', String(50), + ForeignKey('users.username', onupdate="cascade")), + test_needs_fk=True + ) + + @classmethod + def setup_classes(cls): + class User(_base.ComparableEntity): + pass + class Address(_base.ComparableEntity): + pass + + @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + def test_onetomany_passive(self): + self._test_onetomany(True) + + def test_onetomany_nonpassive(self): + self._test_onetomany(False) + + @testing.resolve_artifact_names + def _test_onetomany(self, passive_updates): + mapper(User, users, properties={ + 'addresses':relation(Address, passive_updates=passive_updates)}) + mapper(Address, addresses) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + u1.addresses.append(Address(email='jack1')) + u1.addresses.append(Address(email='jack2')) + sess.add(u1) + sess.flush() + a1 = u1.addresses[0] + + eq_(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)]) + + assert sess.query(Address).get(a1.id) is u1.addresses[0] + + u1.username = 'ed' + sess.flush() + assert u1.addresses[0].username == 'ed' + eq_(sa.select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)]) + + sess.expunge_all() + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + + u1 = sess.query(User).get(u1.id) + u1.username = 'jack' + def go(): + sess.flush() + if not passive_updates: + self.assert_sql_count(testing.db, go, 4) # test passive_updates=False; load addresses, update user, update 2 addresses + else: + self.assert_sql_count(testing.db, go, 1) # test passive_updates=True; update user + sess.expunge_all() + assert User(username='jack', addresses=[Address(username='jack'), Address(username='jack')]) == sess.query(User).get(u1.id) + sess.expunge_all() + + u1 = sess.query(User).get(u1.id) + u1.addresses = [] + u1.username = 'fred' + sess.flush() + sess.expunge_all() + a1 = sess.query(Address).get(a1.id) + eq_(a1.username, None) + + eq_(sa.select([addresses.c.username]).execute().fetchall(), [(None,), (None,)]) + + u1 = sess.query(User).get(u1.id) + eq_(User(username='fred', fullname='jack'), u1) + + diff --git a/test/orm/test_onetoone.py b/test/orm/test_onetoone.py new file mode 100644 index 000000000..0d66915ea --- /dev/null +++ b/test/orm/test_onetoone.py @@ -0,0 +1,76 @@ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base + + +class O2OTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('jack', metadata, + Column('id', Integer, primary_key=True), + Column('number', String(50)), + Column('status', String(20)), + Column('subroom', String(5))) + + Table('port', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30)), + Column('description', String(100)), + Column('jack_id', Integer, ForeignKey("jack.id"))) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + class Jack(_base.BasicEntity): + pass + class Port(_base.BasicEntity): + pass + + + @testing.resolve_artifact_names + def test_basic(self): + mapper(Port, port) + mapper(Jack, jack, + order_by=[jack.c.number], + properties=dict( + port=relation(Port, backref='jack', + uselist=False, + )), + ) + + session = create_session() + + j = Jack(number='101') + session.add(j) + p = Port(name='fa0/1') + session.add(p) + + j.port=p + session.flush() + jid = j.id + pid = p.id + + j=session.query(Jack).get(jid) + p=session.query(Port).get(pid) + assert p.jack is not None + assert p.jack is j + assert j.port is not None + p.jack = None + assert j.port is None + + session.expunge_all() + + j = session.query(Jack).get(jid) + p = session.query(Port).get(pid) + + j.port=None + self.assert_(p.jack is None) + session.flush() + + session.delete(j) + session.flush() + diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py new file mode 100644 index 000000000..5343cc15b --- /dev/null +++ b/test/orm/test_pickled.py @@ -0,0 +1,194 @@ +from sqlalchemy.test.testing import eq_ +import pickle +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, attributes +from test.orm import _base, _fixtures + + +User, EmailUser = None, None + +class PickleTest(_fixtures.FixtureTest): + run_inserts = None + + @testing.resolve_artifact_names + def test_transient(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref="user") + }) + mapper(Address, addresses) + + sess = create_session() + u1 = User(name='ed') + u1.addresses.append(Address(email_address='ed@bar.com')) + + u2 = pickle.loads(pickle.dumps(u1)) + sess.add(u2) + sess.flush() + + sess.expunge_all() + + eq_(u1, sess.query(User).get(u2.id)) + + @testing.resolve_artifact_names + def test_class_deferred_cols(self): + mapper(User, users, properties={ + 'name':sa.orm.deferred(users.c.name), + 'addresses':relation(Address, backref="user") + }) + mapper(Address, addresses, properties={ + 'email_address':sa.orm.deferred(addresses.c.email_address) + }) + sess = create_session() + u1 = User(name='ed') + u1.addresses.append(Address(email_address='ed@bar.com')) + sess.add(u1) + sess.flush() + sess.expunge_all() + u1 = sess.query(User).get(u1.id) + assert 'name' not in u1.__dict__ + assert 'addresses' not in u1.__dict__ + + u2 = pickle.loads(pickle.dumps(u1)) + sess2 = create_session() + sess2.add(u2) + eq_(u2.name, 'ed') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + + u2 = pickle.loads(pickle.dumps(u1)) + sess2 = create_session() + u2 = sess2.merge(u2, dont_load=True) + eq_(u2.name, 'ed') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + + @testing.resolve_artifact_names + def test_instance_deferred_cols(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref="user") + }) + mapper(Address, addresses) + + sess = create_session() + u1 = User(name='ed') + u1.addresses.append(Address(email_address='ed@bar.com')) + sess.add(u1) + sess.flush() + sess.expunge_all() + + u1 = sess.query(User).options(sa.orm.defer('name'), sa.orm.defer('addresses.email_address')).get(u1.id) + assert 'name' not in u1.__dict__ + assert 'addresses' not in u1.__dict__ + + u2 = pickle.loads(pickle.dumps(u1)) + sess2 = create_session() + sess2.add(u2) + eq_(u2.name, 'ed') + assert 'addresses' not in u2.__dict__ + ad = u2.addresses[0] + assert 'email_address' not in ad.__dict__ + eq_(ad.email_address, 'ed@bar.com') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + + u2 = pickle.loads(pickle.dumps(u1)) + sess2 = create_session() + u2 = sess2.merge(u2, dont_load=True) + eq_(u2.name, 'ed') + assert 'addresses' not in u2.__dict__ + ad = u2.addresses[0] + assert 'email_address' in ad.__dict__ # mapper options dont transmit over merge() right now + eq_(ad.email_address, 'ed@bar.com') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + + @testing.resolve_artifact_names + def test_options_with_descriptors(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref="user") + }) + mapper(Address, addresses) + sess = create_session() + u1 = User(name='ed') + u1.addresses.append(Address(email_address='ed@bar.com')) + sess.add(u1) + sess.flush() + sess.expunge_all() + + for opt in [ + sa.orm.eagerload(User.addresses), + sa.orm.eagerload("addresses"), + sa.orm.defer("name"), + sa.orm.defer(User.name), + sa.orm.defer([User.name]), + sa.orm.eagerload("addresses", User.addresses), + sa.orm.eagerload(["addresses", User.addresses]), + ]: + opt2 = pickle.loads(pickle.dumps(opt)) + eq_(opt.key, opt2.key) + + u1 = sess.query(User).options(opt).first() + + u2 = pickle.loads(pickle.dumps(u1)) + + +class PolymorphicDeferredTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30)), + Column('type', String(30))) + Table('email_users', metadata, + Column('id', Integer, ForeignKey('users.id'), primary_key=True), + Column('email_address', String(30))) + + @classmethod + def setup_classes(cls): + global User, EmailUser + class User(_base.BasicEntity): + pass + + class EmailUser(User): + pass + + @classmethod + def teardown_class(cls): + global User, EmailUser + User, EmailUser = None, None + super(PolymorphicDeferredTest, cls).teardown_class() + + @testing.resolve_artifact_names + def test_polymorphic_deferred(self): + mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type) + mapper(EmailUser, email_users, inherits=User, polymorphic_identity='emailuser') + + eu = EmailUser(name="user1", email_address='foo@bar.com') + sess = create_session() + sess.add(eu) + sess.flush() + sess.expunge_all() + + eu = sess.query(User).first() + eu2 = pickle.loads(pickle.dumps(eu)) + sess2 = create_session() + sess2.add(eu2) + assert 'email_address' not in eu2.__dict__ + eq_(eu2.email_address, 'foo@bar.com') + +class CustomSetupTeardownTest(_fixtures.FixtureTest): + @testing.resolve_artifact_names + def test_rebuild_state(self): + """not much of a 'test', but illustrate how to + remove instance-level state before pickling. + + """ + mapper(User, users) + + u1 = User() + attributes.manager_of_class(User).teardown_instance(u1) + assert not u1.__dict__ + u2 = pickle.loads(pickle.dumps(u1)) + attributes.manager_of_class(User).setup_instance(u2) + assert attributes.instance_state(u2) + diff --git a/test/orm/test_query.py b/test/orm/test_query.py new file mode 100644 index 000000000..66c219b10 --- /dev/null +++ b/test/orm/test_query.py @@ -0,0 +1,3024 @@ +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +import operator +from sqlalchemy import * +from sqlalchemy import exc as sa_exc, util +from sqlalchemy.sql import compiler, table, column +from sqlalchemy.engine import default +from sqlalchemy.orm import * +from sqlalchemy.orm import attributes + +from sqlalchemy.test.testing import eq_ + +import sqlalchemy as sa +from sqlalchemy.test import testing, AssertsCompiledSQL, Column, engines + +from test.orm import _fixtures +from test.orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \ + Dingaling, item_keywords, dingalings, User, items,\ + orders, Address, users, nodes, \ + order_items, Item, Order, Node + +from test.orm import _base + +from sqlalchemy.orm.util import join, outerjoin, with_parent + +class QueryTest(_fixtures.FixtureTest): + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + + @classmethod + def setup_mappers(cls): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', order_by=addresses.c.id), + 'orders':relation(Order, backref='user', order_by=orders.c.id), # o2m, m2o + }) + mapper(Address, addresses, properties={ + 'dingaling':relation(Dingaling, uselist=False, backref="address") #o2o + }) + mapper(Dingaling, dingalings) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m + 'address':relation(Address), # m2o + }) + mapper(Item, items, properties={ + 'keywords':relation(Keyword, secondary=item_keywords) #m2m + }) + mapper(Keyword, keywords) + + mapper(Node, nodes, properties={ + 'children':relation(Node, + backref=backref('parent', remote_side=[nodes.c.id]) + ) + }) + + compile_mappers() + +class RowTupleTest(QueryTest): + run_setup_mappers = None + + def test_custom_names(self): + mapper(User, users, properties={ + 'uname':users.c.name + }) + + row = create_session().query(User.id, User.uname).filter(User.id==7).first() + assert row.id == 7 + assert row.uname == 'jack' + +class GetTest(QueryTest): + def test_get(self): + s = create_session() + assert s.query(User).get(19) is None + u = s.query(User).get(7) + u2 = s.query(User).get(7) + assert u is u2 + s.expunge_all() + u2 = s.query(User).get(7) + assert u is not u2 + + def test_no_criterion(self): + """test that get()/load() does not use preexisting filter/etc. criterion""" + + s = create_session() + + q = s.query(User).join('addresses').filter(Address.user_id==8) + assert_raises(sa_exc.InvalidRequestError, q.get, 7) + assert_raises(sa_exc.InvalidRequestError, s.query(User).filter(User.id==7).get, 19) + + # order_by()/get() doesn't raise + s.query(User).order_by(User.id).get(8) + + def test_unique_param_names(self): + class SomeUser(object): + pass + s = users.select(users.c.id!=12).alias('users') + m = mapper(SomeUser, s) + print s.primary_key + print m.primary_key + assert s.primary_key == m.primary_key + + row = s.select(use_labels=True).execute().fetchone() + print row[s.primary_key[0]] + + sess = create_session() + assert sess.query(SomeUser).get(7).name == 'jack' + + def test_load(self): + s = create_session() + + assert s.query(User).populate_existing().get(19) is None + + u = s.query(User).populate_existing().get(7) + u2 = s.query(User).populate_existing().get(7) + assert u is u2 + s.expunge_all() + u2 = s.query(User).populate_existing().get(7) + assert u is not u2 + + u2.name = 'some name' + a = Address(email_address='some other name') + u2.addresses.append(a) + assert u2 in s.dirty + assert a in u2.addresses + + s.query(User).populate_existing().get(7) + assert u2 not in s.dirty + assert u2.name =='jack' + assert a not in u2.addresses + + @testing.requires.unicode_connections + def test_unicode(self): + """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail + on postgres, mysql and oracle unless it is converted to an encoded string""" + + metadata = MetaData(engines.utf8_engine()) + table = Table('unicode_data', metadata, + Column('id', Unicode(40), primary_key=True), + Column('data', Unicode(40))) + try: + metadata.create_all() + ustring = 'petit voix m\xe2\x80\x99a'.decode('utf-8') + table.insert().execute(id=ustring, data=ustring) + class LocalFoo(Base): + pass + mapper(LocalFoo, table) + eq_(create_session().query(LocalFoo).get(ustring), + LocalFoo(id=ustring, data=ustring)) + finally: + metadata.drop_all() + + def test_populate_existing(self): + s = create_session() + + userlist = s.query(User).all() + + u = userlist[0] + u.name = 'foo' + a = Address(name='ed') + u.addresses.append(a) + + self.assert_(a in u.addresses) + + s.query(User).populate_existing().all() + + self.assert_(u not in s.dirty) + + self.assert_(u.name == 'jack') + + self.assert_(a not in u.addresses) + + u.addresses[0].email_address = 'lala' + u.orders[1].items[2].description = 'item 12' + # test that lazy load doesnt change child items + s.query(User).populate_existing().all() + assert u.addresses[0].email_address == 'lala' + assert u.orders[1].items[2].description == 'item 12' + + # eager load does + s.query(User).options(eagerload('addresses'), eagerload_all('orders.items')).populate_existing().all() + assert u.addresses[0].email_address == 'jack@bean.com' + assert u.orders[1].items[2].description == 'item 5' + + @testing.fails_on_everything_except('sqlite', 'mssql') + def test_query_str(self): + s = create_session() + q = s.query(User).filter(User.id==1) + eq_( + str(q).replace('\n',''), + 'SELECT users.id AS users_id, users.name AS users_name FROM users WHERE users.id = ?' + ) + +class InvalidGenerationsTest(QueryTest): + def test_no_limit_offset(self): + s = create_session() + + for q in ( + s.query(User).limit(2), + s.query(User).offset(2), + s.query(User).limit(2).offset(2) + ): + assert_raises(sa_exc.InvalidRequestError, q.join, "addresses") + + assert_raises(sa_exc.InvalidRequestError, q.filter, User.name=='ed') + + assert_raises(sa_exc.InvalidRequestError, q.filter_by, name='ed') + + assert_raises(sa_exc.InvalidRequestError, q.order_by, 'foo') + + assert_raises(sa_exc.InvalidRequestError, q.group_by, 'foo') + + assert_raises(sa_exc.InvalidRequestError, q.having, 'foo') + + def test_no_from(self): + s = create_session() + + q = s.query(User).select_from(users) + assert_raises(sa_exc.InvalidRequestError, q.select_from, users) + + q = s.query(User).join('addresses') + assert_raises(sa_exc.InvalidRequestError, q.select_from, users) + + q = s.query(User).order_by(User.id) + assert_raises(sa_exc.InvalidRequestError, q.select_from, users) + + # this is fine, however + q.from_self() + + def test_invalid_select_from(self): + s = create_session() + q = s.query(User) + assert_raises(sa_exc.ArgumentError, q.select_from, User.id==5) + assert_raises(sa_exc.ArgumentError, q.select_from, User.id) + + def test_invalid_from_statement(self): + s = create_session() + q = s.query(User) + assert_raises(sa_exc.ArgumentError, q.from_statement, User.id==5) + assert_raises(sa_exc.ArgumentError, q.from_statement, users.join(addresses)) + + def test_invalid_column(self): + s = create_session() + q = s.query(User) + assert_raises(sa_exc.InvalidRequestError, q.add_column, object()) + + def test_mapper_zero(self): + s = create_session() + + q = s.query(User, Address) + assert_raises(sa_exc.InvalidRequestError, q.get, 5) + + def test_from_statement(self): + s = create_session() + + q = s.query(User).filter(User.id==5) + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") + + q = s.query(User).filter_by(id=5) + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") + + q = s.query(User).limit(5) + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") + + q = s.query(User).group_by(User.name) + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") + + q = s.query(User).order_by(User.name) + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") + +class OperatorTest(QueryTest, AssertsCompiledSQL): + """test sql.Comparator implementation for MapperProperties""" + + def _test(self, clause, expected): + self.assert_compile(clause, expected, dialect=default.DefaultDialect()) + + def test_arithmetic(self): + create_session().query(User) + for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), + (operator.sub, '-'), (operator.div, '/'), + ): + for (lhs, rhs, res) in ( + (5, User.id, ':id_1 %s users.id'), + (5, literal(6), ':param_1 %s :param_2'), + (User.id, 5, 'users.id %s :id_1'), + (User.id, literal('b'), 'users.id %s :param_1'), + (User.id, User.id, 'users.id %s users.id'), + (literal(5), 'b', ':param_1 %s :param_2'), + (literal(5), User.id, ':param_1 %s users.id'), + (literal(5), literal(6), ':param_1 %s :param_2'), + ): + self._test(py_op(lhs, rhs), res % sql_op) + + def test_comparison(self): + create_session().query(User) + ualias = aliased(User) + + for (py_op, fwd_op, rev_op) in ((operator.lt, '<', '>'), + (operator.gt, '>', '<'), + (operator.eq, '=', '='), + (operator.ne, '!=', '!='), + (operator.le, '<=', '>='), + (operator.ge, '>=', '<=')): + for (lhs, rhs, l_sql, r_sql) in ( + ('a', User.id, ':id_1', 'users.id'), + ('a', literal('b'), ':param_2', ':param_1'), # note swap! + (User.id, 'b', 'users.id', ':id_1'), + (User.id, literal('b'), 'users.id', ':param_1'), + (User.id, User.id, 'users.id', 'users.id'), + (literal('a'), 'b', ':param_1', ':param_2'), + (literal('a'), User.id, ':param_1', 'users.id'), + (literal('a'), literal('b'), ':param_1', ':param_2'), + (ualias.id, literal('b'), 'users_1.id', ':param_1'), + (User.id, ualias.name, 'users.id', 'users_1.name'), + (User.name, ualias.name, 'users.name', 'users_1.name'), + (ualias.name, User.name, 'users_1.name', 'users.name'), + ): + + # the compiled clause should match either (e.g.): + # 'a' < 'b' -or- 'b' > 'a'. + compiled = str(py_op(lhs, rhs).compile(dialect=default.DefaultDialect())) + fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql) + rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql) + + self.assert_(compiled == fwd_sql or compiled == rev_sql, + "\n'" + compiled + "'\n does not match\n'" + + fwd_sql + "'\n or\n'" + rev_sql + "'") + + def test_negated_null(self): + self._test(User.id == None, "users.id IS NULL") + self._test(~(User.id==None), "users.id IS NOT NULL") + self._test(None == User.id, "users.id IS NULL") + self._test(~(None == User.id), "users.id IS NOT NULL") + self._test(Address.user == None, "addresses.user_id IS NULL") + self._test(~(Address.user==None), "addresses.user_id IS NOT NULL") + self._test(None == Address.user, "addresses.user_id IS NULL") + self._test(~(None == Address.user), "addresses.user_id IS NOT NULL") + + def test_relation(self): + self._test(User.addresses.any(Address.id==17), + "EXISTS (SELECT 1 " + "FROM addresses " + "WHERE users.id = addresses.user_id AND addresses.id = :id_1)" + ) + + u7 = User(id=7) + attributes.instance_state(u7).commit_all(attributes.instance_dict(u7)) + + self._test(Address.user == u7, ":param_1 = addresses.user_id") + + self._test(Address.user != u7, "addresses.user_id != :user_id_1 OR addresses.user_id IS NULL") + + self._test(Address.user == None, "addresses.user_id IS NULL") + + self._test(Address.user != None, "addresses.user_id IS NOT NULL") + + def test_selfref_relation(self): + nalias = aliased(Node) + + # auto self-referential aliasing + self._test( + Node.children.any(Node.data=='n1'), + "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE " + "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)" + ) + + # needs autoaliasing + self._test( + Node.children==None, + "NOT (EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE nodes.id = nodes_1.parent_id))" + ) + + self._test( + Node.parent==None, + "nodes.parent_id IS NULL" + ) + + self._test( + nalias.parent==None, + "nodes_1.parent_id IS NULL" + ) + + self._test( + nalias.children==None, + "NOT (EXISTS (SELECT 1 FROM nodes WHERE nodes_1.id = nodes.parent_id))" + ) + + self._test( + nalias.children.any(Node.data=='some data'), + "EXISTS (SELECT 1 FROM nodes WHERE " + "nodes_1.id = nodes.parent_id AND nodes.data = :data_1)") + + # fails, but I think I want this to fail + #self._test( + # Node.children.any(nalias.data=='some data'), + # "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE " + # "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)" + # ) + + self._test( + nalias.parent.has(Node.data=='some data'), + "EXISTS (SELECT 1 FROM nodes WHERE nodes.id = nodes_1.parent_id AND nodes.data = :data_1)" + ) + + self._test( + Node.parent.has(Node.data=='some data'), + "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE nodes_1.id = nodes.parent_id AND nodes_1.data = :data_1)" + ) + + self._test( + Node.parent == Node(id=7), + ":param_1 = nodes.parent_id" + ) + + self._test( + nalias.parent == Node(id=7), + ":param_1 = nodes_1.parent_id" + ) + + self._test( + nalias.parent != Node(id=7), + 'nodes_1.parent_id != :parent_id_1 OR nodes_1.parent_id IS NULL' + ) + + self._test( + nalias.children.contains(Node(id=7)), "nodes_1.id = :param_1" + ) + + def test_op(self): + self._test(User.name.op('ilike')('17'), "users.name ilike :name_1") + + def test_in(self): + self._test(User.id.in_(['a', 'b']), + "users.id IN (:id_1, :id_2)") + + def test_in_on_relation_not_supported(self): + assert_raises(NotImplementedError, Address.user.in_, [User(id=5)]) + + def test_between(self): + self._test(User.id.between('a', 'b'), + "users.id BETWEEN :id_1 AND :id_2") + + def test_selfref_between(self): + ualias = aliased(User) + self._test(User.id.between(ualias.id, ualias.id), "users.id BETWEEN users_1.id AND users_1.id") + self._test(ualias.id.between(User.id, User.id), "users_1.id BETWEEN users.id AND users.id") + + def test_clauses(self): + for (expr, compare) in ( + (func.max(User.id), "max(users.id)"), + (User.id.desc(), "users.id DESC"), + (between(5, User.id, Address.id), ":param_1 BETWEEN users.id AND addresses.id"), + # this one would require adding compile() to InstrumentedScalarAttribute. do we want this ? + #(User.id, "users.id") + ): + c = expr.compile(dialect=default.DefaultDialect()) + assert str(c) == compare, "%s != %s" % (str(c), compare) + + +class RawSelectTest(QueryTest, AssertsCompiledSQL): + """compare a bunch of select() tests with the equivalent Query using straight table/columns. + + Results should be the same as Query should act as a select() pass-thru for ClauseElement entities. + + """ + def test_select(self): + sess = create_session() + + self.assert_compile(sess.query(users).select_from(users.select()).with_labels().statement, + "SELECT users.id AS users_id, users.name AS users_name FROM users, (SELECT users.id AS id, users.name AS name FROM users) AS anon_1") + + self.assert_compile(sess.query(users, exists([1], from_obj=addresses)).with_labels().statement, + "SELECT users.id AS users_id, users.name AS users_name, EXISTS (SELECT 1 FROM addresses) AS anon_1 FROM users") + + # a little tedious here, adding labels to work around Query's auto-labelling. + # also correlate needed explicitly. hmmm..... + # TODO: can we detect only one table in the "froms" and then turn off use_labels ? + s = sess.query(addresses.c.id.label('id'), addresses.c.email_address.label('email')).\ + filter(addresses.c.user_id==users.c.id).correlate(users).statement.alias() + + self.assert_compile(sess.query(users, s.c.email).select_from(users.join(s, s.c.id==users.c.id)).with_labels().statement, + "SELECT users.id AS users_id, users.name AS users_name, anon_1.email AS anon_1_email " + "FROM users JOIN (SELECT addresses.id AS id, addresses.email_address AS email FROM addresses " + "WHERE addresses.user_id = users.id) AS anon_1 ON anon_1.id = users.id", + dialect=default.DefaultDialect() + ) + + x = func.lala(users.c.id).label('foo') + self.assert_compile(sess.query(x).filter(x==5).statement, + "SELECT lala(users.id) AS foo FROM users WHERE lala(users.id) = :param_1", dialect=default.DefaultDialect()) + +class ExpressionTest(QueryTest, AssertsCompiledSQL): + + def test_deferred_instances(self): + session = create_session() + s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).statement + + l = list(session.query(User).instances(s.execute(emailad = 'jack@bean.com'))) + eq_([User(id=7)], l) + + def test_scalar_subquery(self): + session = create_session() + + q = session.query(User.id).filter(User.id==7).subquery() + + q = session.query(User).filter(User.id==q) + + eq_(User(id=7), q.one()) + + + def test_in(self): + session = create_session() + s = session.query(User.id).join(User.addresses).group_by(User.id).having(func.count(Address.id) > 2) + eq_( + session.query(User).filter(User.id.in_(s)).all(), + [User(id=8)] + ) + + def test_union(self): + s = create_session() + + q1 = s.query(User).filter(User.name=='ed').with_labels() + q2 = s.query(User).filter(User.name=='fred').with_labels() + eq_( + s.query(User).from_statement(union(q1, q2).order_by('users_name')).all(), + [User(name='ed'), User(name='fred')] + ) + + def test_select(self): + s = create_session() + + # this is actually not legal on most DBs since the subquery has no alias + q1 = s.query(User).filter(User.name=='ed') + self.assert_compile( + select([q1]), + "SELECT id, name FROM (SELECT users.id AS id, users.name AS name FROM users WHERE users.name = :name_1)", + dialect=default.DefaultDialect() + ) + + def test_join(self): + s = create_session() + + # TODO: do we want aliased() to detect a query and convert to subquery() + # automatically ? + q1 = s.query(Address).filter(Address.email_address=='jack@bean.com') + adalias = aliased(Address, q1.subquery()) + eq_( + s.query(User, adalias).join((adalias, User.id==adalias.user_id)).all(), + [(User(id=7,name=u'jack'), Address(email_address=u'jack@bean.com',user_id=7,id=1))] + ) + +# more slice tests are available in test/orm/generative.py +class SliceTest(QueryTest): + def test_first(self): + assert User(id=7) == create_session().query(User).first() + + assert create_session().query(User).filter(User.id==27).first() is None + + @testing.fails_on_everything_except('sqlite') + def test_limit_offset_applies(self): + """Test that the expected LIMIT/OFFSET is applied for slices. + + The LIMIT/OFFSET syntax differs slightly on all databases, and + query[x:y] executes immediately, so we are asserting against + SQL strings using sqlite's syntax. + + """ + sess = create_session() + q = sess.query(User) + + self.assert_sql(testing.db, lambda: q[10:20], [ + ("SELECT users.id AS users_id, users.name AS users_name FROM users LIMIT 10 OFFSET 10", {}) + ]) + + self.assert_sql(testing.db, lambda: q[:20], [ + ("SELECT users.id AS users_id, users.name AS users_name FROM users LIMIT 20 OFFSET 0", {}) + ]) + + self.assert_sql(testing.db, lambda: q[5:], [ + ("SELECT users.id AS users_id, users.name AS users_name FROM users LIMIT -1 OFFSET 5", {}) + ]) + + self.assert_sql(testing.db, lambda: q[2:2], []) + + self.assert_sql(testing.db, lambda: q[-2:-5], []) + + self.assert_sql(testing.db, lambda: q[-5:-2], [ + ("SELECT users.id AS users_id, users.name AS users_name FROM users", {}) + ]) + + self.assert_sql(testing.db, lambda: q[-5:], [ + ("SELECT users.id AS users_id, users.name AS users_name FROM users", {}) + ]) + + self.assert_sql(testing.db, lambda: q[:], [ + ("SELECT users.id AS users_id, users.name AS users_name FROM users", {}) + ]) + + + +class FilterTest(QueryTest): + def test_basic(self): + assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).all() + + @testing.fails_on('maxdb', 'FIXME: unknown') + def test_limit(self): + assert [User(id=8), User(id=9)] == create_session().query(User).order_by(User.id).limit(2).offset(1).all() + + assert [User(id=8), User(id=9)] == list(create_session().query(User).order_by(User.id)[1:3]) + + assert User(id=8) == create_session().query(User).order_by(User.id)[1] + + assert [] == create_session().query(User).order_by(User.id)[3:3] + assert [] == create_session().query(User).order_by(User.id)[0:0] + + + def test_one_filter(self): + assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.name.endswith('ed')).all() + + def test_contains(self): + """test comparing a collection to an object instance.""" + + sess = create_session() + address = sess.query(Address).get(3) + assert [User(id=8)] == sess.query(User).filter(User.addresses.contains(address)).all() + + try: + sess.query(User).filter(User.addresses == address) + assert False + except sa_exc.InvalidRequestError: + assert True + + assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all() + + try: + assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all() + assert False + except sa_exc.InvalidRequestError: + assert True + + #assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all() + + def test_any(self): + sess = create_session() + + assert [User(id=8), User(id=9)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'))).all() + + assert [User(id=8)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'), id=4)).all() + + assert [User(id=8)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'))).\ + filter(User.addresses.any(id=4)).all() + + assert [User(id=9)] == sess.query(User).filter(User.addresses.any(email_address='fred@fred.com')).all() + + # test that any() doesn't overcorrelate + assert [User(id=7), User(id=8)] == sess.query(User).join("addresses").filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all() + + # test that the contents are not adapted by the aliased join + assert [User(id=7), User(id=8)] == sess.query(User).join("addresses", aliased=True).filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all() + + assert [User(id=10)] == sess.query(User).outerjoin("addresses", aliased=True).filter(~User.addresses.any()).all() + + @testing.crashes('maxdb', 'can dump core') + def test_has(self): + sess = create_session() + assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all() + + assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).all() + + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all() + + # test has() doesn't overcorrelate + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).all() + + # test has() doesnt' get subquery contents adapted by aliased join + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).all() + + dingaling = sess.query(Dingaling).get(2) + assert [User(id=9)] == sess.query(User).filter(User.addresses.any(Address.dingaling==dingaling)).all() + + def test_contains_m2m(self): + sess = create_session() + item = sess.query(Item).get(3) + assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items.contains(item)).all() + + assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(~Order.items.contains(item)).all() + + item2 = sess.query(Item).get(5) + assert [Order(id=3)] == sess.query(Order).filter(Order.items.contains(item)).filter(Order.items.contains(item2)).all() + + + def test_comparison(self): + """test scalar comparison to an object instance""" + + sess = create_session() + user = sess.query(User).get(8) + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user==user).all() + + assert [Address(id=1), Address(id=5)] == sess.query(Address).filter(Address.user!=user).all() + + # generates an IS NULL + assert [] == sess.query(Address).filter(Address.user == None).all() + + assert [Order(id=5)] == sess.query(Order).filter(Order.address == None).all() + + # o2o + dingaling = sess.query(Dingaling).get(2) + assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all() + + # m2m + eq_(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)]) + eq_(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)]) + + def test_filter_by(self): + sess = create_session() + user = sess.query(User).get(8) + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter_by(user=user).all() + + # many to one generates IS NULL + assert [] == sess.query(Address).filter_by(user = None).all() + + # one to many generates WHERE NOT EXISTS + assert [User(name='chuck')] == sess.query(User).filter_by(addresses = None).all() + + def test_none_comparison(self): + sess = create_session() + + # o2o + eq_([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all()) + eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all()) + + # m2o + eq_([Order(id=5)], sess.query(Order).filter(Order.address==None).all()) + eq_([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).order_by(Order.id).filter(Order.address!=None).all()) + + # o2m + eq_([User(id=10)], sess.query(User).filter(User.addresses==None).all()) + eq_([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).order_by(User.id).all()) + + +class FromSelfTest(QueryTest, AssertsCompiledSQL): + def test_filter(self): + + assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().all() + + assert [User(id=8), User(id=9)] == create_session().query(User).order_by(User.id).slice(1,3)._from_self().all() + assert [User(id=8)] == list(create_session().query(User).filter(User.id.in_([8,9]))._from_self().order_by(User.id)[0:1]) + + def test_join(self): + assert [ + (User(id=8), Address(id=2)), + (User(id=8), Address(id=3)), + (User(id=8), Address(id=4)), + (User(id=9), Address(id=5)) + ] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().\ + join('addresses').add_entity(Address).order_by(User.id, Address.id).all() + + def test_group_by(self): + eq_( + create_session().query(Address.user_id, func.count(Address.id).label('count')).\ + group_by(Address.user_id).order_by(Address.user_id).all(), + [(7, 1), (8, 3), (9, 1)] + ) + + eq_( + create_session().query(Address.user_id, Address.id).\ + from_self(Address.user_id, func.count(Address.id)).\ + group_by(Address.user_id).order_by(Address.user_id).all(), + [(7, 1), (8, 3), (9, 1)] + ) + + def test_no_eagerload(self): + """test that eagerloads are pushed outwards and not rendered in subqueries.""" + + s = create_session() + + self.assert_compile( + s.query(User).options(eagerload(User.addresses)).from_self().statement, + "SELECT anon_1.users_id, anon_1.users_name, addresses_1.id, addresses_1.user_id, "\ + "addresses_1.email_address FROM (SELECT users.id AS users_id, users.name AS users_name FROM users) AS anon_1 "\ + "LEFT OUTER JOIN addresses AS addresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id" + ) + + def test_aliases(self): + """test that aliased objects are accessible externally to a from_self() call.""" + + s = create_session() + + ualias = aliased(User) + eq_( + s.query(User, ualias).filter(User.id > ualias.id).from_self(User.name, ualias.name). + order_by(User.name, ualias.name).all(), + [ + (u'chuck', u'ed'), + (u'chuck', u'fred'), + (u'chuck', u'jack'), + (u'ed', u'jack'), + (u'fred', u'ed'), + (u'fred', u'jack') + ] + ) + + eq_( + s.query(User, ualias).filter(User.id > ualias.id).from_self(User.name, ualias.name).filter(ualias.name=='ed')\ + .order_by(User.name, ualias.name).all(), + [(u'chuck', u'ed'), (u'fred', u'ed')] + ) + + eq_( + s.query(User, ualias).filter(User.id > ualias.id).from_self(ualias.name, Address.email_address). + join(ualias.addresses).order_by(ualias.name, Address.email_address).all(), + [ + (u'ed', u'fred@fred.com'), + (u'jack', u'ed@bettyboop.com'), + (u'jack', u'ed@lala.com'), + (u'jack', u'ed@wood.com'), + (u'jack', u'fred@fred.com')] + ) + + + def test_multiple_entities(self): + sess = create_session() + + eq_( + sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().all(), + [ + (User(id=8), Address(id=2)), + (User(id=9), Address(id=5)) + ] + ) + + eq_( + sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().options(eagerload('addresses')).first(), + + # order_by(User.id, Address.id).first(), + (User(id=8, addresses=[Address(), Address(), Address()]), Address(id=2)), + ) + +class SetOpsTest(QueryTest, AssertsCompiledSQL): + + def test_union(self): + s = create_session() + + fred = s.query(User).filter(User.name=='fred') + ed = s.query(User).filter(User.name=='ed') + jack = s.query(User).filter(User.name=='jack') + + eq_(fred.union(ed).order_by(User.name).all(), + [User(name='ed'), User(name='fred')] + ) + + eq_(fred.union(ed, jack).order_by(User.name).all(), + [User(name='ed'), User(name='fred'), User(name='jack')] + ) + + @testing.fails_on('mysql', "mysql doesn't support intersect") + def test_intersect(self): + s = create_session() + + fred = s.query(User).filter(User.name=='fred') + ed = s.query(User).filter(User.name=='ed') + jack = s.query(User).filter(User.name=='jack') + eq_(fred.intersect(ed, jack).all(), + [] + ) + + eq_(fred.union(ed).intersect(ed.union(jack)).all(), + [User(name='ed')] + ) + + def test_eager_load(self): + s = create_session() + + fred = s.query(User).filter(User.name=='fred') + ed = s.query(User).filter(User.name=='ed') + jack = s.query(User).filter(User.name=='jack') + + def go(): + eq_( + fred.union(ed).order_by(User.name).options(eagerload(User.addresses)).all(), + [ + User(name='ed', addresses=[Address(), Address(), Address()]), + User(name='fred', addresses=[Address()]) + ] + ) + self.assert_sql_count(testing.db, go, 1) + + +class AggregateTest(QueryTest): + + def test_sum(self): + sess = create_session() + orders = sess.query(Order).filter(Order.id.in_([2, 3, 4])) + eq_(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,)) + eq_(orders.value(func.sum(Order.user_id * Order.address_id)), 79) + + def test_apply(self): + sess = create_session() + assert sess.query(func.sum(Order.user_id * Order.address_id)).filter(Order.id.in_([2, 3, 4])).one() == (79,) + + def test_having(self): + sess = create_session() + assert [User(name=u'ed',id=8)] == sess.query(User).order_by(User.id).group_by(User).join('addresses').having(func.count(Address.id)> 2).all() + + assert [User(name=u'jack',id=7), User(name=u'fred',id=9)] == sess.query(User).order_by(User.id).group_by(User).join('addresses').having(func.count(Address.id)< 2).all() + +class CountTest(QueryTest): + def test_basic(self): + s = create_session() + + eq_(s.query(User).count(), 4) + + eq_(s.query(User).filter(users.c.name.endswith('ed')).count(), 2) + + def test_multiple_entity(self): + s = create_session() + q = s.query(User, Address) + eq_(q.count(), 20) # cartesian product + + q = s.query(User, Address).join(User.addresses) + eq_(q.count(), 5) + + def test_nested(self): + s = create_session() + q = s.query(User, Address).limit(2) + eq_(q.count(), 2) + + q = s.query(User, Address).limit(100) + eq_(q.count(), 20) + + q = s.query(User, Address).join(User.addresses).limit(100) + eq_(q.count(), 5) + + def test_cols(self): + """test that column-based queries always nest.""" + + s = create_session() + + q = s.query(func.count(distinct(User.name))) + eq_(q.count(), 1) + + q = s.query(func.count(distinct(User.name))).distinct() + eq_(q.count(), 1) + + q = s.query(User.name) + eq_(q.count(), 4) + + q = s.query(User.name, Address) + eq_(q.count(), 20) + + q = s.query(Address.user_id) + eq_(q.count(), 5) + eq_(q.distinct().count(), 3) + + +class DistinctTest(QueryTest): + def test_basic(self): + assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).distinct().all() + assert [User(id=7), User(id=9), User(id=8),User(id=10)] == create_session().query(User).distinct().order_by(desc(User.name)).all() + + def test_joined(self): + """test that orderbys from a joined table get placed into the columns clause when DISTINCT is used""" + + sess = create_session() + q = sess.query(User).join('addresses').distinct().order_by(desc(Address.email_address)) + + assert [User(id=7), User(id=9), User(id=8)] == q.all() + + sess.expunge_all() + + # test that it works on embedded eagerload/LIMIT subquery + q = sess.query(User).join('addresses').distinct().options(eagerload('addresses')).order_by(desc(Address.email_address)).limit(2) + + def go(): + assert [ + User(id=7, addresses=[ + Address(id=1) + ]), + User(id=9, addresses=[ + Address(id=5) + ]), + ] == q.all() + self.assert_sql_count(testing.db, go, 1) + + +class YieldTest(QueryTest): + def test_basic(self): + import gc + sess = create_session() + q = iter(sess.query(User).yield_per(1).from_statement("select * from users")) + + ret = [] + eq_(len(sess.identity_map), 0) + ret.append(q.next()) + ret.append(q.next()) + eq_(len(sess.identity_map), 2) + ret.append(q.next()) + ret.append(q.next()) + eq_(len(sess.identity_map), 4) + try: + q.next() + assert False + except StopIteration: + pass + +class TextTest(QueryTest): + def test_fulltext(self): + assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).from_statement("select * from users order by id").all() + + assert User(id=7) == create_session().query(User).from_statement("select * from users order by id").first() + assert None == create_session().query(User).from_statement("select * from users where name='nonexistent'").first() + + def test_fragment(self): + assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (8, 9)").all() + + assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter("id=9").all() + + assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(User.id==9).all() + + def test_binds(self): + assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all() + + def test_as_column(self): + s = create_session() + assert_raises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name")) + + eq_(s.query(User.id, "name").order_by(User.id).all(), [(7, u'jack'), (8, u'ed'), (9, u'fred'), (10, u'chuck')]) + +class ParentTest(QueryTest): + def test_o2m(self): + sess = create_session() + q = sess.query(User) + + u1 = q.filter_by(name='jack').one() + + # test auto-lookup of property + o = sess.query(Order).with_parent(u1).all() + assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o + + # test with explicit property + o = sess.query(Order).with_parent(u1, property='orders').all() + assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o + + o = sess.query(Order).filter(with_parent(u1, User.orders)).all() + assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o + + # test static method + @testing.uses_deprecated(".*Use sqlalchemy.orm.with_parent") + def go(): + o = Query.query_from_parent(u1, property='orders', session=sess).all() + assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o + go() + + # test generative criterion + o = sess.query(Order).with_parent(u1).filter(orders.c.id>2).all() + assert [Order(description="order 3"), Order(description="order 5")] == o + + # test against None for parent? this can't be done with the current API since we don't know + # what mapper to use + #assert sess.query(Order).with_parent(None, property='addresses').all() == [Order(description="order 5")] + + def test_noparent(self): + sess = create_session() + q = sess.query(User) + + u1 = q.filter_by(name='jack').one() + + try: + q = sess.query(Item).with_parent(u1) + assert False + except sa_exc.InvalidRequestError, e: + assert str(e) == "Could not locate a property which relates instances of class 'Item' to instances of class 'User'" + + def test_m2m(self): + sess = create_session() + i1 = sess.query(Item).filter_by(id=2).one() + k = sess.query(Keyword).with_parent(i1).all() + assert [Keyword(name='red'), Keyword(name='small'), Keyword(name='square')] == k + + +class JoinTest(QueryTest): + + def test_overlapping_paths(self): + for aliased in (True,False): + # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack) + result = create_session().query(User).join(['orders', 'items'], aliased=aliased).filter_by(id=3).join(['orders','address'], aliased=aliased).filter_by(id=1).all() + assert [User(id=7, name='jack')] == result + + def test_overlapping_paths_outerjoin(self): + result = create_session().query(User).outerjoin(['orders', 'items']).filter_by(id=3).outerjoin(['orders','address']).filter_by(id=1).all() + assert [User(id=7, name='jack')] == result + + def test_from_joinpoint(self): + sess = create_session() + + for oalias,ialias in [(True, True), (False, False), (True, False), (False, True)]: + eq_( + sess.query(User).join('orders', aliased=oalias).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description == 'item 4').all(), + [User(name='jack')] + ) + + # use middle criterion + eq_( + sess.query(User).join('orders', aliased=oalias).filter(Order.user_id==9).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description=='item 4').all(), + [] + ) + + orderalias = aliased(Order) + itemalias = aliased(Item) + eq_( + sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(itemalias.description == 'item 4').all(), + [User(name='jack')] + ) + eq_( + sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(orderalias.user_id==9).filter(itemalias.description=='item 4').all(), + [] + ) + + def test_backwards_join(self): + # a more controversial feature. join from + # User->Address, but the onclause is Address.user. + + sess = create_session() + + eq_( + sess.query(User).join(Address.user).filter(Address.email_address=='ed@wood.com').all(), + [User(id=8,name=u'ed')] + ) + + # its actually not so controversial if you view it in terms + # of multiple entities. + eq_( + sess.query(User, Address).join(Address.user).filter(Address.email_address=='ed@wood.com').all(), + [(User(id=8,name=u'ed'), Address(email_address='ed@wood.com'))] + ) + + # this was the controversial part. now, raise an error if the feature is abused. + # before the error raise was added, this would silently work..... + assert_raises( + sa_exc.InvalidRequestError, + sess.query(User).join, (Address, Address.user), + ) + + # but this one would silently fail + adalias = aliased(Address) + assert_raises( + sa_exc.InvalidRequestError, + sess.query(User).join, (adalias, Address.user), + ) + + def test_multiple_with_aliases(self): + sess = create_session() + + ualias = aliased(User) + oalias1 = aliased(Order) + oalias2 = aliased(Order) + result = sess.query(ualias).join((oalias1, ualias.orders), (oalias2, ualias.orders)).\ + filter(or_(oalias1.user_id==9, oalias2.user_id==7)).all() + eq_(result, [User(id=7,name=u'jack'), User(id=9,name=u'fred')]) + + def test_orderby_arg_bug(self): + sess = create_session() + # no arg error + result = sess.query(User).join('orders', aliased=True).order_by([Order.id]).reset_joinpoint().order_by(users.c.id).all() + + def test_no_onclause(self): + sess = create_session() + + eq_( + sess.query(User).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(), + [User(name='jack')] + ) + + eq_( + sess.query(User.name).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(), + [('jack',)] + ) + + eq_( + sess.query(User).join(Order, (Item, Order.items)).filter(Item.description == 'item 4').all(), + [User(name='jack')] + ) + + def test_clause_onclause(self): + sess = create_session() + + eq_( + sess.query(User).join( + (Order, User.id==Order.user_id), + (order_items, Order.id==order_items.c.order_id), + (Item, order_items.c.item_id==Item.id) + ).filter(Item.description == 'item 4').all(), + [User(name='jack')] + ) + + eq_( + sess.query(User.name).join( + (Order, User.id==Order.user_id), + (order_items, Order.id==order_items.c.order_id), + (Item, order_items.c.item_id==Item.id) + ).filter(Item.description == 'item 4').all(), + [('jack',)] + ) + + ualias = aliased(User) + eq_( + sess.query(ualias.name).join( + (Order, ualias.id==Order.user_id), + (order_items, Order.id==order_items.c.order_id), + (Item, order_items.c.item_id==Item.id) + ).filter(Item.description == 'item 4').all(), + [('jack',)] + ) + + # explicit onclause with from_self(), means + # the onclause must be aliased against the query's custom + # FROM object + eq_( + sess.query(User).order_by(User.id).offset(2).from_self().join( + (Order, User.id==Order.user_id) + ).all(), + [User(name='fred')] + ) + + # same with an explicit select_from() + eq_( + sess.query(User).select_from(select([users]).order_by(User.id).offset(2).alias()).join( + (Order, User.id==Order.user_id) + ).all(), + [User(name='fred')] + ) + + + def test_aliased_classes(self): + sess = create_session() + + (user7, user8, user9, user10) = sess.query(User).all() + (address1, address2, address3, address4, address5) = sess.query(Address).all() + expected = [(user7, address1), + (user8, address2), + (user8, address3), + (user8, address4), + (user9, address5), + (user10, None)] + + q = sess.query(User) + AdAlias = aliased(Address) + q = q.add_entity(AdAlias).select_from(outerjoin(User, AdAlias)) + l = q.order_by(User.id, AdAlias.id).all() + eq_(l, expected) + + sess.expunge_all() + + q = sess.query(User).add_entity(AdAlias) + l = q.select_from(outerjoin(User, AdAlias)).filter(AdAlias.email_address=='ed@bettyboop.com').all() + eq_(l, [(user8, address3)]) + + l = q.select_from(outerjoin(User, AdAlias, 'addresses')).filter(AdAlias.email_address=='ed@bettyboop.com').all() + eq_(l, [(user8, address3)]) + + l = q.select_from(outerjoin(User, AdAlias, User.id==AdAlias.user_id)).filter(AdAlias.email_address=='ed@bettyboop.com').all() + eq_(l, [(user8, address3)]) + + # this is the first test where we are joining "backwards" - from AdAlias to User even though + # the query is against User + q = sess.query(User, AdAlias) + l = q.join(AdAlias.user).filter(User.name=='ed') + eq_(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) + + q = sess.query(User, AdAlias).select_from(join(AdAlias, User, AdAlias.user)).filter(User.name=='ed') + eq_(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) + + def test_implicit_joins_from_aliases(self): + sess = create_session() + OrderAlias = aliased(Order) + + eq_( + sess.query(OrderAlias).join('items').filter_by(description='item 3').\ + order_by(OrderAlias.id).all(), + [ + Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), + Order(address_id=4,description=u'order 2',isopen=0,user_id=9,id=2), + Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3) + ] + ) + + eq_( + sess.query(User, OrderAlias, Item.description).join(('orders', OrderAlias), 'items').filter_by(description='item 3').\ + order_by(User.id, OrderAlias.id).all(), + [ + (User(name=u'jack',id=7), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), u'item 3'), + (User(name=u'jack',id=7), Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), u'item 3'), + (User(name=u'fred',id=9), Order(address_id=4,description=u'order 2',isopen=0,user_id=9,id=2), u'item 3') + ] + ) + + def test_aliased_classes_m2m(self): + sess = create_session() + + (order1, order2, order3, order4, order5) = sess.query(Order).all() + (item1, item2, item3, item4, item5) = sess.query(Item).all() + expected = [ + (order1, item1), + (order1, item2), + (order1, item3), + (order2, item1), + (order2, item2), + (order2, item3), + (order3, item3), + (order3, item4), + (order3, item5), + (order4, item1), + (order4, item5), + (order5, item5), + ] + + q = sess.query(Order) + q = q.add_entity(Item).select_from(join(Order, Item, 'items')).order_by(Order.id, Item.id) + l = q.all() + eq_(l, expected) + + IAlias = aliased(Item) + q = sess.query(Order, IAlias).select_from(join(Order, IAlias, 'items')).filter(IAlias.description=='item 3') + l = q.all() + eq_(l, + [ + (order1, item3), + (order2, item3), + (order3, item3), + ] + ) + + def test_reset_joinpoint(self): + for aliased in (True, False): + # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack) + result = create_session().query(User).join(['orders', 'items'], aliased=aliased).filter_by(id=3).reset_joinpoint().join(['orders','address'], aliased=aliased).filter_by(id=1).all() + assert [User(id=7, name='jack')] == result + + result = create_session().query(User).outerjoin(['orders', 'items'], aliased=aliased).filter_by(id=3).reset_joinpoint().outerjoin(['orders','address'], aliased=aliased).filter_by(id=1).all() + assert [User(id=7, name='jack')] == result + + def test_overlap_with_aliases(self): + oalias = orders.alias('oalias') + + result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_(["order 1", "order 2", "order 3"])).join(['orders', 'items']).order_by(User.id).all() + assert [User(id=7, name='jack'), User(id=9, name='fred')] == result + + result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_(["order 1", "order 2", "order 3"])).join(['orders', 'items']).filter_by(id=4).all() + assert [User(id=7, name='jack')] == result + + def test_aliased(self): + """test automatic generation of aliased joins.""" + + sess = create_session() + + # test a basic aliasized path + q = sess.query(User).join('addresses', aliased=True).filter_by(email_address='jack@bean.com') + assert [User(id=7)] == q.all() + + q = sess.query(User).join('addresses', aliased=True).filter(Address.email_address=='jack@bean.com') + assert [User(id=7)] == q.all() + + q = sess.query(User).join('addresses', aliased=True).filter(or_(Address.email_address=='jack@bean.com', Address.email_address=='fred@fred.com')) + assert [User(id=7), User(id=9)] == q.all() + + # test two aliasized paths, one to 'orders' and the other to 'orders','items'. + # one row is returned because user 7 has order 3 and also has order 1 which has item 1 + # this tests a o2m join and a m2m join. + q = sess.query(User).join('orders', aliased=True).filter(Order.description=="order 3").join(['orders', 'items'], aliased=True).filter(Item.description=="item 1") + assert q.count() == 1 + assert [User(id=7)] == q.all() + + # test the control version - same joins but not aliased. rows are not returned because order 3 does not have item 1 + q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Item.description=="item 1") + assert [] == q.all() + assert q.count() == 0 + + # the left half of the join condition of the any() is aliased. + q = sess.query(User).join('orders', aliased=True).filter(Order.items.any(Item.description=='item 4')) + assert [User(id=7)] == q.all() + + # test that aliasing gets reset when join() is called + q = sess.query(User).join('orders', aliased=True).filter(Order.description=="order 3").join('orders', aliased=True).filter(Order.description=="order 5") + assert q.count() == 1 + assert [User(id=7)] == q.all() + + def test_aliased_order_by(self): + sess = create_session() + + ualias = aliased(User) + eq_( + sess.query(User, ualias).filter(User.id > ualias.id).order_by(desc(ualias.id), User.name).all(), + [ + (User(id=10,name=u'chuck'), User(id=9,name=u'fred')), + (User(id=10,name=u'chuck'), User(id=8,name=u'ed')), + (User(id=9,name=u'fred'), User(id=8,name=u'ed')), + (User(id=10,name=u'chuck'), User(id=7,name=u'jack')), + (User(id=8,name=u'ed'), User(id=7,name=u'jack')), + (User(id=9,name=u'fred'), User(id=7,name=u'jack')) + ] + ) + + def test_plain_table(self): + + sess = create_session() + + eq_( + sess.query(User.name).join((addresses, User.id==addresses.c.user_id)).order_by(User.id).all(), + [(u'jack',), (u'ed',), (u'ed',), (u'ed',), (u'fred',)] + ) + + +class MultiplePathTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global t1, t2, t1t2_1, t1t2_2 + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)) + ) + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)) + ) + + t1t2_1 = Table('t1t2_1', metadata, + Column('t1id', Integer, ForeignKey('t1.id')), + Column('t2id', Integer, ForeignKey('t2.id')) + ) + + t1t2_2 = Table('t1t2_2', metadata, + Column('t1id', Integer, ForeignKey('t1.id')), + Column('t2id', Integer, ForeignKey('t2.id')) + ) + + def test_basic(self): + class T1(object):pass + class T2(object):pass + + mapper(T1, t1, properties={ + 't2s_1':relation(T2, secondary=t1t2_1), + 't2s_2':relation(T2, secondary=t1t2_2), + }) + mapper(T2, t2) + + q = create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint() + assert_raises_message(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.", + q.join, 't2s_2' + ) + + create_session().query(T1).join('t2s_1', aliased=True).filter(t2.c.id==5).reset_joinpoint().join('t2s_2').all() + create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2', aliased=True).all() + +class SynonymTest(QueryTest): + + @classmethod + def setup_mappers(cls): + mapper(User, users, properties={ + 'name_syn':synonym('name'), + 'addresses':relation(Address), + 'orders':relation(Order, backref='user'), # o2m, m2o + 'orders_syn':synonym('orders') + }) + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items), #m2m + 'address':relation(Address), # m2o + 'items_syn':synonym('items') + }) + mapper(Item, items, properties={ + 'keywords':relation(Keyword, secondary=item_keywords) #m2m + }) + mapper(Keyword, keywords) + + def test_joins(self): + for j in ( + ['orders', 'items'], + ['orders_syn', 'items'], + ['orders', 'items_syn'], + ['orders_syn', 'items_syn'], + ): + result = create_session().query(User).join(j).filter_by(id=3).all() + assert [User(id=7, name='jack'), User(id=9, name='fred')] == result + + def test_with_parent(self): + for nameprop, orderprop in ( + ('name', 'orders'), + ('name_syn', 'orders'), + ('name', 'orders_syn'), + ('name_syn', 'orders_syn'), + ): + sess = create_session() + q = sess.query(User) + + u1 = q.filter_by(**{nameprop:'jack'}).one() + + o = sess.query(Order).with_parent(u1, property=orderprop).all() + assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o + +class InstancesTest(QueryTest, AssertsCompiledSQL): + + def test_from_alias(self): + + query = users.select(users.c.id==7).union(users.select(users.c.id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True,order_by=['ulist.id', addresses.c.id]) + sess =create_session() + q = sess.query(User) + + def go(): + l = list(q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute())) + assert self.static.user_address_result == l + self.assert_sql_count(testing.db, go, 1) + + sess.expunge_all() + + def go(): + l = q.options(contains_alias('ulist'), contains_eager('addresses')).from_statement(query).all() + assert self.static.user_address_result == l + self.assert_sql_count(testing.db, go, 1) + + # better way. use select_from() + def go(): + l = sess.query(User).select_from(query).options(contains_eager('addresses')).all() + assert self.static.user_address_result == l + self.assert_sql_count(testing.db, go, 1) + + # same thing, but alias addresses, so that the adapter generated by select_from() is wrapped within + # the adapter created by contains_eager() + adalias = addresses.alias() + query = users.select(users.c.id==7).union(users.select(users.c.id>7)).alias('ulist').outerjoin(adalias).select(use_labels=True,order_by=['ulist.id', adalias.c.id]) + def go(): + l = sess.query(User).select_from(query).options(contains_eager('addresses', alias=adalias)).all() + assert self.static.user_address_result == l + self.assert_sql_count(testing.db, go, 1) + + def test_contains_eager(self): + sess = create_session() + + # test that contains_eager suppresses the normal outer join rendering + q = sess.query(User).outerjoin(User.addresses).options(contains_eager(User.addresses)).order_by(User.id) + self.assert_compile(q.with_labels().statement, + "SELECT addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "\ + "addresses.email_address AS addresses_email_address, users.id AS users_id, "\ + "users.name AS users_name FROM users LEFT OUTER JOIN addresses "\ + "ON users.id = addresses.user_id ORDER BY users.id" + , dialect=default.DefaultDialect()) + + def go(): + assert self.static.user_address_result == q.all() + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + adalias = addresses.alias() + q = sess.query(User).select_from(users.outerjoin(adalias)).options(contains_eager(User.addresses, alias=adalias)) + def go(): + eq_(self.static.user_address_result, q.order_by(User.id).all()) + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + selectquery = users.outerjoin(addresses).select(users.c.id<10, use_labels=True, order_by=[users.c.id, addresses.c.id]) + q = sess.query(User) + + def go(): + l = list(q.options(contains_eager('addresses')).instances(selectquery.execute())) + assert self.static.user_address_result[0:3] == l + self.assert_sql_count(testing.db, go, 1) + + sess.expunge_all() + + def go(): + l = list(q.options(contains_eager(User.addresses)).instances(selectquery.execute())) + assert self.static.user_address_result[0:3] == l + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + def go(): + l = q.options(contains_eager('addresses')).from_statement(selectquery).all() + assert self.static.user_address_result[0:3] == l + self.assert_sql_count(testing.db, go, 1) + + def test_contains_eager_alias(self): + adalias = addresses.alias('adalias') + selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.id, adalias.c.id]) + sess = create_session() + q = sess.query(User) + + # string alias name + def go(): + l = list(q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())) + assert self.static.user_address_result == l + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + # expression.Alias object + def go(): + l = list(q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())) + assert self.static.user_address_result == l + self.assert_sql_count(testing.db, go, 1) + + sess.expunge_all() + + # Aliased object + adalias = aliased(Address) + def go(): + l = q.options(contains_eager('addresses', alias=adalias)).outerjoin((adalias, User.addresses)).order_by(User.id, adalias.id) + assert self.static.user_address_result == l.all() + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + oalias = orders.alias('o1') + ialias = items.alias('i1') + query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True).order_by(users.c.id, oalias.c.id, ialias.c.id) + q = create_session().query(User) + # test using string alias with more than one level deep + def go(): + l = list(q.options(contains_eager('orders', alias='o1'), contains_eager('orders.items', alias='i1')).instances(query.execute())) + assert self.static.user_order_result == l + self.assert_sql_count(testing.db, go, 1) + + sess.expunge_all() + + # test using Alias with more than one level deep + def go(): + l = list(q.options(contains_eager('orders', alias=oalias), contains_eager('orders.items', alias=ialias)).instances(query.execute())) + assert self.static.user_order_result == l + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + # test using Aliased with more than one level deep + oalias = aliased(Order) + ialias = aliased(Item) + def go(): + l = q.options(contains_eager(User.orders, alias=oalias), contains_eager(User.orders, Order.items, alias=ialias)).\ + outerjoin((oalias, User.orders), (ialias, oalias.items)).order_by(User.id, oalias.id, ialias.id) + assert self.static.user_order_result == l.all() + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + def test_mixed_eager_contains_with_limit(self): + sess = create_session() + + q = sess.query(User) + def go(): + # outerjoin to User.orders, offset 1/limit 2 so we get user 7 + second two orders. + # then eagerload the addresses. User + Order columns go into the subquery, address + # left outer joins to the subquery, eagerloader for User.orders applies context.adapter + # to result rows. This was [ticket:1180]. + l = q.outerjoin(User.orders).options(eagerload(User.addresses), contains_eager(User.orders)).order_by(User.id, Order.id).offset(1).limit(2).all() + eq_(l, [User(id=7, + addresses=[Address(email_address=u'jack@bean.com',user_id=7,id=1)], + name=u'jack', + orders=[ + Order(address_id=1,user_id=7,description=u'order 3',isopen=1,id=3), + Order(address_id=None,user_id=7,description=u'order 5',isopen=0,id=5) + ])]) + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + def go(): + # same as above, except Order is aliased, so two adapters are applied by the + # eager loader + oalias = aliased(Order) + l = q.outerjoin((User.orders, oalias)).options(eagerload(User.addresses), contains_eager(User.orders, alias=oalias)).order_by(User.id, oalias.id).offset(1).limit(2).all() + eq_(l, [User(id=7, + addresses=[Address(email_address=u'jack@bean.com',user_id=7,id=1)], + name=u'jack', + orders=[ + Order(address_id=1,user_id=7,description=u'order 3',isopen=1,id=3), + Order(address_id=None,user_id=7,description=u'order 5',isopen=0,id=5) + ])]) + self.assert_sql_count(testing.db, go, 1) + + +class MixedEntitiesTest(QueryTest): + + def test_values(self): + sess = create_session() + + assert list(sess.query(User).values()) == list() + + sel = users.select(User.id.in_([7, 8])).alias() + q = sess.query(User) + q2 = q.select_from(sel).values(User.name) + eq_(list(q2), [(u'jack',), (u'ed',)]) + + q = sess.query(User) + q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String)) + eq_(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')]) + + q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address) + eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) + + q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice(1, 3).values(User.name, Address.email_address) + eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')]) + + adalias = aliased(Address) + q2 = q.join(('addresses', adalias)).filter(User.name.like('%e%')).values(User.name, adalias.email_address) + eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) + + q2 = q.values(func.count(User.name)) + assert q2.next() == (4,) + + q2 = q.select_from(sel).filter(User.id==8).values(User.name, sel.c.name, User.name) + eq_(list(q2), [(u'ed', u'ed', u'ed')]) + + # using User.xxx is alised against "sel", so this query returns nothing + q2 = q.select_from(sel).filter(User.id==8).filter(User.id>sel.c.id).values(User.name, sel.c.name, User.name) + eq_(list(q2), []) + + # whereas this uses users.c.xxx, is not aliased and creates a new join + q2 = q.select_from(sel).filter(users.c.id==8).filter(users.c.id>sel.c.id).values(users.c.name, sel.c.name, User.name) + eq_(list(q2), [(u'ed', u'jack', u'jack')]) + + @testing.fails_on('mssql', 'FIXME: unknown') + def test_values_specific_order_by(self): + sess = create_session() + + assert list(sess.query(User).values()) == list() + + sel = users.select(User.id.in_([7, 8])).alias() + q = sess.query(User) + u2 = aliased(User) + q2 = q.select_from(sel).filter(u2.id>1).order_by([User.id, sel.c.id, u2.id]).values(User.name, sel.c.name, u2.name) + eq_(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')]) + + @testing.fails_on('mssql', 'FIXME: unknown') + def test_values_with_boolean_selects(self): + """Tests a values clause that works with select boolean evaluations""" + sess = create_session() + + q = sess.query(User) + q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%'))) + eq_(list(q2), [(True, 1), (False, 3)]) + + def test_correlated_subquery(self): + """test that a subquery constructed from ORM attributes doesn't leak out + those entities to the outermost query. + + """ + sess = create_session() + + subq = select([func.count()]).\ + where(User.id==Address.user_id).\ + correlate(users).\ + label('count') + + # we don't want Address to be outside of the subquery here + eq_( + list(sess.query(User, subq)[0:3]), + [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)] + ) + + # same thing without the correlate, as it should + # not be needed + subq = select([func.count()]).\ + where(User.id==Address.user_id).\ + label('count') + + # we don't want Address to be outside of the subquery here + eq_( + list(sess.query(User, subq)[0:3]), + [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)] + ) + + def test_tuple_labeling(self): + sess = create_session() + for row in sess.query(User, Address).join(User.addresses).all(): + eq_(set(row.keys()), set(['User', 'Address'])) + eq_(row.User, row[0]) + eq_(row.Address, row[1]) + + for row in sess.query(User.name, User.id.label('foobar')): + eq_(set(row.keys()), set(['name', 'foobar'])) + eq_(row.name, row[0]) + eq_(row.foobar, row[1]) + + for row in sess.query(User).values(User.name, User.id.label('foobar')): + eq_(set(row.keys()), set(['name', 'foobar'])) + eq_(row.name, row[0]) + eq_(row.foobar, row[1]) + + oalias = aliased(Order) + for row in sess.query(User, oalias).join(User.orders).all(): + eq_(set(row.keys()), set(['User'])) + eq_(row.User, row[0]) + + oalias = aliased(Order, name='orders') + for row in sess.query(User, oalias).join(User.orders).all(): + eq_(set(row.keys()), set(['User', 'orders'])) + eq_(row.User, row[0]) + eq_(row.orders, row[1]) + + + def test_column_queries(self): + sess = create_session() + + eq_(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)]) + + sel = users.select(User.id.in_([7, 8])).alias() + q = sess.query(User.name) + q2 = q.select_from(sel).all() + eq_(list(q2), [(u'jack',), (u'ed',)]) + + eq_(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [ + (u'jack', u'jack@bean.com'), (u'ed', u'ed@wood.com'), + (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), + (u'fred', u'fred@fred.com') + ]) + + eq_(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(), + [(u'jack', 1), (u'ed', 3), (u'fred', 1), (u'chuck', 0)] + ) + + eq_(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), + [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)] + ) + + eq_(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), + [(1, User(name='jack',id=7)), (3, User(name='ed',id=8)), (1, User(name='fred',id=9)), (0, User(name='chuck',id=10))] + ) + + adalias = aliased(Address) + eq_(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(), + [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)] + ) + + eq_(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(), + [(1, User(name=u'jack',id=7)), (3, User(name=u'ed',id=8)), (1, User(name=u'fred',id=9)), (0, User(name=u'chuck',id=10))] + ) + + # select from aliasing + explicit aliasing + eq_( + sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).order_by(User.id, adalias.id).all(), + [ + (User(name=u'jack',id=7), u'jack@bean.com'), + (User(name=u'ed',id=8), u'ed@wood.com'), + (User(name=u'ed',id=8), u'ed@bettyboop.com'), + (User(name=u'ed',id=8), u'ed@lala.com'), + (User(name=u'fred',id=9), u'fred@fred.com'), + (User(name=u'chuck',id=10), None) + ] + ) + + # anon + select from aliasing + eq_( + sess.query(User).join(User.addresses, aliased=True).filter(Address.email_address.like('%ed%')).from_self().all(), + [ + User(name=u'ed',id=8), + User(name=u'fred',id=9), + ] + ) + + # test eager aliasing, with/without select_from aliasing + for q in [ + sess.query(User, adalias.email_address).outerjoin((User.addresses, adalias)).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10), + sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10), + ]: + eq_( + + q.all(), + [(User(addresses=[Address(user_id=7,email_address=u'jack@bean.com',id=1)],name=u'jack',id=7), u'jack@bean.com'), + (User(addresses=[ + Address(user_id=8,email_address=u'ed@wood.com',id=2), + Address(user_id=8,email_address=u'ed@bettyboop.com',id=3), + Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@wood.com'), + (User(addresses=[ + Address(user_id=8,email_address=u'ed@wood.com',id=2), + Address(user_id=8,email_address=u'ed@bettyboop.com',id=3), + Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@bettyboop.com'), + (User(addresses=[ + Address(user_id=8,email_address=u'ed@wood.com',id=2), + Address(user_id=8,email_address=u'ed@bettyboop.com',id=3), + Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@lala.com'), + (User(addresses=[Address(user_id=9,email_address=u'fred@fred.com',id=5)],name=u'fred',id=9), u'fred@fred.com'), + + (User(addresses=[],name=u'chuck',id=10), None)] + ) + + def test_column_from_limited_eagerload(self): + sess = create_session() + + def go(): + results = sess.query(User).limit(1).options(eagerload('addresses')).add_column(User.name).all() + eq_(results, [(User(name='jack'), 'jack')]) + self.assert_sql_count(testing.db, go, 1) + + def test_self_referential(self): + + sess = create_session() + oalias = aliased(Order) + + for q in [ + sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id).order_by(Order.id, oalias.id), + sess.query(Order, oalias)._from_self().filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id).order_by(Order.id, oalias.id), + + # same thing, but reversed. + sess.query(oalias, Order)._from_self().filter(oalias.user_id==Order.user_id).filter(oalias.user_id==7).filter(Order.idoalias.id)._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)), + + # gratuitous four layers + sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id)._from_self()._from_self()._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)), + + ]: + + eq_( + q.all(), + [ + (Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)), + (Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)), + (Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5), Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3)) + ] + ) + + def test_multi_mappers(self): + + test_session = create_session() + + (user7, user8, user9, user10) = test_session.query(User).all() + (address1, address2, address3, address4, address5) = test_session.query(Address).all() + + expected = [(user7, address1), + (user8, address2), + (user8, address3), + (user8, address4), + (user9, address5), + (user10, None)] + + sess = create_session() + + selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id]) + eq_(list(sess.query(User, Address).instances(selectquery.execute())), expected) + sess.expunge_all() + + for address_entity in (Address, aliased(Address)): + q = sess.query(User).add_entity(address_entity).outerjoin(('addresses', address_entity)).order_by(User.id, address_entity.id) + eq_(q.all(), expected) + sess.expunge_all() + + q = sess.query(User).add_entity(address_entity) + q = q.join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com') + eq_(q.all(), [(user8, address3)]) + sess.expunge_all() + + q = sess.query(User, address_entity).join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com') + eq_(q.all(), [(user8, address3)]) + sess.expunge_all() + + q = sess.query(User, address_entity).join(('addresses', address_entity)).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com') + eq_(list(util.OrderedSet(q.all())), [(user8, address3)]) + sess.expunge_all() + + def test_aliased_multi_mappers(self): + sess = create_session() + + (user7, user8, user9, user10) = sess.query(User).all() + (address1, address2, address3, address4, address5) = sess.query(Address).all() + + expected = [(user7, address1), + (user8, address2), + (user8, address3), + (user8, address4), + (user9, address5), + (user10, None)] + + q = sess.query(User) + adalias = addresses.alias('adalias') + q = q.add_entity(Address, alias=adalias).select_from(users.outerjoin(adalias)) + l = q.order_by(User.id, adalias.c.id).all() + assert l == expected + + sess.expunge_all() + + q = sess.query(User).add_entity(Address, alias=adalias) + l = q.select_from(users.outerjoin(adalias)).filter(adalias.c.email_address=='ed@bettyboop.com').all() + assert l == [(user8, address3)] + + def test_multi_columns(self): + sess = create_session() + + expected = [(u, u.name) for u in sess.query(User).all()] + + for add_col in (User.name, users.c.name): + assert sess.query(User).add_column(add_col).all() == expected + sess.expunge_all() + + assert_raises(sa_exc.InvalidRequestError, sess.query(User).add_column, object()) + + def test_add_multi_columns(self): + """test that add_column accepts a FROM clause.""" + + sess = create_session() + + eq_( + sess.query(User.id).add_column(users).all(), + [(7, 7, u'jack'), (8, 8, u'ed'), (9, 9, u'fred'), (10, 10, u'chuck')] + ) + + def test_multi_columns_2(self): + """test aliased/nonalised joins with the usage of add_column()""" + sess = create_session() + + (user7, user8, user9, user10) = sess.query(User).all() + expected = [(user7, 1), + (user8, 3), + (user9, 1), + (user10, 0) + ] + + q = sess.query(User) + q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(Address.id).label('count')) + eq_(q.all(), expected) + sess.expunge_all() + + adalias = aliased(Address) + q = sess.query(User) + q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin(('addresses', adalias)).add_column(func.count(adalias.id).label('count')) + eq_(q.all(), expected) + sess.expunge_all() + + s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id) + q = sess.query(User) + l = q.add_column("count").from_statement(s).all() + assert l == expected + + + def test_raw_columns(self): + sess = create_session() + (user7, user8, user9, user10) = sess.query(User).all() + expected = [ + (user7, 1, "Name:jack"), + (user8, 3, "Name:ed"), + (user9, 1, "Name:fred"), + (user10, 0, "Name:chuck")] + + adalias = addresses.alias() + q = create_session().query(User).add_column(func.count(adalias.c.id))\ + .add_column(("Name:" + users.c.name)).outerjoin(('addresses', adalias))\ + .group_by([c for c in users.c]).order_by(users.c.id) + + assert q.all() == expected + + # test with a straight statement + s = select([users, func.count(addresses.c.id).label('count'), ("Name:" + users.c.name).label('concat')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.id]) + q = create_session().query(User) + l = q.add_column("count").add_column("concat").from_statement(s).all() + assert l == expected + + sess.expunge_all() + + # test with select_from() + q = create_session().query(User).add_column(func.count(addresses.c.id))\ + .add_column(("Name:" + users.c.name)).select_from(users.outerjoin(addresses))\ + .group_by([c for c in users.c]).order_by(users.c.id) + + assert q.all() == expected + sess.expunge_all() + + q = create_session().query(User).add_column(func.count(addresses.c.id))\ + .add_column(("Name:" + users.c.name)).outerjoin('addresses')\ + .group_by([c for c in users.c]).order_by(users.c.id) + + assert q.all() == expected + sess.expunge_all() + + q = create_session().query(User).add_column(func.count(adalias.c.id))\ + .add_column(("Name:" + users.c.name)).outerjoin(('addresses', adalias))\ + .group_by([c for c in users.c]).order_by(users.c.id) + + assert q.all() == expected + sess.expunge_all() + +class ImmediateTest(_fixtures.FixtureTest): + run_inserts = 'once' + run_deletes = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Address, addresses) + + mapper(User, users, properties=dict( + addresses=relation(Address))) + + @testing.resolve_artifact_names + def test_one(self): + sess = create_session() + + assert_raises(sa.orm.exc.NoResultFound, + sess.query(User).filter(User.id == 99).one) + + eq_(sess.query(User).filter(User.id == 7).one().id, 7) + + assert_raises(sa.orm.exc.MultipleResultsFound, + sess.query(User).one) + + assert_raises( + sa.orm.exc.NoResultFound, + sess.query(User.id, User.name).filter(User.id == 99).one) + + eq_(sess.query(User.id, User.name).filter(User.id == 7).one(), + (7, 'jack')) + + assert_raises(sa.orm.exc.MultipleResultsFound, + sess.query(User.id, User.name).one) + + assert_raises(sa.orm.exc.NoResultFound, + (sess.query(User, Address). + join(User.addresses). + filter(Address.id == 99)).one) + + eq_((sess.query(User, Address). + join(User.addresses). + filter(Address.id == 4)).one(), + (User(id=8), Address(id=4))) + + assert_raises(sa.orm.exc.MultipleResultsFound, + sess.query(User, Address).join(User.addresses).one) + + @testing.future + def test_getslice(self): + assert False + + @testing.resolve_artifact_names + def test_scalar(self): + sess = create_session() + + eq_(sess.query(User.id).filter_by(id=7).scalar(), 7) + eq_(sess.query(User.id, User.name).filter_by(id=7).scalar(), 7) + eq_(sess.query(User.id).filter_by(id=0).scalar(), None) + eq_(sess.query(User).filter_by(id=7).scalar(), + sess.query(User).filter_by(id=7).one()) + + @testing.resolve_artifact_names + def test_value(self): + sess = create_session() + + eq_(sess.query(User).filter_by(id=7).value(User.id), 7) + eq_(sess.query(User.id, User.name).filter_by(id=7).value(User.id), 7) + eq_(sess.query(User).filter_by(id=0).value(User.id), None) + + sess.bind = testing.db + eq_(sess.query().value(sa.literal_column('1').label('x')), 1) + + +class SelectFromTest(QueryTest): + run_setup_mappers = None + + def test_replace_with_select(self): + mapper(User, users, properties = { + 'addresses':relation(Address) + }) + mapper(Address, addresses) + + sel = users.select(users.c.id.in_([7, 8])).alias() + sess = create_session() + + eq_(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)]) + + eq_(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)]) + + eq_(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [ + User(name='jack',id=7), User(name='ed',id=8) + ]) + + eq_(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [ + User(name='ed',id=8), User(name='jack',id=7) + ]) + + eq_(sess.query(User).select_from(sel).options(eagerload('addresses')).first(), + User(name='jack', addresses=[Address(id=1)]) + ) + + def test_join_mapper_order_by(self): + """test that mapper-level order_by is adapted to a selectable.""" + + mapper(User, users, order_by=users.c.id) + + sel = users.select(users.c.id.in_([7, 8])) + sess = create_session() + + eq_(sess.query(User).select_from(sel).all(), + [ + User(name='jack',id=7), User(name='ed',id=8) + ] + ) + + def test_join_no_order_by(self): + mapper(User, users) + + sel = users.select(users.c.id.in_([7, 8])) + sess = create_session() + + eq_(sess.query(User).select_from(sel).all(), + [ + User(name='jack',id=7), User(name='ed',id=8) + ] + ) + + def test_join(self): + mapper(User, users, properties = { + 'addresses':relation(Address) + }) + mapper(Address, addresses) + + sel = users.select(users.c.id.in_([7, 8])) + sess = create_session() + + eq_(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(), + [ + (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@bettyboop.com',id=3)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@lala.com',id=4)) + ] + ) + + adalias = aliased(Address) + eq_(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(), + [ + (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@bettyboop.com',id=3)), + (User(name='ed',id=8), Address(user_id=8,email_address='ed@lala.com',id=4)) + ] + ) + + + def test_more_joins(self): + mapper(User, users, properties={ + 'orders':relation(Order, backref='user'), # o2m, m2o + }) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, order_by=items.c.id), #m2m + }) + mapper(Item, items, properties={ + 'keywords':relation(Keyword, secondary=item_keywords, order_by=keywords.c.id) #m2m + }) + mapper(Keyword, keywords) + + sel = users.select(users.c.id.in_([7, 8])) + sess = create_session() + + # TODO: remove + sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all() + + eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + User(name=u'jack',id=7) + ]) + + eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + User(name=u'jack',id=7) + ]) + + def go(): + eq_(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + User(name=u'jack',orders=[ + Order(description=u'order 1',items=[ + Item(description=u'item 1',keywords=[Keyword(name=u'red'), Keyword(name=u'big'), Keyword(name=u'round')]), + Item(description=u'item 2',keywords=[Keyword(name=u'red',id=2), Keyword(name=u'small',id=5), Keyword(name=u'square')]), + Item(description=u'item 3',keywords=[Keyword(name=u'green',id=3), Keyword(name=u'big',id=4), Keyword(name=u'round',id=6)]) + ]), + Order(description=u'order 3',items=[ + Item(description=u'item 3',keywords=[Keyword(name=u'green',id=3), Keyword(name=u'big',id=4), Keyword(name=u'round',id=6)]), + Item(description=u'item 4',keywords=[],id=4), + Item(description=u'item 5',keywords=[],id=5) + ]), + Order(description=u'order 5',items=[Item(description=u'item 5',keywords=[])])]) + ]) + self.assert_sql_count(testing.db, go, 1) + + sess.expunge_all() + sel2 = orders.select(orders.c.id.in_([1,2,3])) + eq_(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').order_by(Order.id).all(), [ + Order(description=u'order 1',id=1), + Order(description=u'order 2',id=2), + ]) + eq_(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [ + Order(description=u'order 1',id=1), + Order(description=u'order 2',id=2), + ]) + + + def test_replace_with_eager(self): + mapper(User, users, properties = { + 'addresses':relation(Address, order_by=addresses.c.id) + }) + mapper(Address, addresses) + + sel = users.select(users.c.id.in_([7, 8])) + sess = create_session() + + def go(): + eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id).all(), + [ + User(id=7, addresses=[Address(id=1)]), + User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]) + ] + ) + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + def go(): + eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).order_by(User.id).all(), + [User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])] + ) + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + def go(): + eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])) + self.assert_sql_count(testing.db, go, 1) + +class CustomJoinTest(QueryTest): + run_setup_mappers = None + + def test_double_same_mappers(self): + """test aliasing of joins with a custom join condition""" + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy=True, order_by=items.c.id), + }) + mapper(Item, items) + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=True), + open_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 1, users.c.id==orders.c.user_id), lazy=True), + closed_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 0, users.c.id==orders.c.user_id), lazy=True) + )) + q = create_session().query(User) + + assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all() + +class SelfReferentialTest(_base.MappedTest): + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + global nodes + nodes = Table('nodes', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('parent_id', Integer, ForeignKey('nodes.id')), + Column('data', String(30))) + + @classmethod + def insert_data(cls): + global Node + + class Node(Base): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=True, join_depth=3, + backref=backref('parent', remote_side=[nodes.c.id]) + ) + }) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + n1.append(Node(data='n13')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) + sess.add(n1) + sess.flush() + sess.close() + + def test_join(self): + sess = create_session() + + node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first() + assert node.data=='n12' + + ret = sess.query(Node.data).join(Node.children, aliased=True).filter_by(data='n122').all() + assert ret == [('n12',)] + + + node = sess.query(Node).join(['children', 'children'], aliased=True).filter_by(data='n122').first() + assert node.data=='n1' + + node = sess.query(Node).filter_by(data='n122').join('parent', aliased=True).filter_by(data='n12').\ + join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first() + assert node.data == 'n122' + + def test_explicit_join(self): + sess = create_session() + + n1 = aliased(Node) + n2 = aliased(Node) + + node = sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data=='n122').first() + assert node.data=='n12' + + node = sess.query(Node).select_from(join(Node, n1, 'children').join(n2, 'children')).\ + filter(n2.data=='n122').first() + assert node.data=='n1' + + # mix explicit and named onclauses + node = sess.query(Node).select_from(join(Node, n1, Node.id==n1.parent_id).join(n2, 'children')).\ + filter(n2.data=='n122').first() + assert node.data=='n1' + + node = sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\ + filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).first() + assert node.data == 'n122' + + eq_( + list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\ + filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)), + [('n122', 'n12', 'n1')]) + + def test_join_to_nonaliased(self): + sess = create_session() + + n1 = aliased(Node) + + # using 'n1.parent' implicitly joins to unaliased Node + eq_( + sess.query(n1).join(n1.parent).filter(Node.data=='n1').all(), + [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)] + ) + + # explicit (new syntax) + eq_( + sess.query(n1).join((Node, n1.parent)).filter(Node.data=='n1').all(), + [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)] + ) + + def test_multiple_explicit_entities(self): + sess = create_session() + + parent = aliased(Node) + grandparent = aliased(Node) + eq_( + sess.query(Node, parent, grandparent).\ + join((Node.parent, parent), (parent.parent, grandparent)).\ + filter(Node.data=='n122').filter(parent.data=='n12').\ + filter(grandparent.data=='n1').first(), + (Node(data='n122'), Node(data='n12'), Node(data='n1')) + ) + + eq_( + sess.query(Node, parent, grandparent).\ + join((Node.parent, parent), (parent.parent, grandparent)).\ + filter(Node.data=='n122').filter(parent.data=='n12').\ + filter(grandparent.data=='n1')._from_self().first(), + (Node(data='n122'), Node(data='n12'), Node(data='n1')) + ) + + # same, change order around + eq_( + sess.query(parent, grandparent, Node).\ + join((Node.parent, parent), (parent.parent, grandparent)).\ + filter(Node.data=='n122').filter(parent.data=='n12').\ + filter(grandparent.data=='n1')._from_self().first(), + (Node(data='n12'), Node(data='n1'), Node(data='n122')) + ) + + eq_( + sess.query(Node, parent, grandparent).\ + join((Node.parent, parent), (parent.parent, grandparent)).\ + filter(Node.data=='n122').filter(parent.data=='n12').\ + filter(grandparent.data=='n1').\ + options(eagerload(Node.children)).first(), + (Node(data='n122'), Node(data='n12'), Node(data='n1')) + ) + + eq_( + sess.query(Node, parent, grandparent).\ + join((Node.parent, parent), (parent.parent, grandparent)).\ + filter(Node.data=='n122').filter(parent.data=='n12').\ + filter(grandparent.data=='n1')._from_self().\ + options(eagerload(Node.children)).first(), + (Node(data='n122'), Node(data='n12'), Node(data='n1')) + ) + + + def test_any(self): + sess = create_session() + eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), []) + eq_(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')]) + eq_(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) + + def test_has(self): + sess = create_session() + + eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), []) + eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')]) + + def test_contains(self): + sess = create_session() + + n122 = sess.query(Node).filter(Node.data=='n122').one() + eq_(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')]) + + n13 = sess.query(Node).filter(Node.data=='n13').one() + eq_(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')]) + + def test_eq_ne(self): + sess = create_session() + + n12 = sess.query(Node).filter(Node.data=='n12').one() + eq_(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + + eq_(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')]) + +class SelfReferentialM2MTest(_base.MappedTest): + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + global nodes, node_to_nodes + nodes = Table('nodes', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('data', String(30))) + + node_to_nodes =Table('node_to_nodes', metadata, + Column('left_node_id', Integer, ForeignKey('nodes.id'),primary_key=True), + Column('right_node_id', Integer, ForeignKey('nodes.id'),primary_key=True), + ) + + @classmethod + def insert_data(cls): + global Node + + class Node(Base): + pass + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=True, secondary=node_to_nodes, + primaryjoin=nodes.c.id==node_to_nodes.c.left_node_id, + secondaryjoin=nodes.c.id==node_to_nodes.c.right_node_id, + ) + }) + sess = create_session() + n1 = Node(data='n1') + n2 = Node(data='n2') + n3 = Node(data='n3') + n4 = Node(data='n4') + n5 = Node(data='n5') + n6 = Node(data='n6') + n7 = Node(data='n7') + + n1.children = [n2, n3, n4] + n2.children = [n3, n6, n7] + n3.children = [n5, n4] + + sess.add(n1) + sess.add(n2) + sess.add(n3) + sess.add(n4) + sess.flush() + sess.close() + + def test_any(self): + sess = create_session() + eq_(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')]) + + def test_contains(self): + sess = create_session() + n4 = sess.query(Node).filter_by(data='n4').one() + + eq_(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')]) + eq_(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')]) + + def test_explicit_join(self): + sess = create_session() + + n1 = aliased(Node) + eq_( + sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data.in_(['n3', 'n7'])).order_by(Node.id).all(), + [Node(data='n1'), Node(data='n2')] + ) + +class ExternalColumnsTest(QueryTest): + """test mappers with SQL-expressions added as column properties.""" + + run_setup_mappers = None + + def test_external_columns_bad(self): + + assert_raises_message(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={ + 'concat': (users.c.id * 2), + }) + clear_mappers() + + def test_external_columns(self): + """test querying mappings that reference external columns or selectables.""" + + mapper(User, users, properties={ + 'concat': column_property((users.c.id * 2)), + 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users).as_scalar()) + }) + + mapper(Address, addresses, properties={ + 'user':relation(User) + }) + + sess = create_session() + + sess.query(Address).options(eagerload('user')).all() + + eq_(sess.query(User).all(), + [ + User(id=7, concat=14, count=1), + User(id=8, concat=16, count=3), + User(id=9, concat=18, count=1), + User(id=10, concat=20, count=0), + ] + ) + + address_result = [ + Address(id=1, user=User(id=7, concat=14, count=1)), + Address(id=2, user=User(id=8, concat=16, count=3)), + Address(id=3, user=User(id=8, concat=16, count=3)), + Address(id=4, user=User(id=8, concat=16, count=3)), + Address(id=5, user=User(id=9, concat=18, count=1)) + ] + eq_(sess.query(Address).all(), address_result) + + # run the eager version twice to test caching of aliased clauses + for x in range(2): + sess.expunge_all() + def go(): + eq_(sess.query(Address).options(eagerload('user')).all(), address_result) + self.assert_sql_count(testing.db, go, 1) + + ualias = aliased(User) + eq_( + sess.query(Address, ualias).join(('user', ualias)).all(), + [(address, address.user) for address in address_result] + ) + + eq_( + sess.query(Address, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(), + [ + (Address(id=1), 1), + (Address(id=2), 3), + (Address(id=3), 3), + (Address(id=4), 3), + (Address(id=5), 1) + ] + ) + + eq_(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(), + [ + (Address(id=1), 14, 1), + (Address(id=2), 16, 3), + (Address(id=3), 16, 3), + (Address(id=4), 16, 3), + (Address(id=5), 18, 1) + ] + ) + + ua = aliased(User) + eq_(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(), + [ + (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1), + (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3), + (Address(id=3, user=User(id=8, concat=16, count=3)), 16, 3), + (Address(id=4, user=User(id=8, concat=16, count=3)), 16, 3), + (Address(id=5, user=User(id=9, concat=18, count=1)), 18, 1) + ] + ) + + eq_(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), + [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] + ) + + eq_(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)), + [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] + ) + + def test_external_columns_eagerload(self): + # in this test, we have a subquery on User that accesses "addresses", underneath + # an eagerload for "addresses". So the "addresses" alias adapter needs to *not* hit + # the "addresses" table within the "user" subquery, but "user" still needs to be adapted. + # therefore the long standing practice of eager adapters being "chained" has been removed + # since its unnecessary and breaks this exact condition. + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', order_by=addresses.c.id), + 'concat': column_property((users.c.id * 2)), + 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users)) + }) + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'address':relation(Address), # m2o + }) + + sess = create_session() + def go(): + o1 = sess.query(Order).options(eagerload_all('address.user')).get(1) + eq_(o1.address.user.count, 1) + self.assert_sql_count(testing.db, go, 1) + + sess = create_session() + def go(): + o1 = sess.query(Order).options(eagerload_all('address.user')).first() + eq_(o1.address.user.count, 1) + self.assert_sql_count(testing.db, go, 1) + +class TestOverlyEagerEquivalentCols(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + global base, sub1, sub2 + base = Table('base', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + sub1 = Table('sub1', metadata, + Column('id', Integer, ForeignKey('base.id'), primary_key=True), + Column('data', String(50)) + ) + + sub2 = Table('sub2', metadata, + Column('id', Integer, ForeignKey('base.id'), ForeignKey('sub1.id'), primary_key=True), + Column('data', String(50)) + ) + + def test_equivs(self): + class Base(_base.ComparableEntity): + pass + class Sub1(_base.ComparableEntity): + pass + class Sub2(_base.ComparableEntity): + pass + + mapper(Base, base, properties={ + 'sub1':relation(Sub1), + 'sub2':relation(Sub2) + }) + + mapper(Sub1, sub1) + mapper(Sub2, sub2) + sess = create_session() + + s11 = Sub1(data='s11') + s12 = Sub1(data='s12') + s2 = Sub2(data='s2') + b1 = Base(data='b1', sub1=[s11], sub2=[]) + b2 = Base(data='b1', sub1=[s12], sub2=[]) + sess.add(b1) + sess.add(b2) + sess.flush() + + # theres an overlapping ForeignKey here, so not much option except + # to artifically control the flush order + b2.sub2 = [s2] + sess.flush() + + q = sess.query(Base).outerjoin('sub2', aliased=True) + assert sub1.c.id not in q._filter_aliases.equivalents + + eq_( + sess.query(Base).join('sub1').outerjoin('sub2', aliased=True).\ + filter(Sub1.id==1).one(), + b1 + ) + +class UpdateDeleteTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(32)), + Column('age', Integer)) + + Table('documents', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', None, ForeignKey('users.id')), + Column('title', String(32))) + + @classmethod + def setup_classes(cls): + class User(_base.ComparableEntity): + pass + + class Document(_base.ComparableEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + users.insert().execute([ + dict(id=1, name='john', age=25), + dict(id=2, name='jack', age=47), + dict(id=3, name='jill', age=29), + dict(id=4, name='jane', age=37), + ]) + + @testing.resolve_artifact_names + def insert_documents(self): + documents.insert().execute([ + dict(id=1, user_id=1, title='foo'), + dict(id=2, user_id=1, title='bar'), + dict(id=3, user_id=2, title='baz'), + ]) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(User, users) + mapper(Document, documents, properties={ + 'user': relation(User, lazy=False, backref=backref('documents', lazy=True)) + }) + + @testing.resolve_artifact_names + def test_delete(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete() + + assert john not in sess and jill not in sess + + eq_(sess.query(User).order_by(User.id).all(), [jack,jane]) + + @testing.resolve_artifact_names + def test_delete_with_bindparams(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter('name = :name').params(name='john').delete() + assert john not in sess + + eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane]) + + @testing.resolve_artifact_names + def test_delete_rollback(self): + sess = sessionmaker()() + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session='evaluate') + assert john not in sess and jill not in sess + sess.rollback() + assert john in sess and jill in sess + + @testing.resolve_artifact_names + def test_delete_rollback_with_fetch(self): + sess = sessionmaker()() + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session='fetch') + assert john not in sess and jill not in sess + sess.rollback() + assert john in sess and jill in sess + + @testing.resolve_artifact_names + def test_delete_without_session_sync(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session=False) + + assert john in sess and jill in sess + + eq_(sess.query(User).order_by(User.id).all(), [jack,jane]) + + @testing.resolve_artifact_names + def test_delete_with_fetch_strategy(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session='fetch') + + assert john not in sess and jill not in sess + + eq_(sess.query(User).order_by(User.id).all(), [jack,jane]) + + @testing.fails_on('mysql', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_delete_fallback(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(User.name == select([func.max(User.name)])).delete(synchronize_session='evaluate') + + assert john not in sess + + eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane]) + + @testing.resolve_artifact_names + def test_update(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='evaluate') + + eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) + eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) + + sess.query(User).filter(User.age > 29).update({User.age: User.age - 10}, synchronize_session='evaluate') + eq_([john.age, jack.age, jill.age, jane.age], [25,27,29,27]) + eq_(sess.query(User.age).order_by(User.id).all(), zip([25,27,29,27])) + + sess.query(User).filter(User.age > 27).update({users.c.age: User.age - 10}, synchronize_session='evaluate') + eq_([john.age, jack.age, jill.age, jane.age], [25,27,19,27]) + eq_(sess.query(User.age).order_by(User.id).all(), zip([25,27,19,27])) + + + @testing.resolve_artifact_names + def test_update_with_bindparams(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + + sess.query(User).filter('age > :x').params(x=29).update({'age': User.age - 10}, synchronize_session='evaluate') + + eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) + eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) + + @testing.resolve_artifact_names + def test_update_changes_resets_dirty(self): + sess = create_session(bind=testing.db, autocommit=False, autoflush=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + + john.age = 50 + jack.age = 37 + + # autoflush is false. therefore our '50' and '37' are getting blown away by this operation. + + sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='evaluate') + + for x in (john, jack, jill, jane): + assert not sess.is_modified(x) + + eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) + + john.age = 25 + assert john in sess.dirty + assert jack in sess.dirty + assert jill not in sess.dirty + assert not sess.is_modified(john) + assert not sess.is_modified(jack) + + @testing.resolve_artifact_names + def test_update_changes_with_autoflush(self): + sess = create_session(bind=testing.db, autocommit=False, autoflush=True) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + + john.age = 50 + jack.age = 37 + + sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='evaluate') + + for x in (john, jack, jill, jane): + assert not sess.is_modified(x) + + eq_([john.age, jack.age, jill.age, jane.age], [40, 27, 29, 27]) + + john.age = 25 + assert john in sess.dirty + assert jack not in sess.dirty + assert jill not in sess.dirty + assert sess.is_modified(john) + assert not sess.is_modified(jack) + + + + @testing.resolve_artifact_names + def test_update_with_expire_strategy(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire') + + eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) + eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) + + @testing.resolve_artifact_names + def test_update_returns_rowcount(self): + sess = create_session(bind=testing.db, autocommit=False) + + rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age + 0}) + eq_(rowcount, 2) + + rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10}) + eq_(rowcount, 2) + + @testing.resolve_artifact_names + def test_delete_returns_rowcount(self): + sess = create_session(bind=testing.db, autocommit=False) + + rowcount = sess.query(User).filter(User.age > 26).delete(synchronize_session=False) + eq_(rowcount, 3) + + @testing.resolve_artifact_names + def test_update_with_eager_relations(self): + self.insert_documents() + + sess = create_session(bind=testing.db, autocommit=False) + + foo,bar,baz = sess.query(Document).order_by(Document.id).all() + sess.query(Document).filter(Document.user_id == 1).update({'title': Document.title+Document.title}, synchronize_session='expire') + + eq_([foo.title, bar.title, baz.title], ['foofoo','barbar', 'baz']) + eq_(sess.query(Document.title).order_by(Document.id).all(), zip(['foofoo','barbar', 'baz'])) + + @testing.resolve_artifact_names + def test_update_with_explicit_eagerload(self): + sess = create_session(bind=testing.db, autocommit=False) + + john,jack,jill,jane = sess.query(User).order_by(User.id).all() + sess.query(User).options(eagerload(User.documents)).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire') + + eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) + eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) + + @testing.resolve_artifact_names + def test_delete_with_eager_relations(self): + self.insert_documents() + + sess = create_session(bind=testing.db, autocommit=False) + + sess.query(Document).filter(Document.user_id == 1).delete(synchronize_session=False) + + eq_(sess.query(Document.title).all(), zip(['baz'])) + diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py new file mode 100644 index 000000000..1bc074c31 --- /dev/null +++ b/test/orm/test_relationships.py @@ -0,0 +1,1907 @@ +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import datetime +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, MetaData, and_ +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers +from sqlalchemy.test.testing import eq_, startswith_ +from test.orm import _base, _fixtures + + +class RelationTest(_base.MappedTest): + """An extended topological sort test + + This is essentially an extension of the "dependency.py" topological sort + test. In this test, a table is dependent on two other tables that are + otherwise unrelated to each other. The dependency sort must ensure that + this childmost table is below both parent tables in the outcome (a bug + existed where this was not always the case). + + While the straight topological sort tests should expose this, since the + sorting can be different due to subtle differences in program execution, + this test case was exposing the bug whereas the simpler tests were not. + + """ + + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + Table("tbl_a", metadata, + Column("id", Integer, primary_key=True), + Column("name", String(128))) + Table("tbl_b", metadata, + Column("id", Integer, primary_key=True), + Column("name", String(128))) + Table("tbl_c", metadata, + Column("id", Integer, primary_key=True), + Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False), + Column("name", String(128))) + Table("tbl_d", metadata, + Column("id", Integer, primary_key=True), + Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False), + Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")), + Column("name", String(128))) + + @classmethod + def setup_classes(cls): + class A(_base.Entity): + pass + class B(_base.Entity): + pass + class C(_base.Entity): + pass + class D(_base.Entity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(A, tbl_a, properties=dict( + c_rows=relation(C, cascade="all, delete-orphan", backref="a_row"))) + mapper(B, tbl_b) + mapper(C, tbl_c, properties=dict( + d_rows=relation(D, cascade="all, delete-orphan", backref="c_row"))) + mapper(D, tbl_d, properties=dict( + b_row=relation(B))) + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + session = create_session() + a = A(name='a1') + b = B(name='b1') + c = C(name='c1', a_row=a) + + d1 = D(name='d1', b_row=b, c_row=c) + d2 = D(name='d2', b_row=b, c_row=c) + d3 = D(name='d3', b_row=b, c_row=c) + session.add(a) + session.add(b) + session.flush() + + @testing.resolve_artifact_names + def testDeleteRootTable(self): + session = create_session() + a = session.query(A).filter_by(name='a1').one() + + session.delete(a) + session.flush() + + @testing.resolve_artifact_names + def testDeleteMiddleTable(self): + session = create_session() + c = session.query(C).filter_by(name='c1').one() + + session.delete(c) + session.flush() + + +class RelationTest2(_base.MappedTest): + """Tests a relationship on a column included in multiple foreign keys. + + This test tests a relationship on a column that is included in multiple + foreign keys, as well as a self-referential relationship on a composite + key where one column in the foreign key is 'joined to itself'. + + """ + @classmethod + def define_tables(cls, metadata): + Table('company_t', metadata, + Column('company_id', Integer, primary_key=True), + Column('name', sa.Unicode(30))) + + Table('employee_t', metadata, + Column('company_id', Integer, primary_key=True), + Column('emp_id', Integer, primary_key=True), + Column('name', sa.Unicode(30)), + Column('reports_to_id', Integer), + sa.ForeignKeyConstraint( + ['company_id'], + ['company_t.company_id']), + sa.ForeignKeyConstraint( + ['company_id', 'reports_to_id'], + ['employee_t.company_id', 'employee_t.emp_id'])) + + @testing.resolve_artifact_names + def test_explicit(self): + """test with mappers that have fairly explicit join conditions""" + + class Company(_base.Entity): + pass + + class Employee(_base.Entity): + def __init__(self, name, company, emp_id, reports_to=None): + self.name = name + self.company = company + self.emp_id = emp_id + self.reports_to = reports_to + + mapper(Company, company_t) + mapper(Employee, employee_t, properties= { + 'company':relation(Company, primaryjoin=employee_t.c.company_id==company_t.c.company_id, backref='employees'), + 'reports_to':relation(Employee, primaryjoin= + sa.and_( + employee_t.c.emp_id==employee_t.c.reports_to_id, + employee_t.c.company_id==employee_t.c.company_id + ), + remote_side=[employee_t.c.emp_id, employee_t.c.company_id], + foreign_keys=[employee_t.c.reports_to_id], + backref='employees') + }) + + sess = create_session() + c1 = Company() + c2 = Company() + + e1 = Employee(u'emp1', c1, 1) + e2 = Employee(u'emp2', c1, 2, e1) + e3 = Employee(u'emp3', c1, 3, e1) + e4 = Employee(u'emp4', c1, 4, e3) + e5 = Employee(u'emp5', c2, 1) + e6 = Employee(u'emp6', c2, 2, e5) + e7 = Employee(u'emp7', c2, 3, e5) + + sess.add_all((c1, c2)) + sess.flush() + sess.expunge_all() + + test_c1 = sess.query(Company).get(c1.company_id) + test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id]) + assert test_e1.name == 'emp1', test_e1.name + test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id]) + assert test_e5.name == 'emp5', test_e5.name + assert [x.name for x in test_e1.employees] == ['emp2', 'emp3'] + eq_(sess.query(Employee).get([c1.company_id, 3]).reports_to.name, 'emp1') + eq_(sess.query(Employee).get([c2.company_id, 3]).reports_to.name, 'emp5') + + @testing.resolve_artifact_names + def test_implicit(self): + """test with mappers that have the most minimal arguments""" + class Company(_base.Entity): + pass + class Employee(_base.Entity): + def __init__(self, name, company, emp_id, reports_to=None): + self.name = name + self.company = company + self.emp_id = emp_id + self.reports_to = reports_to + + mapper(Company, company_t) + mapper(Employee, employee_t, properties= { + 'company':relation(Company, backref='employees'), + 'reports_to':relation(Employee, + remote_side=[employee_t.c.emp_id, employee_t.c.company_id], + foreign_keys=[employee_t.c.reports_to_id], + backref='employees') + }) + + sess = create_session() + c1 = Company() + c2 = Company() + + e1 = Employee(u'emp1', c1, 1) + e2 = Employee(u'emp2', c1, 2, e1) + e3 = Employee(u'emp3', c1, 3, e1) + e4 = Employee(u'emp4', c1, 4, e3) + e5 = Employee(u'emp5', c2, 1) + e6 = Employee(u'emp6', c2, 2, e5) + e7 = Employee(u'emp7', c2, 3, e5) + + sess.add_all((c1, c2)) + sess.flush() + sess.expunge_all() + + test_c1 = sess.query(Company).get(c1.company_id) + test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id]) + assert test_e1.name == 'emp1', test_e1.name + test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id]) + assert test_e5.name == 'emp5', test_e5.name + assert [x.name for x in test_e1.employees] == ['emp2', 'emp3'] + assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1' + assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5' + +class RelationTest3(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table("jobs", metadata, + Column("jobno", sa.Unicode(15), primary_key=True), + Column("created", sa.DateTime, nullable=False, + default=datetime.datetime.now), + Column("deleted", sa.Boolean, nullable=False, default=False)) + + Table("pageversions", metadata, + Column("jobno", sa.Unicode(15), primary_key=True), + Column("pagename", sa.Unicode(30), primary_key=True), + Column("version", Integer, primary_key=True, default=1), + Column("created", sa.DateTime, nullable=False, + default=datetime.datetime.now), + Column("md5sum", String(32)), + Column("width", Integer, nullable=False, default=0), + Column("height", Integer, nullable=False, default=0), + sa.ForeignKeyConstraint( + ["jobno", "pagename"], + ["pages.jobno", "pages.pagename"])) + + Table("pages", metadata, + Column("jobno", sa.Unicode(15), ForeignKey("jobs.jobno"), + primary_key=True), + Column("pagename", sa.Unicode(30), primary_key=True), + Column("created", sa.DateTime, nullable=False, + default=datetime.datetime.now), + Column("deleted", sa.Boolean, nullable=False, default=False), + Column("current_version", Integer)) + + Table("pagecomments", metadata, + Column("jobno", sa.Unicode(15), primary_key=True), + Column("pagename", sa.Unicode(30), primary_key=True), + Column("comment_id", Integer, primary_key=True, + autoincrement=False), + Column("content", sa.UnicodeText), + sa.ForeignKeyConstraint( + ["jobno", "pagename"], + ["pages.jobno", "pages.pagename"])) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + class Job(_base.Entity): + def create_page(self, pagename): + return Page(job=self, pagename=pagename) + class PageVersion(_base.Entity): + def __init__(self, page=None, version=None): + self.page = page + self.version = version + class Page(_base.Entity): + def __init__(self, job=None, pagename=None): + self.job = job + self.pagename = pagename + self.currentversion = PageVersion(self, 1) + def add_version(self): + self.currentversion = PageVersion( + page=self, version=self.currentversion.version+1) + comment = self.add_comment() + comment.closeable = False + comment.content = u'some content' + return self.currentversion + def add_comment(self): + nextnum = max([-1] + [c.comment_id for c in self.comments]) + 1 + newcomment = PageComment() + newcomment.comment_id = nextnum + self.comments.append(newcomment) + newcomment.created_version = self.currentversion.version + return newcomment + class PageComment(_base.Entity): + pass + + mapper(Job, jobs) + mapper(PageVersion, pageversions) + mapper(Page, pages, properties={ + 'job': relation( + Job, + backref=backref('pages', + cascade="all, delete-orphan", + order_by=pages.c.pagename)), + 'currentversion': relation( + PageVersion, + foreign_keys=[pages.c.current_version], + primaryjoin=sa.and_( + pages.c.jobno==pageversions.c.jobno, + pages.c.pagename==pageversions.c.pagename, + pages.c.current_version==pageversions.c.version), + post_update=True), + 'versions': relation( + PageVersion, + cascade="all, delete-orphan", + primaryjoin=sa.and_(pages.c.jobno==pageversions.c.jobno, + pages.c.pagename==pageversions.c.pagename), + order_by=pageversions.c.version, + backref=backref('page',lazy=False) + )}) + mapper(PageComment, pagecomments, properties={ + 'page': relation( + Page, + primaryjoin=sa.and_(pages.c.jobno==pagecomments.c.jobno, + pages.c.pagename==pagecomments.c.pagename), + backref=backref("comments", + cascade="all, delete-orphan", + order_by=pagecomments.c.comment_id))}) + + @testing.resolve_artifact_names + def testbasic(self): + """A combination of complicated join conditions with post_update.""" + + j1 = Job(jobno=u'somejob') + j1.create_page(u'page1') + j1.create_page(u'page2') + j1.create_page(u'page3') + + j2 = Job(jobno=u'somejob2') + j2.create_page(u'page1') + j2.create_page(u'page2') + j2.create_page(u'page3') + + j2.pages[0].add_version() + j2.pages[0].add_version() + j2.pages[1].add_version() + + s = create_session() + s.add_all((j1, j2)) + + s.flush() + + s.expunge_all() + j = s.query(Job).filter_by(jobno=u'somejob').one() + oldp = list(j.pages) + j.pages = [] + + s.flush() + + s.expunge_all() + j = s.query(Job).filter_by(jobno=u'somejob2').one() + j.pages[1].current_version = 12 + s.delete(j) + s.flush() + +class RelationTest4(_base.MappedTest): + """Syncrules on foreign keys that are also primary""" + + @classmethod + def define_tables(cls, metadata): + Table("tableA", metadata, + Column("id",Integer,primary_key=True), + Column("foo",Integer,), + test_needs_fk=True) + Table("tableB",metadata, + Column("id",Integer,ForeignKey("tableA.id"),primary_key=True), + test_needs_fk=True) + + @classmethod + def setup_classes(cls): + class A(_base.Entity): + pass + + class B(_base.Entity): + pass + + @testing.resolve_artifact_names + def test_no_delete_PK_AtoB(self): + """A cant be deleted without B because B would have no PK value.""" + mapper(A, tableA, properties={ + 'bs':relation(B, cascade="save-update")}) + mapper(B, tableB) + + a1 = A() + a1.bs.append(B()) + sess = create_session() + sess.add(a1) + sess.flush() + + sess.delete(a1) + try: + sess.flush() + assert False + except AssertionError, e: + startswith_(str(e), + "Dependency rule tried to blank-out " + "primary key column 'tableB.id' on instance ") + + @testing.resolve_artifact_names + def test_no_delete_PK_BtoA(self): + mapper(B, tableB, properties={ + 'a':relation(A, cascade="save-update")}) + mapper(A, tableA) + + b1 = B() + a1 = A() + b1.a = a1 + sess = create_session() + sess.add(b1) + sess.flush() + b1.a = None + try: + sess.flush() + assert False + except AssertionError, e: + startswith_(str(e), + "Dependency rule tried to blank-out " + "primary key column 'tableB.id' on instance ") + + @testing.fails_on_everything_except('sqlite', 'mysql') + @testing.resolve_artifact_names + def test_nullPKsOK_BtoA(self): + # postgres cant handle a nullable PK column...? + tableC = Table('tablec', tableA.metadata, + Column('id', Integer, primary_key=True), + Column('a_id', Integer, ForeignKey('tableA.id'), + primary_key=True, autoincrement=False, nullable=True)) + tableC.create() + + class C(_base.Entity): + pass + mapper(C, tableC, properties={ + 'a':relation(A, cascade="save-update") + }, allow_null_pks=True) + mapper(A, tableA) + + c1 = C() + c1.id = 5 + c1.a = None + sess = create_session() + sess.add(c1) + # test that no error is raised. + sess.flush() + + @testing.resolve_artifact_names + def test_delete_cascade_BtoA(self): + """No 'blank the PK' error when the child is to be deleted as part of a cascade""" + + for cascade in ("save-update, delete", + #"save-update, delete-orphan", + "save-update, delete, delete-orphan"): + mapper(B, tableB, properties={ + 'a':relation(A, cascade=cascade, single_parent=True) + }) + mapper(A, tableA) + + b1 = B() + a1 = A() + b1.a = a1 + sess = create_session() + sess.add(b1) + sess.flush() + sess.delete(b1) + sess.flush() + assert a1 not in sess + assert b1 not in sess + sess.expunge_all() + sa.orm.clear_mappers() + + @testing.resolve_artifact_names + def test_delete_cascade_AtoB(self): + """No 'blank the PK' error when the child is to be deleted as part of a cascade""" + for cascade in ("save-update, delete", + #"save-update, delete-orphan", + "save-update, delete, delete-orphan"): + mapper(A, tableA, properties={ + 'bs':relation(B, cascade=cascade) + }) + mapper(B, tableB) + + a1 = A() + b1 = B() + a1.bs.append(b1) + sess = create_session() + sess.add(a1) + sess.flush() + + sess.delete(a1) + sess.flush() + assert a1 not in sess + assert b1 not in sess + sess.expunge_all() + sa.orm.clear_mappers() + + @testing.resolve_artifact_names + def test_delete_manual_AtoB(self): + mapper(A, tableA, properties={ + 'bs':relation(B, cascade="none")}) + mapper(B, tableB) + + a1 = A() + b1 = B() + a1.bs.append(b1) + sess = create_session() + sess.add(a1) + sess.add(b1) + sess.flush() + + sess.delete(a1) + sess.delete(b1) + sess.flush() + assert a1 not in sess + assert b1 not in sess + sess.expunge_all() + + @testing.resolve_artifact_names + def test_delete_manual_BtoA(self): + mapper(B, tableB, properties={ + 'a':relation(A, cascade="none")}) + mapper(A, tableA) + + b1 = B() + a1 = A() + b1.a = a1 + sess = create_session() + sess.add(b1) + sess.add(a1) + sess.flush() + sess.delete(b1) + sess.delete(a1) + sess.flush() + assert a1 not in sess + assert b1 not in sess + +class RelationTest5(_base.MappedTest): + """Test a map to a select that relates to a map to the table.""" + + @classmethod + def define_tables(cls, metadata): + Table('items', metadata, + Column('item_policy_num', String(10), primary_key=True, + key='policyNum'), + Column('item_policy_eff_date', sa.Date, primary_key=True, + key='policyEffDate'), + Column('item_type', String(20), primary_key=True, + key='type'), + Column('item_id', Integer, primary_key=True, + key='id', autoincrement=False)) + + @testing.resolve_artifact_names + def test_basic(self): + class Container(_base.Entity): + pass + class LineItem(_base.Entity): + pass + + container_select = sa.select( + [items.c.policyNum, items.c.policyEffDate, items.c.type], + distinct=True, + ).alias('container_select') + + mapper(LineItem, items) + + mapper(Container, + container_select, + order_by=sa.asc(container_select.c.type), + properties=dict( + lineItems=relation(LineItem, + lazy=True, + cascade='all, delete-orphan', + order_by=sa.asc(items.c.type), + primaryjoin=sa.and_( + container_select.c.policyNum==items.c.policyNum, + container_select.c.policyEffDate==items.c.policyEffDate, + container_select.c.type==items.c.type), + foreign_keys=[ + items.c.policyNum, + items.c.policyEffDate, + items.c.type]))) + + session = create_session() + con = Container() + con.policyNum = "99" + con.policyEffDate = datetime.date.today() + con.type = "TESTER" + session.add(con) + for i in range(0, 10): + li = LineItem() + li.id = i + con.lineItems.append(li) + session.add(li) + session.flush() + session.expunge_all() + newcon = session.query(Container).first() + assert con.policyNum == newcon.policyNum + assert len(newcon.lineItems) == 10 + for old, new in zip(con.lineItems, newcon.lineItems): + assert old.id == new.id + +class RelationTest6(_base.MappedTest): + """test a relation with a non-column entity in the primary join, + is not viewonly, and also has the non-column's clause mentioned in the + foreign keys list. + + """ + + @classmethod + def define_tables(cls, metadata): + Table('tags', metadata, Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + Table('tag_foo', metadata, + Column("id", Integer, primary_key=True), + Column('tagid', Integer), + Column("data", String(50)), + ) + + @testing.resolve_artifact_names + def test_basic(self): + class Tag(_base.ComparableEntity): + pass + class TagInstance(_base.ComparableEntity): + pass + + mapper(Tag, tags, properties={ + 'foo':relation(TagInstance, + primaryjoin=sa.and_(tag_foo.c.data=='iplc_case', + tag_foo.c.tagid==tags.c.id), + foreign_keys=[tag_foo.c.tagid, tag_foo.c.data], + ), + }) + + mapper(TagInstance, tag_foo) + + sess = create_session() + t1 = Tag(data='some tag') + t1.foo.append(TagInstance(data='iplc_case')) + t1.foo.append(TagInstance(data='not_iplc_case')) + sess.add(t1) + sess.flush() + sess.expunge_all() + + # relation works + eq_(sess.query(Tag).all(), [Tag(data='some tag', foo=[TagInstance(data='iplc_case')])]) + + # both TagInstances were persisted + eq_( + sess.query(TagInstance).order_by(TagInstance.data).all(), + [TagInstance(data='iplc_case'), TagInstance(data='not_iplc_case')] + ) + +class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest): + """test ambiguous joins due to FKs on both sides treated as self-referential. + + this mapping is very similar to that of test/orm/inheritance/query.py + SelfReferentialTestJoinedToBase , except that inheritance is not used + here. + + """ + + @classmethod + def define_tables(cls, metadata): + subscriber_table = Table('subscriber', metadata, + Column('id', Integer, primary_key=True), + Column('dummy', String(10)) # to appease older sqlite version + ) + + address_table = Table('address', + metadata, + Column('subscriber_id', Integer, ForeignKey('subscriber.id'), primary_key=True), + Column('type', String(1), primary_key=True), + ) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + subscriber_and_address = subscriber.join(address, + and_(address.c.subscriber_id==subscriber.c.id, address.c.type.in_(['A', 'B', 'C']))) + + class Address(_base.ComparableEntity): + pass + + class Subscriber(_base.ComparableEntity): + pass + + mapper(Address, address) + + mapper(Subscriber, subscriber_and_address, properties={ + 'id':[subscriber.c.id, address.c.subscriber_id], + 'addresses' : relation(Address, + backref=backref("customer")) + }) + + @testing.resolve_artifact_names + def test_mapping(self): + from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE + sess = create_session() + assert Subscriber.addresses.property.direction is ONETOMANY + assert Address.customer.property.direction is MANYTOONE + + s1 = Subscriber(type='A', + addresses = [ + Address(type='D'), + Address(type='E'), + ] + ) + a1 = Address(type='B', customer=Subscriber(type='C')) + + assert s1.addresses[0].customer is s1 + assert a1.customer.addresses[0] is a1 + + sess.add_all([s1, a1]) + + sess.flush() + sess.expunge_all() + + eq_( + sess.query(Subscriber).order_by(Subscriber.type).all(), + [ + Subscriber(id=1, type=u'A'), + Subscriber(id=2, type=u'B'), + Subscriber(id=2, type=u'C') + ] + ) + + +class ManualBackrefTest(_fixtures.FixtureTest): + """Test explicit relations that are backrefs to each other.""" + + run_inserts = None + + @testing.resolve_artifact_names + def test_o2m(self): + mapper(User, users, properties={ + 'addresses':relation(Address, back_populates='user') + }) + + mapper(Address, addresses, properties={ + 'user':relation(User, back_populates='addresses') + }) + + sess = create_session() + + u1 = User(name='u1') + a1 = Address(email_address='foo') + u1.addresses.append(a1) + assert a1.user is u1 + + sess.add(u1) + sess.flush() + sess.expire_all() + assert sess.query(Address).one() is a1 + assert a1.user is u1 + assert a1 in u1.addresses + + @testing.resolve_artifact_names + def test_invalid_key(self): + mapper(User, users, properties={ + 'addresses':relation(Address, back_populates='userr') + }) + + mapper(Address, addresses, properties={ + 'user':relation(User, back_populates='addresses') + }) + + assert_raises(sa.exc.InvalidRequestError, compile_mappers) + + @testing.resolve_artifact_names + def test_invalid_target(self): + mapper(User, users, properties={ + 'addresses':relation(Address, back_populates='dingaling'), + }) + + mapper(Dingaling, dingalings) + mapper(Address, addresses, properties={ + 'dingaling':relation(Dingaling) + }) + + assert_raises_message(sa.exc.ArgumentError, + r"reverse_property 'dingaling' on relation User.addresses references " + "relation Address.dingaling, which does not reference mapper Mapper\|User\|users", + compile_mappers) + +class JoinConditionErrorTest(testing.TestBase): + + def test_clauseelement_pj(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + class C1(Base): + __tablename__ = 'c1' + id = Column('id', Integer, primary_key=True) + class C2(Base): + __tablename__ = 'c2' + id = Column('id', Integer, primary_key=True) + c1id = Column('c1id', Integer, ForeignKey('c1.id')) + c2 = relation(C1, primaryjoin=C1.id) + + assert_raises(sa.exc.ArgumentError, compile_mappers) + + def test_clauseelement_pj_false(self): + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + class C1(Base): + __tablename__ = 'c1' + id = Column('id', Integer, primary_key=True) + class C2(Base): + __tablename__ = 'c2' + id = Column('id', Integer, primary_key=True) + c1id = Column('c1id', Integer, ForeignKey('c1.id')) + c2 = relation(C1, primaryjoin="x"=="y") + + assert_raises(sa.exc.ArgumentError, compile_mappers) + + + def test_fk_error_raised(self): + m = MetaData() + t1 = Table('t1', m, + Column('id', Integer, primary_key=True), + Column('foo_id', Integer, ForeignKey('t2.nonexistent_id')), + ) + t2 = Table('t2', m, + Column('id', Integer, primary_key=True), + ) + + t3 = Table('t3', m, + Column('id', Integer, primary_key=True), + Column('t1id', Integer, ForeignKey('t1.id')) + ) + + class C1(object): + pass + class C2(object): + pass + + mapper(C1, t1, properties={'c2':relation(C2)}) + mapper(C2, t3) + + assert_raises(sa.exc.NoReferencedColumnError, compile_mappers) + + def test_join_error_raised(self): + m = MetaData() + t1 = Table('t1', m, + Column('id', Integer, primary_key=True), + ) + t2 = Table('t2', m, + Column('id', Integer, primary_key=True), + ) + + t3 = Table('t3', m, + Column('id', Integer, primary_key=True), + Column('t1id', Integer) + ) + + class C1(object): + pass + class C2(object): + pass + + mapper(C1, t1, properties={'c2':relation(C2)}) + mapper(C2, t3) + + assert_raises(sa.exc.ArgumentError, compile_mappers) + + def teardown(self): + clear_mappers() + +class TypeMatchTest(_base.MappedTest): + """test errors raised when trying to add items whose type is not handled by a relation""" + + @classmethod + def define_tables(cls, metadata): + Table("a", metadata, + Column('aid', Integer, primary_key=True), + Column('data', String(30))) + Table("b", metadata, + Column('bid', Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.aid")), + Column('data', String(30))) + Table("c", metadata, + Column('cid', Integer, primary_key=True), + Column("b_id", Integer, ForeignKey("b.bid")), + Column('data', String(30))) + Table("d", metadata, + Column('did', Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.aid")), + Column('data', String(30))) + + @testing.resolve_artifact_names + def test_o2m_oncascade(self): + class A(_base.Entity): pass + class B(_base.Entity): pass + class C(_base.Entity): pass + mapper(A, a, properties={'bs':relation(B)}) + mapper(B, b) + mapper(C, c) + + a1 = A() + b1 = B() + c1 = C() + a1.bs.append(b1) + a1.bs.append(c1) + sess = create_session() + try: + sess.add(a1) + assert False + except AssertionError, err: + eq_(str(err), + "Attribute 'bs' on class '%s' doesn't handle " + "objects of type '%s'" % (A, C)) + + @testing.resolve_artifact_names + def test_o2m_onflush(self): + class A(_base.Entity): pass + class B(_base.Entity): pass + class C(_base.Entity): pass + mapper(A, a, properties={'bs':relation(B, cascade="none")}) + mapper(B, b) + mapper(C, c) + + a1 = A() + b1 = B() + c1 = C() + a1.bs.append(b1) + a1.bs.append(c1) + sess = create_session() + sess.add(a1) + sess.add(b1) + sess.add(c1) + assert_raises_message(sa.orm.exc.FlushError, + "Attempting to flush an item", sess.flush) + + @testing.resolve_artifact_names + def test_o2m_nopoly_onflush(self): + class A(_base.Entity): pass + class B(_base.Entity): pass + class C(B): pass + mapper(A, a, properties={'bs':relation(B, cascade="none")}) + mapper(B, b) + mapper(C, c, inherits=B) + + a1 = A() + b1 = B() + c1 = C() + a1.bs.append(b1) + a1.bs.append(c1) + sess = create_session() + sess.add(a1) + sess.add(b1) + sess.add(c1) + assert_raises_message(sa.orm.exc.FlushError, + "Attempting to flush an item", sess.flush) + + @testing.resolve_artifact_names + def test_m2o_nopoly_onflush(self): + class A(_base.Entity): pass + class B(A): pass + class D(_base.Entity): pass + mapper(A, a) + mapper(B, b, inherits=A) + mapper(D, d, properties={"a":relation(A, cascade="none")}) + b1 = B() + d1 = D() + d1.a = b1 + sess = create_session() + sess.add(b1) + sess.add(d1) + assert_raises_message(sa.orm.exc.FlushError, + "Attempting to flush an item", sess.flush) + + @testing.resolve_artifact_names + def test_m2o_oncascade(self): + class A(_base.Entity): pass + class B(_base.Entity): pass + class D(_base.Entity): pass + mapper(A, a) + mapper(B, b) + mapper(D, d, properties={"a":relation(A)}) + b1 = B() + d1 = D() + d1.a = b1 + sess = create_session() + assert_raises_message(AssertionError, + "doesn't handle objects of type", sess.add, d1) + +class TypedAssociationTable(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + class MySpecialType(sa.types.TypeDecorator): + impl = String + def process_bind_param(self, value, dialect): + return "lala" + value + def process_result_value(self, value, dialect): + return value[4:] + + Table('t1', metadata, + Column('col1', MySpecialType(30), primary_key=True), + Column('col2', String(30))) + Table('t2', metadata, + Column('col1', MySpecialType(30), primary_key=True), + Column('col2', String(30))) + Table('t3', metadata, + Column('t1c1', MySpecialType(30), ForeignKey('t1.col1')), + Column('t2c1', MySpecialType(30), ForeignKey('t2.col1'))) + + @testing.resolve_artifact_names + def testm2m(self): + """Many-to-many tables with special types for candidate keys.""" + + class T1(_base.Entity): pass + class T2(_base.Entity): pass + mapper(T2, t2) + mapper(T1, t1, properties={ + 't2s':relation(T2, secondary=t3, backref='t1s')}) + + a = T1() + a.col1 = "aid" + b = T2() + b.col1 = "bid" + c = T2() + c.col1 = "cid" + a.t2s.append(b) + a.t2s.append(c) + sess = create_session() + sess.add(a) + sess.flush() + + assert t3.count().scalar() == 2 + + a.t2s.remove(c) + sess.flush() + + assert t3.count().scalar() == 1 + + +class ViewOnlyOverlappingNames(_base.MappedTest): + """'viewonly' mappings with overlapping PK column names.""" + + @classmethod + def define_tables(cls, metadata): + Table("t1", metadata, + Column('id', Integer, primary_key=True), + Column('data', String(40))) + Table("t2", metadata, + Column('id', Integer, primary_key=True), + Column('data', String(40)), + Column('t1id', Integer, ForeignKey('t1.id'))) + Table("t3", metadata, + Column('id', Integer, primary_key=True), + Column('data', String(40)), + Column('t2id', Integer, ForeignKey('t2.id'))) + + @testing.resolve_artifact_names + def test_three_table_view(self): + """A three table join with overlapping PK names. + + A third table is pulled into the primary join condition using + overlapping PK column names and should not produce 'conflicting column' + error. + + """ + class C1(_base.Entity): pass + class C2(_base.Entity): pass + class C3(_base.Entity): pass + + mapper(C1, t1, properties={ + 't2s':relation(C2), + 't2_view':relation(C2, + viewonly=True, + primaryjoin=sa.and_(t1.c.id==t2.c.t1id, + t3.c.t2id==t2.c.id, + t3.c.data==t1.c.data))}) + mapper(C2, t2) + mapper(C3, t3, properties={ + 't2':relation(C2)}) + + c1 = C1() + c1.data = 'c1data' + c2a = C2() + c1.t2s.append(c2a) + c2b = C2() + c1.t2s.append(c2b) + c3 = C3() + c3.data='c1data' + c3.t2 = c2b + sess = create_session() + sess.add(c1) + sess.add(c3) + sess.flush() + sess.expunge_all() + + c1 = sess.query(C1).get(c1.id) + assert set([x.id for x in c1.t2s]) == set([c2a.id, c2b.id]) + assert set([x.id for x in c1.t2_view]) == set([c2b.id]) + +class ViewOnlyUniqueNames(_base.MappedTest): + """'viewonly' mappings with unique PK column names.""" + + @classmethod + def define_tables(cls, metadata): + Table("t1", metadata, + Column('t1id', Integer, primary_key=True), + Column('data', String(40))) + Table("t2", metadata, + Column('t2id', Integer, primary_key=True), + Column('data', String(40)), + Column('t1id_ref', Integer, ForeignKey('t1.t1id'))) + Table("t3", metadata, + Column('t3id', Integer, primary_key=True), + Column('data', String(40)), + Column('t2id_ref', Integer, ForeignKey('t2.t2id'))) + + @testing.resolve_artifact_names + def test_three_table_view(self): + """A three table join with overlapping PK names. + + A third table is pulled into the primary join condition using unique + PK column names and should not produce 'mapper has no columnX' error. + + """ + class C1(_base.Entity): pass + class C2(_base.Entity): pass + class C3(_base.Entity): pass + + mapper(C1, t1, properties={ + 't2s':relation(C2), + 't2_view':relation(C2, + viewonly=True, + primaryjoin=sa.and_(t1.c.t1id==t2.c.t1id_ref, + t3.c.t2id_ref==t2.c.t2id, + t3.c.data==t1.c.data))}) + mapper(C2, t2) + mapper(C3, t3, properties={ + 't2':relation(C2)}) + + c1 = C1() + c1.data = 'c1data' + c2a = C2() + c1.t2s.append(c2a) + c2b = C2() + c1.t2s.append(c2b) + c3 = C3() + c3.data='c1data' + c3.t2 = c2b + sess = create_session() + + sess.add_all((c1, c3)) + sess.flush() + sess.expunge_all() + + c1 = sess.query(C1).get(c1.t1id) + assert set([x.t2id for x in c1.t2s]) == set([c2a.t2id, c2b.t2id]) + assert set([x.t2id for x in c1.t2_view]) == set([c2b.t2id]) + +class ViewOnlyLocalRemoteM2M(testing.TestBase): + """test that local-remote is correctly determined for m2m""" + + def test_local_remote(self): + meta = MetaData() + + t1 = Table('t1', meta, + Column('id', Integer, primary_key=True), + ) + t2 = Table('t2', meta, + Column('id', Integer, primary_key=True), + ) + t12 = Table('tab', meta, + Column('t1_id', Integer, ForeignKey('t1.id',)), + Column('t2_id', Integer, ForeignKey('t2.id',)), + ) + + class A(object): pass + class B(object): pass + mapper( B, t2, ) + m = mapper( A, t1, properties=dict( + b_view = relation( B, secondary=t12, viewonly=True), + b_plain= relation( B, secondary=t12), + ) + ) + compile_mappers() + assert m.get_property('b_view').local_remote_pairs == \ + m.get_property('b_plain').local_remote_pairs == \ + [(t1.c.id, t12.c.t1_id), (t12.c.t2_id, t2.c.id)] + + + +class ViewOnlyNonEquijoin(_base.MappedTest): + """'viewonly' mappings based on non-equijoins.""" + + @classmethod + def define_tables(cls, metadata): + Table('foos', metadata, + Column('id', Integer, primary_key=True)) + Table('bars', metadata, + Column('id', Integer, primary_key=True), + Column('fid', Integer)) + + @testing.resolve_artifact_names + def test_viewonly_join(self): + class Foo(_base.ComparableEntity): + pass + class Bar(_base.ComparableEntity): + pass + + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id > bars.c.fid, + foreign_keys=[bars.c.fid], + viewonly=True)}) + + mapper(Bar, bars) + + sess = create_session() + sess.add_all((Foo(id=4), + Foo(id=9), + Bar(id=1, fid=2), + Bar(id=2, fid=3), + Bar(id=3, fid=6), + Bar(id=4, fid=7))) + sess.flush() + + sess = create_session() + eq_(sess.query(Foo).filter_by(id=4).one(), + Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)])) + eq_(sess.query(Foo).filter_by(id=9).one(), + Foo(id=9, bars=[Bar(fid=2), Bar(fid=3), Bar(fid=6), Bar(fid=7)])) + + +class ViewOnlyRepeatedRemoteColumn(_base.MappedTest): + """'viewonly' mappings that contain the same 'remote' column twice""" + + @classmethod + def define_tables(cls, metadata): + Table('foos', metadata, + Column('id', Integer, primary_key=True), + Column('bid1', Integer,ForeignKey('bars.id')), + Column('bid2', Integer,ForeignKey('bars.id'))) + + Table('bars', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50))) + + @testing.resolve_artifact_names + def test_relation_on_or(self): + class Foo(_base.ComparableEntity): + pass + class Bar(_base.ComparableEntity): + pass + + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=sa.or_(bars.c.id == foos.c.bid1, + bars.c.id == foos.c.bid2), + uselist=True, + viewonly=True)}) + mapper(Bar, bars) + + sess = create_session() + b1 = Bar(id=1, data='b1') + b2 = Bar(id=2, data='b2') + b3 = Bar(id=3, data='b3') + f1 = Foo(bid1=1, bid2=2) + f2 = Foo(bid1=3, bid2=None) + + sess.add_all((b1, b2, b3)) + sess.flush() + + sess.add_all((f1, f2)) + sess.flush() + + sess.expunge_all() + eq_(sess.query(Foo).filter_by(id=f1.id).one(), + Foo(bars=[Bar(data='b1'), Bar(data='b2')])) + eq_(sess.query(Foo).filter_by(id=f2.id).one(), + Foo(bars=[Bar(data='b3')])) + +class ViewOnlyRepeatedLocalColumn(_base.MappedTest): + """'viewonly' mappings that contain the same 'local' column twice""" + + @classmethod + def define_tables(cls, metadata): + Table('foos', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50))) + + Table('bars', metadata, Column('id', Integer, primary_key=True), + Column('fid1', Integer, ForeignKey('foos.id')), + Column('fid2', Integer, ForeignKey('foos.id')), + Column('data', String(50))) + + @testing.resolve_artifact_names + def test_relation_on_or(self): + class Foo(_base.ComparableEntity): + pass + class Bar(_base.ComparableEntity): + pass + + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=sa.or_(bars.c.fid1 == foos.c.id, + bars.c.fid2 == foos.c.id), + viewonly=True)}) + mapper(Bar, bars) + + sess = create_session() + f1 = Foo(id=1, data='f1') + f2 = Foo(id=2, data='f2') + b1 = Bar(fid1=1, data='b1') + b2 = Bar(fid2=1, data='b2') + b3 = Bar(fid1=2, data='b3') + b4 = Bar(fid1=1, fid2=2, data='b4') + + sess.add_all((f1, f2)) + sess.flush() + + sess.add_all((b1, b2, b3, b4)) + sess.flush() + + sess.expunge_all() + eq_(sess.query(Foo).filter_by(id=f1.id).one(), + Foo(bars=[Bar(data='b1'), Bar(data='b2'), Bar(data='b4')])) + eq_(sess.query(Foo).filter_by(id=f2.id).one(), + Foo(bars=[Bar(data='b3'), Bar(data='b4')])) + +class ViewOnlyComplexJoin(_base.MappedTest): + """'viewonly' mappings with a complex join condition.""" + + @classmethod + def define_tables(cls, metadata): + Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50))) + Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('t1id', Integer, ForeignKey('t1.id'))) + Table('t3', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50))) + Table('t2tot3', metadata, + Column('t2id', Integer, ForeignKey('t2.id')), + Column('t3id', Integer, ForeignKey('t3.id'))) + + @classmethod + def setup_classes(cls): + class T1(_base.ComparableEntity): + pass + class T2(_base.ComparableEntity): + pass + class T3(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def test_basic(self): + mapper(T1, t1, properties={ + 't3s':relation(T3, primaryjoin=sa.and_( + t1.c.id==t2.c.t1id, + t2.c.id==t2tot3.c.t2id, + t3.c.id==t2tot3.c.t3id), + viewonly=True, + foreign_keys=t3.c.id, remote_side=t2.c.t1id) + }) + mapper(T2, t2, properties={ + 't1':relation(T1), + 't3s':relation(T3, secondary=t2tot3) + }) + mapper(T3, t3) + + sess = create_session() + sess.add(T2(data='t2', t1=T1(data='t1'), t3s=[T3(data='t3')])) + sess.flush() + sess.expunge_all() + + a = sess.query(T1).first() + eq_(a.t3s, [T3(data='t3')]) + + + @testing.resolve_artifact_names + def test_remote_side_escalation(self): + mapper(T1, t1, properties={ + 't3s':relation(T3, + primaryjoin=sa.and_(t1.c.id==t2.c.t1id, + t2.c.id==t2tot3.c.t2id, + t3.c.id==t2tot3.c.t3id + ), + viewonly=True, + foreign_keys=t3.c.id)}) + mapper(T2, t2, properties={ + 't1':relation(T1), + 't3s':relation(T3, secondary=t2tot3)}) + mapper(T3, t3) + assert_raises_message(sa.exc.ArgumentError, + "Specify remote_side argument", + sa.orm.compile_mappers) + + +class ExplicitLocalRemoteTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('t1', metadata, + Column('id', String(50), primary_key=True), + Column('data', String(50))) + Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('t1id', String(50))) + + @classmethod + @testing.resolve_artifact_names + def setup_classes(cls): + class T1(_base.ComparableEntity): + pass + class T2(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def test_onetomany_funcfk(self): + # use a function within join condition. but specifying + # local_remote_pairs overrides all parsing of the join condition. + mapper(T1, t1, properties={ + 't2s':relation(T2, + primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id])}) + mapper(T2, t2) + + sess = create_session() + a1 = T1(id='number1', data='a1') + a2 = T1(id='number2', data='a2') + b1 = T2(data='b1', t1id='NuMbEr1') + b2 = T2(data='b2', t1id='Number1') + b3 = T2(data='b3', t1id='Number2') + sess.add_all((a1, a2, b1, b2, b3)) + sess.flush() + sess.expunge_all() + + eq_(sess.query(T1).first(), + T1(id='number1', data='a1', t2s=[ + T2(data='b1', t1id='NuMbEr1'), + T2(data='b2', t1id='Number1')])) + + @testing.resolve_artifact_names + def test_manytoone_funcfk(self): + mapper(T1, t1) + mapper(T2, t2, properties={ + 't1':relation(T1, + primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), + _local_remote_pairs=[(t2.c.t1id, t1.c.id)], + foreign_keys=[t2.c.t1id], + uselist=True)}) + + sess = create_session() + a1 = T1(id='number1', data='a1') + a2 = T1(id='number2', data='a2') + b1 = T2(data='b1', t1id='NuMbEr1') + b2 = T2(data='b2', t1id='Number1') + b3 = T2(data='b3', t1id='Number2') + sess.add_all((a1, a2, b1, b2, b3)) + sess.flush() + sess.expunge_all() + + eq_(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), + [T2(data='b1', t1=[T1(id='number1', data='a1')]), + T2(data='b2', t1=[T1(id='number1', data='a1')])]) + + @testing.resolve_artifact_names + def test_onetomany_func_referent(self): + mapper(T1, t1, properties={ + 't2s':relation(T2, + primaryjoin=sa.func.lower(t1.c.id)==t2.c.t1id, + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id])}) + mapper(T2, t2) + + sess = create_session() + a1 = T1(id='NuMbeR1', data='a1') + a2 = T1(id='NuMbeR2', data='a2') + b1 = T2(data='b1', t1id='number1') + b2 = T2(data='b2', t1id='number1') + b3 = T2(data='b2', t1id='number2') + sess.add_all((a1, a2, b1, b2, b3)) + sess.flush() + sess.expunge_all() + + eq_(sess.query(T1).first(), + T1(id='NuMbeR1', data='a1', t2s=[ + T2(data='b1', t1id='number1'), + T2(data='b2', t1id='number1')])) + + @testing.resolve_artifact_names + def test_manytoone_func_referent(self): + mapper(T1, t1) + mapper(T2, t2, properties={ + 't1':relation(T1, + primaryjoin=sa.func.lower(t1.c.id)==t2.c.t1id, + _local_remote_pairs=[(t2.c.t1id, t1.c.id)], + foreign_keys=[t2.c.t1id], uselist=True)}) + + sess = create_session() + a1 = T1(id='NuMbeR1', data='a1') + a2 = T1(id='NuMbeR2', data='a2') + b1 = T2(data='b1', t1id='number1') + b2 = T2(data='b2', t1id='number1') + b3 = T2(data='b3', t1id='number2') + sess.add_all((a1, a2, b1, b2, b3)) + sess.flush() + sess.expunge_all() + + eq_(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), + [T2(data='b1', t1=[T1(id='NuMbeR1', data='a1')]), + T2(data='b2', t1=[T1(id='NuMbeR1', data='a1')])]) + + @testing.resolve_artifact_names + def test_escalation_1(self): + mapper(T1, t1, properties={ + 't2s':relation(T2, + primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id], + remote_side=[t2.c.t1id])}) + mapper(T2, t2) + assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_escalation_2(self): + mapper(T1, t1, properties={ + 't2s':relation(T2, + primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), + _local_remote_pairs=[(t1.c.id, t2.c.t1id)])}) + mapper(T2, t2) + assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers) + +class InvalidRemoteSideTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column('t_id', Integer, ForeignKey('t1.id')) + ) + + @classmethod + @testing.resolve_artifact_names + def setup_classes(cls): + class T1(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def test_o2m_backref(self): + mapper(T1, t1, properties={ + 't1s':relation(T1, backref='parent') + }) + + assert_raises_message(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " + "both of the same direction . Did you " + "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_m2o_backref(self): + mapper(T1, t1, properties={ + 't1s':relation(T1, backref=backref('parent', remote_side=t1.c.id), remote_side=t1.c.id) + }) + + assert_raises_message(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " + "both of the same direction . Did you " + "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_o2m_explicit(self): + mapper(T1, t1, properties={ + 't1s':relation(T1, back_populates='parent'), + 'parent':relation(T1, back_populates='t1s'), + }) + + # can't be sure of ordering here + assert_raises_message(sa.exc.ArgumentError, + "both of the same direction . Did you " + "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_m2o_explicit(self): + mapper(T1, t1, properties={ + 't1s':relation(T1, back_populates='parent', remote_side=t1.c.id), + 'parent':relation(T1, back_populates='t1s', remote_side=t1.c.id) + }) + + # can't be sure of ordering here + assert_raises_message(sa.exc.ArgumentError, + "both of the same direction . Did you " + "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) + + +class InvalidRelationEscalationTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('foos', metadata, + Column('id', Integer, primary_key=True), + Column('fid', Integer)) + Table('bars', metadata, + Column('id', Integer, primary_key=True), + Column('fid', Integer)) + + @classmethod + def setup_classes(cls): + class Foo(_base.Entity): + pass + class Bar(_base.Entity): + pass + + @testing.resolve_artifact_names + def test_no_join(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine join condition between parent/child " + "tables on relation", sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_no_join_self_ref(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine join condition between parent/child " + "tables on relation", sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_no_equated(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id>bars.c.fid)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine relation direction for primaryjoin condition", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_no_equated_fks(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id>bars.c.fid, + foreign_keys=bars.c.fid)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not locate any equated, locally mapped column pairs " + "for primaryjoin condition", sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_ambiguous_fks(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id==bars.c.fid, + foreign_keys=[foos.c.id, bars.c.fid])}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Do the columns in 'foreign_keys' represent only the " + "'foreign' columns in this join condition ?", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_ambiguous_remoteside_o2m(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id==bars.c.fid, + foreign_keys=[bars.c.fid], + remote_side=[foos.c.id, bars.c.fid], + viewonly=True + )}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "could not determine any local/remote column pairs", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_ambiguous_remoteside_m2o(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id==bars.c.fid, + foreign_keys=[foos.c.id], + remote_side=[foos.c.id, bars.c.fid], + viewonly=True + )}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "could not determine any local/remote column pairs", + sa.orm.compile_mappers) + + + @testing.resolve_artifact_names + def test_no_equated_self_ref(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, + primaryjoin=foos.c.id>foos.c.fid)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine relation direction for primaryjoin condition", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_no_equated_self_ref(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, + primaryjoin=foos.c.id>foos.c.fid, + foreign_keys=[foos.c.fid])}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not locate any equated, locally mapped column pairs " + "for primaryjoin condition", sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_no_equated_viewonly(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id>bars.c.fid, + viewonly=True)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine relation direction for primaryjoin condition", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_no_equated_self_ref_viewonly(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, + primaryjoin=foos.c.id>foos.c.fid, + viewonly=True)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Specify the 'foreign_keys' argument to indicate which columns " + "on the relation are foreign.", sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_no_equated_self_ref_viewonly_fks(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, + primaryjoin=foos.c.id>foos.c.fid, + viewonly=True, + foreign_keys=[foos.c.fid])}) + + sa.orm.compile_mappers() + eq_(Foo.foos.property.local_remote_pairs, [(foos.c.id, foos.c.fid)]) + + @testing.resolve_artifact_names + def test_equated(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + primaryjoin=foos.c.id==bars.c.fid)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine relation direction for primaryjoin condition", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_equated_self_ref(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, + primaryjoin=foos.c.id==foos.c.fid)}) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine relation direction for primaryjoin condition", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_equated_self_ref_wrong_fks(self): + mapper(Foo, foos, properties={ + 'foos':relation(Foo, + primaryjoin=foos.c.id==foos.c.fid, + foreign_keys=[bars.c.id])}) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine relation direction for primaryjoin condition", + sa.orm.compile_mappers) + + +class InvalidRelationEscalationTestM2M(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('foos', metadata, + Column('id', Integer, primary_key=True)) + Table('foobars', metadata, + Column('fid', Integer), Column('bid', Integer)) + Table('bars', metadata, + Column('id', Integer, primary_key=True)) + + @classmethod + @testing.resolve_artifact_names + def setup_classes(cls): + class Foo(_base.Entity): + pass + class Bar(_base.Entity): + pass + + @testing.resolve_artifact_names + def test_no_join(self): + mapper(Foo, foos, properties={ + 'bars': relation(Bar, secondary=foobars)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine join condition between parent/child tables " + "on relation", sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_no_secondaryjoin(self): + mapper(Foo, foos, properties={ + 'bars': relation(Bar, + secondary=foobars, + primaryjoin=foos.c.id > foobars.c.fid)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine join condition between parent/child tables " + "on relation", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_bad_primaryjoin(self): + mapper(Foo, foos, properties={ + 'bars': relation(Bar, + secondary=foobars, + primaryjoin=foos.c.id > foobars.c.fid, + secondaryjoin=foobars.c.bid<=bars.c.id)}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine relation direction for primaryjoin condition", + sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_bad_secondaryjoin(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + secondary=foobars, + primaryjoin=foos.c.id == foobars.c.fid, + secondaryjoin=foobars.c.bid <= bars.c.id, + foreign_keys=[foobars.c.fid])}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not determine relation direction for secondaryjoin " + "condition", sa.orm.compile_mappers) + + @testing.resolve_artifact_names + def test_no_equated_secondaryjoin(self): + mapper(Foo, foos, properties={ + 'bars':relation(Bar, + secondary=foobars, + primaryjoin=foos.c.id == foobars.c.fid, + secondaryjoin=foobars.c.bid <= bars.c.id, + foreign_keys=[foobars.c.fid, foobars.c.bid])}) + mapper(Bar, bars) + + assert_raises_message( + sa.exc.ArgumentError, + "Could not locate any equated, locally mapped column pairs for " + "secondaryjoin condition", sa.orm.compile_mappers) + + diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py new file mode 100644 index 000000000..2117e8dcc --- /dev/null +++ b/test/orm/test_scoping.py @@ -0,0 +1,249 @@ +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy.orm import scoped_session +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, query +from sqlalchemy.test.testing import eq_ +from test.orm import _base + + +class _ScopedTest(_base.MappedTest): + """Adds another lookup bucket to emulate Session globals.""" + + run_setup_mappers = 'once' + + _artifact_registries = ( + _base.MappedTest._artifact_registries + ('scoping',)) + + @classmethod + def setup_class(cls): + cls.scoping = _base.adict() + super(_ScopedTest, cls).setup_class() + + @classmethod + def teardown_class(cls): + cls.scoping.clear() + super(_ScopedTest, cls).teardown_class() + + +class ScopedSessionTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('table1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + Table('table2', metadata, + Column('id', Integer, primary_key=True), + Column('someid', None, ForeignKey('table1.id'))) + + @testing.resolve_artifact_names + def test_basic(self): + Session = scoped_session(sa.orm.sessionmaker()) + + class CustomQuery(query.Query): + pass + + class SomeObject(_base.ComparableEntity): + query = Session.query_property() + class SomeOtherObject(_base.ComparableEntity): + query = Session.query_property() + custom_query = Session.query_property(query_cls=CustomQuery) + + mapper(SomeObject, table1, properties={ + 'options':relation(SomeOtherObject)}) + mapper(SomeOtherObject, table2) + + s = SomeObject(id=1, data="hello") + sso = SomeOtherObject() + s.options.append(sso) + Session.add(s) + Session.commit() + Session.refresh(sso) + Session.remove() + + eq_(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), + Session.query(SomeObject).one()) + eq_(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), + SomeObject.query.one()) + eq_(SomeOtherObject(someid=1), + SomeOtherObject.query.filter( + SomeOtherObject.someid == sso.someid).one()) + assert isinstance(SomeOtherObject.query, query.Query) + assert not isinstance(SomeOtherObject.query, CustomQuery) + assert isinstance(SomeOtherObject.custom_query, query.Query) + + +class ScopedMapperTest(_ScopedTest): + + @classmethod + def define_tables(cls, metadata): + Table('table1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + Table('table2', metadata, + Column('id', Integer, primary_key=True), + Column('someid', None, ForeignKey('table1.id'))) + + @classmethod + def setup_classes(cls): + class SomeObject(_base.ComparableEntity): + pass + class SomeOtherObject(_base.ComparableEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + Session = scoped_session(sa.orm.create_session) + Session.mapper(SomeObject, table1, properties={ + 'options':relation(SomeOtherObject) + }) + Session.mapper(SomeOtherObject, table2) + + cls.scoping['Session'] = Session + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + s = SomeObject() + s.id = 1 + s.data = 'hello' + sso = SomeOtherObject() + s.options.append(sso) + Session.flush() + Session.expunge_all() + + @testing.resolve_artifact_names + def test_query(self): + sso = SomeOtherObject.query().first() + assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id + + @testing.resolve_artifact_names + def test_query_compiles(self): + class Foo(object): + pass + Session.mapper(Foo, table2) + assert hasattr(Foo, 'query') + + ext = sa.orm.MapperExtension() + + class Bar(object): + pass + Session.mapper(Bar, table2, extension=[ext]) + assert hasattr(Bar, 'query') + + class Baz(object): + pass + Session.mapper(Baz, table2, extension=ext) + assert hasattr(Baz, 'query') + + @testing.resolve_artifact_names + def test_default_constructor_state_not_shared(self): + scope = scoped_session(sa.orm.sessionmaker()) + + class A(object): + pass + class B(object): + def __init__(self): + pass + + scope.mapper(A, table1) + scope.mapper(B, table2) + + A(foo='bar') + assert_raises(TypeError, B, foo='bar') + + scope = scoped_session(sa.orm.sessionmaker()) + + class C(object): + def __init__(self): + pass + class D(object): + pass + + scope.mapper(C, table1) + scope.mapper(D, table2) + + assert_raises(TypeError, C, foo='bar') + D(foo='bar') + + @testing.resolve_artifact_names + def test_validating_constructor(self): + s2 = SomeObject(someid=12) + s3 = SomeOtherObject(someid=123, bogus=345) + + class ValidatedOtherObject(object): pass + Session.mapper(ValidatedOtherObject, table2, validate=True) + + v1 = ValidatedOtherObject(someid=12) + assert_raises(sa.exc.ArgumentError, ValidatedOtherObject, + someid=12, bogus=345) + + @testing.resolve_artifact_names + def test_dont_clobber_methods(self): + class MyClass(object): + def expunge(self): + return "an expunge !" + + Session.mapper(MyClass, table2) + + assert MyClass().expunge() == "an expunge !" + + +class ScopedMapperTest2(_ScopedTest): + + @classmethod + def define_tables(cls, metadata): + Table('table1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('type', String(30))) + Table('table2', metadata, + Column('id', Integer, primary_key=True), + Column('someid', None, ForeignKey('table1.id')), + Column('somedata', String(30))) + + @classmethod + def setup_classes(cls): + class BaseClass(_base.ComparableEntity): + pass + class SubClass(BaseClass): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + Session = scoped_session(sa.orm.sessionmaker()) + + Session.mapper(BaseClass, table1, + polymorphic_identity='base', + polymorphic_on=table1.c.type) + Session.mapper(SubClass, table2, + polymorphic_identity='sub', + inherits=BaseClass) + + cls.scoping['Session'] = Session + + @testing.resolve_artifact_names + def test_inheritance(self): + def expunge_list(l): + for x in l: + Session.expunge(x) + return l + + b = BaseClass(data='b1') + s = SubClass(data='s1', somedata='somedata') + Session.commit() + Session.expunge_all() + + eq_(expunge_list([BaseClass(data='b1'), + SubClass(data='s1', somedata='somedata')]), + BaseClass.query.all()) + eq_(expunge_list([SubClass(data='s1', somedata='somedata')]), + SubClass.query.all()) + + diff --git a/test/orm/test_selectable.py b/test/orm/test_selectable.py new file mode 100644 index 000000000..0a2025360 --- /dev/null +++ b/test/orm/test_selectable.py @@ -0,0 +1,55 @@ +"""Generic mapping to Select statements""" +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import String, Integer, select +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base + + +# TODO: more tests mapping to selects + +class SelectableNoFromsTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('common', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer), + Column('extra', String(45))) + + @classmethod + def setup_classes(cls): + class Subset(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def test_no_tables(self): + + selectable = select(["x", "y", "z"]) + assert_raises_message(sa.exc.InvalidRequestError, + "Could not find any Table objects", + mapper, Subset, selectable) + + @testing.emits_warning('.*creating an Alias.*') + @testing.resolve_artifact_names + def test_basic(self): + subset_select = select([common.c.id, common.c.data]) + subset_mapper = mapper(Subset, subset_select) + + sess = create_session(bind=testing.db) + sess.add(Subset(data=1)) + sess.flush() + sess.expunge_all() + + eq_(sess.query(Subset).all(), [Subset(data=1)]) + eq_(sess.query(Subset).filter(Subset.data==1).one(), Subset(data=1)) + eq_(sess.query(Subset).filter(Subset.data!=1).first(), None) + + subset_select = sa.orm.class_mapper(Subset).mapped_table + eq_(sess.query(Subset).filter(subset_select.c.data==1).one(), + Subset(data=1)) + + diff --git a/test/orm/test_session.py b/test/orm/test_session.py new file mode 100644 index 000000000..3020d66e9 --- /dev/null +++ b/test/orm/test_session.py @@ -0,0 +1,1434 @@ +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +import gc +import inspect +import pickle +from sqlalchemy.orm import create_session, sessionmaker, attributes +import sqlalchemy as sa +from sqlalchemy.test import engines, testing, config +from sqlalchemy import Integer, String, Sequence +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, eagerload +from sqlalchemy.test.testing import eq_ +from test.engine import _base as engine_base +from test.orm import _base, _fixtures + + +class SessionTest(_fixtures.FixtureTest): + run_inserts = None + + @testing.resolve_artifact_names + def test_no_close_on_flush(self): + """Flush() doesn't close a connection the session didn't open""" + c = testing.db.connect() + c.execute("select * from users") + + mapper(User, users) + s = create_session(bind=c) + s.add(User(name='first')) + s.flush() + c.execute("select * from users") + + @testing.resolve_artifact_names + def test_close(self): + """close() doesn't close a connection the session didn't open""" + c = testing.db.connect() + c.execute("select * from users") + + mapper(User, users) + s = create_session(bind=c) + s.add(User(name='first')) + s.flush() + c.execute("select * from users") + s.close() + c.execute("select * from users") + + @testing.resolve_artifact_names + def test_no_close_transaction_on_flulsh(self): + c = testing.db.connect() + try: + mapper(User, users) + s = create_session(bind=c) + s.begin() + tran = s.transaction + s.add(User(name='first')) + s.flush() + c.execute("select * from users") + u = User(name='two') + s.add(u) + s.flush() + u = User(name='third') + s.add(u) + s.flush() + assert s.transaction is tran + tran.close() + finally: + c.close() + + @testing.requires.sequences + def test_sequence_execute(self): + seq = Sequence("some_sequence") + seq.create(testing.db) + try: + sess = create_session(bind=testing.db) + eq_(sess.execute(seq), 1) + finally: + seq.drop(testing.db) + + + @testing.resolve_artifact_names + def test_expunge_cascade(self): + mapper(Address, addresses) + mapper(User, users, properties={ + 'addresses':relation(Address, + backref=backref("user", cascade="all"), + cascade="all")}) + + _fixtures.run_inserts_for(users) + _fixtures.run_inserts_for(addresses) + + session = create_session() + u = session.query(User).filter_by(id=7).one() + + # get everything to load in both directions + print [a.user for a in u.addresses] + + # then see if expunge fails + session.expunge(u) + + assert sa.orm.object_session(u) is None + assert sa.orm.attributes.instance_state(u).session_id is None + for a in u.addresses: + assert sa.orm.object_session(a) is None + assert sa.orm.attributes.instance_state(a).session_id is None + + @engines.close_open_connections + @testing.resolve_artifact_names + def test_table_binds_from_expression(self): + """Session can extract Table objects from ClauseElements and match them to tables.""" + + mapper(Address, addresses) + mapper(User, users, properties={ + 'addresses':relation(Address, + backref=backref("user", cascade="all"), + cascade="all")}) + + Session = sessionmaker(binds={users: self.metadata.bind, + addresses: self.metadata.bind}) + sess = Session() + + sess.execute(users.insert(), params=dict(id=1, name='ed')) + eq_(sess.execute(users.select(users.c.id == 1)).fetchall(), + [(1, 'ed')]) + + eq_(sess.execute(users.select(User.id == 1)).fetchall(), + [(1, 'ed')]) + + sess.close() + + @engines.close_open_connections + @testing.resolve_artifact_names + def test_mapped_binds_from_expression(self): + """Session can extract Table objects from ClauseElements and match them to tables.""" + + mapper(Address, addresses) + mapper(User, users, properties={ + 'addresses':relation(Address, + backref=backref("user", cascade="all"), + cascade="all")}) + + Session = sessionmaker(binds={User: self.metadata.bind, + Address: self.metadata.bind}) + sess = Session() + + sess.execute(users.insert(), params=dict(id=1, name='ed')) + eq_(sess.execute(users.select(users.c.id == 1)).fetchall(), + [(1, 'ed')]) + + eq_(sess.execute(users.select(User.id == 1)).fetchall(), + [(1, 'ed')]) + + sess.close() + + @engines.close_open_connections + @testing.resolve_artifact_names + def test_bind_from_metadata(self): + mapper(User, users) + + session = create_session() + session.execute(users.insert(), dict(name='Johnny')) + + assert len(session.query(User).filter_by(name='Johnny').all()) == 1 + + session.execute(users.delete()) + + assert len(session.query(User).filter_by(name='Johnny').all()) == 0 + session.close() + + @testing.requires.independent_connections + @engines.close_open_connections + @testing.resolve_artifact_names + def test_transaction(self): + mapper(User, users) + conn1 = testing.db.connect() + conn2 = testing.db.connect() + + sess = create_session(autocommit=False, bind=conn1) + u = User(name='x') + sess.add(u) + sess.flush() + assert conn1.execute("select count(1) from users").scalar() == 1 + assert conn2.execute("select count(1) from users").scalar() == 0 + sess.commit() + assert conn1.execute("select count(1) from users").scalar() == 1 + assert testing.db.connect().execute("select count(1) from users").scalar() == 1 + sess.close() + + @testing.requires.independent_connections + @engines.close_open_connections + @testing.resolve_artifact_names + def test_autoflush(self): + bind = self.metadata.bind + mapper(User, users) + conn1 = bind.connect() + conn2 = bind.connect() + + sess = create_session(bind=conn1, autocommit=False, autoflush=True) + u = User() + u.name='ed' + sess.add(u) + u2 = sess.query(User).filter_by(name='ed').one() + assert u2 is u + eq_(conn1.execute("select count(1) from users").scalar(), 1) + eq_(conn2.execute("select count(1) from users").scalar(), 0) + sess.commit() + eq_(conn1.execute("select count(1) from users").scalar(), 1) + eq_(bind.connect().execute("select count(1) from users").scalar(), 1) + sess.close() + + @testing.resolve_artifact_names + def test_autoflush_expressions(self): + """test that an expression which is dependent on object state is + evaluated after the session autoflushes. This is the lambda + inside of strategies.py lazy_clause. + + """ + mapper(User, users, properties={ + 'addresses':relation(Address, backref="user")}) + mapper(Address, addresses) + + sess = create_session(autoflush=True, autocommit=False) + u = User(name='ed', addresses=[Address(email_address='foo')]) + sess.add(u) + eq_(sess.query(Address).filter(Address.user==u).one(), + Address(email_address='foo')) + + # still works after "u" is garbage collected + sess.commit() + sess.close() + u = sess.query(User).get(u.id) + q = sess.query(Address).filter(Address.user==u) + del u + gc.collect() + eq_(q.one(), Address(email_address='foo')) + + + @testing.requires.independent_connections + @engines.close_open_connections + @testing.resolve_artifact_names + def test_autoflush_unbound(self): + mapper(User, users) + + try: + sess = create_session(autocommit=False, autoflush=True) + u = User() + u.name='ed' + sess.add(u) + u2 = sess.query(User).filter_by(name='ed').one() + assert u2 is u + assert sess.execute("select count(1) from users", mapper=User).scalar() == 1 + assert testing.db.connect().execute("select count(1) from users").scalar() == 0 + sess.commit() + assert sess.execute("select count(1) from users", mapper=User).scalar() == 1 + assert testing.db.connect().execute("select count(1) from users").scalar() == 1 + sess.close() + except: + sess.rollback() + raise + + @engines.close_open_connections + @testing.resolve_artifact_names + def test_autoflush_2(self): + mapper(User, users) + conn1 = testing.db.connect() + conn2 = testing.db.connect() + + sess = create_session(bind=conn1, autocommit=False, autoflush=True) + u = User() + u.name='ed' + sess.add(u) + sess.commit() + assert conn1.execute("select count(1) from users").scalar() == 1 + assert testing.db.connect().execute("select count(1) from users").scalar() == 1 + sess.commit() + + @testing.resolve_artifact_names + def test_autoflush_rollback(self): + mapper(Address, addresses) + mapper(User, users, properties={ + 'addresses':relation(Address)}) + + _fixtures.run_inserts_for(users) + _fixtures.run_inserts_for(addresses) + + sess = create_session(autocommit=False, autoflush=True) + u = sess.query(User).get(8) + newad = Address(email_address='a new address') + u.addresses.append(newad) + u.name = 'some new name' + assert u.name == 'some new name' + assert len(u.addresses) == 4 + assert newad in u.addresses + sess.rollback() + assert u.name == 'ed' + assert len(u.addresses) == 3 + + assert newad not in u.addresses + # pending objects dont get expired + assert newad.email_address == 'a new address' + + @testing.resolve_artifact_names + def test_autocommit_doesnt_raise_on_pending(self): + mapper(User, users) + session = create_session(autocommit=True) + + session.add(User(name='ed')) + + session.begin() + session.flush() + session.commit() + + def test_active_flag(self): + sess = create_session(bind=config.db, autocommit=True) + assert not sess.is_active + sess.begin() + assert sess.is_active + sess.rollback() + assert not sess.is_active + + @testing.resolve_artifact_names + def test_textual_execute(self): + """test that Session.execute() converts to text()""" + + sess = create_session(bind=self.metadata.bind) + users.insert().execute(id=7, name='jack') + + # use :bindparam style + eq_(sess.execute("select * from users where id=:id", + {'id':7}).fetchall(), + [(7, u'jack')]) + + + # use :bindparam style + eq_(sess.scalar("select id from users where id=:id", + {'id':7}), + 7) + + @engines.close_open_connections + @testing.resolve_artifact_names + def test_subtransaction_on_external(self): + mapper(User, users) + conn = testing.db.connect() + trans = conn.begin() + sess = create_session(bind=conn, autocommit=False, autoflush=True) + sess.begin(subtransactions=True) + u = User(name='ed') + sess.add(u) + sess.flush() + sess.commit() # commit does nothing + trans.rollback() # rolls back + assert len(sess.query(User).all()) == 0 + sess.close() + + @testing.requires.savepoints + @engines.close_open_connections + @testing.resolve_artifact_names + def test_external_nested_transaction(self): + mapper(User, users) + try: + conn = testing.db.connect() + trans = conn.begin() + sess = create_session(bind=conn, autocommit=False, autoflush=True) + u1 = User(name='u1') + sess.add(u1) + sess.flush() + + sess.begin_nested() + u2 = User(name='u2') + sess.add(u2) + sess.flush() + sess.rollback() + + trans.commit() + assert len(sess.query(User).all()) == 1 + except: + conn.close() + raise + + @testing.requires.savepoints + @testing.resolve_artifact_names + def test_heavy_nesting(self): + session = create_session(bind=testing.db) + + session.begin() + session.connection().execute("insert into users (name) values ('user1')") + + session.begin(subtransactions=True) + + session.begin_nested() + + session.connection().execute("insert into users (name) values ('user2')") + assert session.connection().execute("select count(1) from users").scalar() == 2 + + session.rollback() + assert session.connection().execute("select count(1) from users").scalar() == 1 + session.connection().execute("insert into users (name) values ('user3')") + + session.commit() + assert session.connection().execute("select count(1) from users").scalar() == 2 + + @testing.fails_on('sqlite', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_transactions_isolated(self): + mapper(User, users) + users.delete().execute() + + s1 = create_session(bind=testing.db, autocommit=False) + s2 = create_session(bind=testing.db, autocommit=False) + u1 = User(name='u1') + s1.add(u1) + s1.flush() + + assert s2.query(User).all() == [] + + @testing.requires.two_phase_transactions + @testing.resolve_artifact_names + def test_twophase(self): + # TODO: mock up a failure condition here + # to ensure a rollback succeeds + mapper(User, users) + mapper(Address, addresses) + + engine2 = engines.testing_engine() + sess = create_session(autocommit=True, autoflush=False, twophase=True) + sess.bind_mapper(User, testing.db) + sess.bind_mapper(Address, engine2) + sess.begin() + u1 = User(name='u1') + a1 = Address(email_address='u1@e') + sess.add_all((u1, a1)) + sess.commit() + sess.close() + engine2.dispose() + assert users.count().scalar() == 1 + assert addresses.count().scalar() == 1 + + @testing.resolve_artifact_names + def test_subtransaction_on_noautocommit(self): + mapper(User, users) + sess = create_session(autocommit=False, autoflush=True) + sess.begin(subtransactions=True) + u = User(name='u1') + sess.add(u) + sess.flush() + sess.commit() # commit does nothing + sess.rollback() # rolls back + assert len(sess.query(User).all()) == 0 + sess.close() + + @testing.requires.savepoints + @testing.resolve_artifact_names + def test_nested_transaction(self): + mapper(User, users) + sess = create_session() + sess.begin() + + u = User(name='u1') + sess.add(u) + sess.flush() + + sess.begin_nested() # nested transaction + + u2 = User(name='u2') + sess.add(u2) + sess.flush() + + sess.rollback() + + sess.commit() + assert len(sess.query(User).all()) == 1 + sess.close() + + @testing.requires.savepoints + @testing.resolve_artifact_names + def test_nested_autotrans(self): + mapper(User, users) + sess = create_session(autocommit=False) + u = User(name='u1') + sess.add(u) + sess.flush() + + sess.begin_nested() # nested transaction + + u2 = User(name='u2') + sess.add(u2) + sess.flush() + + sess.rollback() + + sess.commit() + assert len(sess.query(User).all()) == 1 + sess.close() + + @testing.requires.savepoints + @testing.resolve_artifact_names + def test_nested_transaction_connection_add(self): + mapper(User, users) + + sess = create_session(autocommit=True) + + sess.begin() + sess.begin_nested() + + u1 = User(name='u1') + sess.add(u1) + sess.flush() + + sess.rollback() + + u2 = User(name='u2') + sess.add(u2) + + sess.commit() + + eq_(set(sess.query(User).all()), set([u2])) + + sess.begin() + sess.begin_nested() + + u3 = User(name='u3') + sess.add(u3) + sess.commit() # commit the nested transaction + sess.rollback() + + eq_(set(sess.query(User).all()), set([u2])) + + sess.close() + + @testing.requires.savepoints + @testing.resolve_artifact_names + def test_mixed_transaction_control(self): + mapper(User, users) + + sess = create_session(autocommit=True) + + sess.begin() + sess.begin_nested() + transaction = sess.begin(subtransactions=True) + + sess.add(User(name='u1')) + + transaction.commit() + sess.commit() + sess.commit() + + sess.close() + + eq_(len(sess.query(User).all()), 1) + + t1 = sess.begin() + t2 = sess.begin_nested() + + sess.add(User(name='u2')) + + t2.commit() + assert sess.transaction is t1 + + sess.close() + + @testing.requires.savepoints + @testing.resolve_artifact_names + def test_mixed_transaction_close(self): + mapper(User, users) + + sess = create_session(autocommit=False) + + sess.begin_nested() + + sess.add(User(name='u1')) + sess.flush() + + sess.close() + + sess.add(User(name='u2')) + sess.commit() + + sess.close() + + eq_(len(sess.query(User).all()), 1) + + @testing.resolve_artifact_names + def test_error_on_using_inactive_session(self): + mapper(User, users) + + sess = create_session(autocommit=True) + + sess.begin() + sess.begin(subtransactions=True) + + sess.add(User(name='u1')) + sess.flush() + + sess.rollback() + assert_raises_message(sa.exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True) + sess.close() + + @testing.resolve_artifact_names + def test_no_autocommit_with_explicit_commit(self): + mapper(User, users) + session = create_session(autocommit=False) + + session.add(User(name='ed')) + session.transaction.commit() + assert session.transaction is not None, "autocommit=False should start a new transaction" + + @engines.close_open_connections + @testing.resolve_artifact_names + def test_bound_connection(self): + mapper(User, users) + c = testing.db.connect() + sess = create_session(bind=c) + sess.begin() + transaction = sess.transaction + u = User(name='u1') + sess.add(u) + sess.flush() + assert transaction._connection_for_bind(testing.db) is transaction._connection_for_bind(c) is c + + assert_raises_message(sa.exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect()) + + transaction.rollback() + assert len(sess.query(User).all()) == 0 + sess.close() + + @testing.resolve_artifact_names + def test_bound_connection_transactional(self): + mapper(User, users) + c = testing.db.connect() + + sess = create_session(bind=c, autocommit=False) + u = User(name='u1') + sess.add(u) + sess.flush() + sess.close() + assert not c.in_transaction() + assert c.scalar("select count(1) from users") == 0 + + sess = create_session(bind=c, autocommit=False) + u = User(name='u2') + sess.add(u) + sess.flush() + sess.commit() + assert not c.in_transaction() + assert c.scalar("select count(1) from users") == 1 + c.execute("delete from users") + assert c.scalar("select count(1) from users") == 0 + + c = testing.db.connect() + + trans = c.begin() + sess = create_session(bind=c, autocommit=True) + u = User(name='u3') + sess.add(u) + sess.flush() + assert c.in_transaction() + trans.commit() + assert not c.in_transaction() + assert c.scalar("select count(1) from users") == 1 + + + @testing.uses_deprecated() + @engines.close_open_connections + @testing.resolve_artifact_names + def test_save_update_delete(self): + + s = create_session() + mapper(User, users, properties={ + 'addresses':relation(Address, cascade="all, delete") + }) + mapper(Address, addresses) + + user = User(name='u1') + + assert_raises_message(sa.exc.InvalidRequestError, "is not persisted", s.update, user) + assert_raises_message(sa.exc.InvalidRequestError, "is not persisted", s.delete, user) + + s.add(user) + s.flush() + user = s.query(User).one() + s.expunge(user) + assert user not in s + + # modify outside of session, assert changes remain/get saved + user.name = "fred" + s.add(user) + assert user in s + assert user in s.dirty + s.flush() + s.expunge_all() + assert s.query(User).count() == 1 + user = s.query(User).one() + assert user.name == 'fred' + + # ensure its not dirty if no changes occur + s.expunge_all() + assert user not in s + s.add(user) + assert user in s + assert user not in s.dirty + + assert_raises_message(sa.exc.InvalidRequestError, "is already persistent", s.save, user) + + s2 = create_session() + assert_raises_message(sa.exc.InvalidRequestError, "is already attached to session", s2.delete, user) + + u2 = s2.query(User).get(user.id) + assert_raises_message(sa.exc.InvalidRequestError, "another instance with key", s.delete, u2) + + s.expire(user) + s.expunge(user) + assert user not in s + s.delete(user) + assert user in s + + s.flush() + assert user not in s + assert s.query(User).count() == 0 + + @testing.resolve_artifact_names + def test_is_modified(self): + s = create_session() + + mapper(User, users, properties={'addresses':relation(Address)}) + mapper(Address, addresses) + + # save user + u = User(name='fred') + s.add(u) + s.flush() + s.expunge_all() + + user = s.query(User).one() + assert user not in s.dirty + assert not s.is_modified(user) + user.name = 'fred' + assert user in s.dirty + assert not s.is_modified(user) + user.name = 'ed' + assert user in s.dirty + assert s.is_modified(user) + s.flush() + assert user not in s.dirty + assert not s.is_modified(user) + + a = Address() + user.addresses.append(a) + assert user in s.dirty + assert s.is_modified(user) + assert not s.is_modified(user, include_collections=False) + + + @testing.resolve_artifact_names + def test_weak_ref(self): + """test the weak-referencing identity map, which strongly-references modified items.""" + + s = create_session() + mapper(User, users) + + s.add(User(name='ed')) + s.flush() + assert not s.dirty + + user = s.query(User).one() + del user + gc.collect() + assert len(s.identity_map) == 0 + + user = s.query(User).one() + user.name = 'fred' + del user + gc.collect() + assert len(s.identity_map) == 1 + assert len(s.dirty) == 1 + assert None not in s.dirty + s.flush() + gc.collect() + assert not s.dirty + assert not s.identity_map + + user = s.query(User).one() + assert user.name == 'fred' + assert s.identity_map + + @testing.resolve_artifact_names + def test_weakref_with_cycles_o2m(self): + s = sessionmaker()() + mapper(User, users, properties={ + "addresses":relation(Address, backref="user") + }) + mapper(Address, addresses) + s.add(User(name="ed", addresses=[Address(email_address="ed1")])) + s.commit() + + user = s.query(User).options(eagerload(User.addresses)).one() + user.addresses[0].user # lazyload + eq_(user, User(name="ed", addresses=[Address(email_address="ed1")])) + + del user + gc.collect() + assert len(s.identity_map) == 0 + + user = s.query(User).options(eagerload(User.addresses)).one() + user.addresses[0].email_address='ed2' + user.addresses[0].user # lazyload + del user + gc.collect() + assert len(s.identity_map) == 2 + + s.commit() + user = s.query(User).options(eagerload(User.addresses)).one() + eq_(user, User(name="ed", addresses=[Address(email_address="ed2")])) + + @testing.resolve_artifact_names + def test_weakref_with_cycles_o2o(self): + s = sessionmaker()() + mapper(User, users, properties={ + "address":relation(Address, backref="user", uselist=False) + }) + mapper(Address, addresses) + s.add(User(name="ed", address=Address(email_address="ed1"))) + s.commit() + + user = s.query(User).options(eagerload(User.address)).one() + user.address.user + eq_(user, User(name="ed", address=Address(email_address="ed1"))) + + del user + gc.collect() + assert len(s.identity_map) == 0 + + user = s.query(User).options(eagerload(User.address)).one() + user.address.email_address='ed2' + user.address.user # lazyload + + del user + gc.collect() + assert len(s.identity_map) == 2 + + s.commit() + user = s.query(User).options(eagerload(User.address)).one() + eq_(user, User(name="ed", address=Address(email_address="ed2"))) + + @testing.resolve_artifact_names + def test_strong_ref(self): + s = create_session(weak_identity_map=False) + mapper(User, users) + + # save user + s.add(User(name='u1')) + s.flush() + user = s.query(User).one() + user = None + print s.identity_map + import gc + gc.collect() + assert len(s.identity_map) == 1 + + user = s.query(User).one() + assert not s.identity_map._modified + user.name = 'u2' + assert s.identity_map._modified + s.flush() + eq_(users.select().execute().fetchall(), [(user.id, 'u2')]) + + + @testing.resolve_artifact_names + def test_prune(self): + s = create_session(weak_identity_map=False) + mapper(User, users) + + for o in [User(name='u%s' % x) for x in xrange(10)]: + s.add(o) + # o is still live after this loop... + + self.assert_(len(s.identity_map) == 0) + self.assert_(s.prune() == 0) + s.flush() + import gc + gc.collect() + self.assert_(s.prune() == 9) + self.assert_(len(s.identity_map) == 1) + + id = o.id + del o + self.assert_(s.prune() == 1) + self.assert_(len(s.identity_map) == 0) + + u = s.query(User).get(id) + self.assert_(s.prune() == 0) + self.assert_(len(s.identity_map) == 1) + u.name = 'squiznart' + del u + self.assert_(s.prune() == 0) + self.assert_(len(s.identity_map) == 1) + s.flush() + self.assert_(s.prune() == 1) + self.assert_(len(s.identity_map) == 0) + + s.add(User(name='x')) + self.assert_(s.prune() == 0) + self.assert_(len(s.identity_map) == 0) + s.flush() + self.assert_(len(s.identity_map) == 1) + self.assert_(s.prune() == 1) + self.assert_(len(s.identity_map) == 0) + + u = s.query(User).get(id) + s.delete(u) + del u + self.assert_(s.prune() == 0) + self.assert_(len(s.identity_map) == 1) + s.flush() + self.assert_(s.prune() == 0) + self.assert_(len(s.identity_map) == 0) + + @testing.resolve_artifact_names + def test_no_save_cascade_1(self): + mapper(Address, addresses) + mapper(User, users, properties=dict( + addresses=relation(Address, cascade="none", backref="user"))) + s = create_session() + + u = User(name='u1') + s.add(u) + a = Address(email_address='u1@e') + u.addresses.append(a) + assert u in s + assert a not in s + s.flush() + print "\n".join([repr(x.__dict__) for x in s]) + s.expunge_all() + assert s.query(User).one().id == u.id + assert s.query(Address).first() is None + + @testing.resolve_artifact_names + def test_no_save_cascade_2(self): + mapper(Address, addresses) + mapper(User, users, properties=dict( + addresses=relation(Address, + cascade="all", + backref=backref("user", cascade="none")))) + + s = create_session() + u = User(name='u1') + a = Address(email_address='u1@e') + a.user = u + s.add(a) + assert u not in s + assert a in s + s.flush() + s.expunge_all() + assert s.query(Address).one().id == a.id + assert s.query(User).first() is None + + @testing.resolve_artifact_names + def test_extension(self): + mapper(User, users) + log = [] + class MyExt(sa.orm.session.SessionExtension): + def before_commit(self, session): + log.append('before_commit') + def after_commit(self, session): + log.append('after_commit') + def after_rollback(self, session): + log.append('after_rollback') + def before_flush(self, session, flush_context, objects): + log.append('before_flush') + def after_flush(self, session, flush_context): + log.append('after_flush') + def after_flush_postexec(self, session, flush_context): + log.append('after_flush_postexec') + def after_begin(self, session, transaction, connection): + log.append('after_begin') + def after_attach(self, session, instance): + log.append('after_attach') + def after_bulk_update(self, session, query, query_context, result): + log.append('after_bulk_update') + def after_bulk_delete(self, session, query, query_context, result): + log.append('after_bulk_delete') + + sess = create_session(extension = MyExt()) + u = User(name='u1') + sess.add(u) + sess.flush() + assert log == ['after_attach', 'before_flush', 'after_begin', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec'] + + log = [] + sess = create_session(autocommit=False, extension=MyExt()) + u = User(name='u1') + sess.add(u) + sess.flush() + assert log == ['after_attach', 'before_flush', 'after_begin', 'after_flush', 'after_flush_postexec'] + + log = [] + u.name = 'ed' + sess.commit() + assert log == ['before_commit', 'before_flush', 'after_flush', 'after_flush_postexec', 'after_commit'] + + log = [] + sess.commit() + assert log == ['before_commit', 'after_commit'] + + log = [] + sess.query(User).delete() + assert log == ['after_begin', 'after_bulk_delete'] + + log = [] + sess.query(User).update({'name': 'foo'}) + assert log == ['after_bulk_update'] + + log = [] + sess = create_session(autocommit=False, extension=MyExt(), bind=testing.db) + conn = sess.connection() + assert log == ['after_begin'] + + @testing.resolve_artifact_names + def test_before_flush(self): + """test that the flush plan can be affected during before_flush()""" + + mapper(User, users) + + class MyExt(sa.orm.session.SessionExtension): + def before_flush(self, session, flush_context, objects): + for obj in list(session.new) + list(session.dirty): + if isinstance(obj, User): + session.add(User(name='another %s' % obj.name)) + for obj in list(session.deleted): + if isinstance(obj, User): + x = session.query(User).filter(User.name=='another %s' % obj.name).one() + session.delete(x) + + sess = create_session(extension = MyExt(), autoflush=True) + u = User(name='u1') + sess.add(u) + sess.flush() + eq_(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + User(name='u1') + ] + ) + + sess.flush() + eq_(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + User(name='u1') + ] + ) + + u.name='u2' + sess.flush() + eq_(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + User(name='another u2'), + User(name='u2') + ] + ) + + sess.delete(u) + sess.flush() + eq_(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + ] + ) + + @testing.resolve_artifact_names + def test_before_flush_affects_dirty(self): + mapper(User, users) + + class MyExt(sa.orm.session.SessionExtension): + def before_flush(self, session, flush_context, objects): + for obj in list(session.identity_map.values()): + obj.name += " modified" + + sess = create_session(extension = MyExt(), autoflush=True) + u = User(name='u1') + sess.add(u) + sess.flush() + eq_(sess.query(User).order_by(User.name).all(), + [ + User(name='u1') + ] + ) + + sess.add(User(name='u2')) + sess.flush() + sess.expunge_all() + eq_(sess.query(User).order_by(User.name).all(), + [ + User(name='u1 modified'), + User(name='u2') + ] + ) + + @testing.resolve_artifact_names + def test_reentrant_flush(self): + + mapper(User, users) + + class MyExt(sa.orm.session.SessionExtension): + def before_flush(s, session, flush_context, objects): + session.flush() + + sess = create_session(extension=MyExt()) + sess.add(User(name='foo')) + assert_raises_message(sa.exc.InvalidRequestError, "already flushing", sess.flush) + + @testing.resolve_artifact_names + def test_pickled_update(self): + mapper(User, users) + sess1 = create_session() + sess2 = create_session() + + u1 = User(name='u1') + sess1.add(u1) + + assert_raises_message(sa.exc.InvalidRequestError, "already attached to session", sess2.add, u1) + + u2 = pickle.loads(pickle.dumps(u1)) + + sess2.add(u2) + + @testing.resolve_artifact_names + def test_duplicate_update(self): + mapper(User, users) + Session = sessionmaker() + sess = Session() + + u1 = User(name='u1') + sess.add(u1) + sess.flush() + assert u1.id is not None + + sess.expunge(u1) + + assert u1 not in sess + assert Session.object_session(u1) is None + + u2 = sess.query(User).get(u1.id) + assert u2 is not None and u2 is not u1 + assert u2 in sess + + assert_raises(Exception, lambda: sess.add(u1)) + + sess.expunge(u2) + assert u2 not in sess + assert Session.object_session(u2) is None + + u1.name = "John" + u2.name = "Doe" + + sess.add(u1) + assert u1 in sess + assert Session.object_session(u1) is sess + + sess.flush() + + sess.expunge_all() + + u3 = sess.query(User).get(u1.id) + assert u3 is not u1 and u3 is not u2 and u3.name == u1.name + + @testing.resolve_artifact_names + def test_no_double_save(self): + sess = create_session() + class Foo(object): + def __init__(self): + sess.add(self) + class Bar(Foo): + def __init__(self): + sess.add(self) + Foo.__init__(self) + mapper(Foo, users) + mapper(Bar, users) + + b = Bar() + assert b in sess + assert len(list(sess)) == 1 + +class DisposedStates(_base.MappedTest): + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + global t1 + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + @classmethod + def setup_mappers(cls): + global T + class T(object): + def __init__(self, data): + self.data = data + mapper(T, t1) + + def teardown(self): + from sqlalchemy.orm.session import _sessions + _sessions.clear() + super(DisposedStates, self).teardown() + + def _set_imap_in_disposal(self, sess, *objs): + """remove selected objects from the given session, as though they + were dereferenced and removed from WeakIdentityMap. + + Hardcodes the identity map's "all_states()" method to return the full list + of states. This simulates the all_states() method returning results, afterwhich + some of the states get garbage collected (this normally only happens during + asynchronous gc). The Session now has one or more + InstanceState's which have been removed from the identity map and disposed. + + Will the Session not trip over this ??? Stay tuned. + + """ + all_states = sess.identity_map.all_states() + sess.identity_map.all_states = lambda: all_states + for obj in objs: + state = attributes.instance_state(obj) + sess.identity_map.remove(state) + state.dispose() + + def _test_session(self, **kwargs): + global sess + sess = create_session(**kwargs) + + data = o1, o2, o3, o4, o5 = [T('t1'), T('t2'), T('t3'), T('t4'), T('t5')] + + sess.add_all(data) + + sess.flush() + + o1.data = 't1modified' + o5.data = 't5modified' + + self._set_imap_in_disposal(sess, o2, o4, o5) + return sess + + def test_flush(self): + self._test_session().flush() + + def test_clear(self): + self._test_session().expunge_all() + + def test_close(self): + self._test_session().close() + + def test_expunge_all(self): + self._test_session().expunge_all() + + def test_expire_all(self): + self._test_session().expire_all() + + def test_rollback(self): + sess = self._test_session(autocommit=False, expire_on_commit=True) + sess.commit() + + sess.rollback() + + +class SessionInterface(testing.TestBase): + """Bogus args to Session methods produce actionable exceptions.""" + + # TODO: expand with message body assertions. + + _class_methods = set(( + 'connection', 'execute', 'get_bind', 'scalar')) + + def _public_session_methods(self): + Session = sa.orm.session.Session + + blacklist = set(('begin', 'query')) + + ok = set() + for meth in Session.public_methods: + if meth in blacklist: + continue + spec = inspect.getargspec(getattr(Session, meth)) + if len(spec[0]) > 1 or spec[1]: + ok.add(meth) + return ok + + def _map_it(self, cls): + return mapper(cls, Table('t', sa.MetaData(), + Column('id', Integer, primary_key=True))) + + @testing.uses_deprecated() + def _test_instance_guards(self, user_arg): + watchdog = set() + + def x_raises_(obj, method, *args, **kw): + watchdog.add(method) + callable_ = getattr(obj, method) + assert_raises(sa.orm.exc.UnmappedInstanceError, + callable_, *args, **kw) + + def raises_(method, *args, **kw): + x_raises_(create_session(), method, *args, **kw) + + raises_('__contains__', user_arg) + + raises_('add', user_arg) + + raises_('add_all', (user_arg,)) + + raises_('delete', user_arg) + + raises_('expire', user_arg) + + raises_('expunge', user_arg) + + # flush will no-op without something in the unit of work + def _(): + class OK(object): + pass + self._map_it(OK) + + s = create_session() + s.add(OK()) + x_raises_(s, 'flush', (user_arg,)) + _() + + raises_('is_modified', user_arg) + + raises_('merge', user_arg) + + raises_('refresh', user_arg) + + raises_('save', user_arg) + + raises_('save_or_update', user_arg) + + raises_('update', user_arg) + + instance_methods = self._public_session_methods() - self._class_methods + + eq_(watchdog, instance_methods, + watchdog.symmetric_difference(instance_methods)) + + def _test_class_guards(self, user_arg): + watchdog = set() + + def raises_(method, *args, **kw): + watchdog.add(method) + callable_ = getattr(create_session(), method) + assert_raises(sa.orm.exc.UnmappedClassError, + callable_, *args, **kw) + + raises_('connection', mapper=user_arg) + + raises_('execute', 'SELECT 1', mapper=user_arg) + + raises_('get_bind', mapper=user_arg) + + raises_('scalar', 'SELECT 1', mapper=user_arg) + + eq_(watchdog, self._class_methods, + watchdog.symmetric_difference(self._class_methods)) + + def test_unmapped_instance(self): + class Unmapped(object): + pass + + self._test_instance_guards(Unmapped()) + self._test_class_guards(Unmapped) + + def test_unmapped_primitives(self): + for prim in ('doh', 123, ('t', 'u', 'p', 'l', 'e')): + self._test_instance_guards(prim) + self._test_class_guards(prim) + + def test_unmapped_class_for_instance(self): + class Unmapped(object): + pass + + self._test_instance_guards(Unmapped) + self._test_class_guards(Unmapped) + + def test_mapped_class_for_instance(self): + class Mapped(object): + pass + self._map_it(Mapped) + + self._test_instance_guards(Mapped) + # no class guards- it would pass. + + def test_missing_state(self): + class Mapped(object): + pass + early = Mapped() + self._map_it(Mapped) + + self._test_instance_guards(early) + self._test_class_guards(early) + + +class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest): + @classmethod + def create_engine(cls): + return engines.testing_engine(options=dict(strategy='threadlocal')) + + @classmethod + def define_tables(cls, metadata): + Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(20)), + test_needs_acid=True) + + @classmethod + def setup_classes(cls): + class User(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(User, users) + + @testing.exclude('mysql', '<', (5, 0, 3), 'FIXME: unknown') + @testing.resolve_artifact_names + def test_session_nesting(self): + sess = create_session(bind=self.engine) + self.engine.begin() + u = User(name='ed') + sess.add(u) + sess.flush() + self.engine.commit() + + diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py new file mode 100644 index 000000000..5aa541cda --- /dev/null +++ b/test/orm/test_transaction.py @@ -0,0 +1,498 @@ + +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy import * +from sqlalchemy.orm import attributes +from sqlalchemy import exc as sa_exc +from sqlalchemy.orm import * + +from sqlalchemy.test import testing +from test.orm import _base +from test.orm._fixtures import FixtureTest, User, Address, users, addresses + +import gc + +class TransactionTest(FixtureTest): + run_setup_mappers = 'once' + run_inserts = None + session = sessionmaker() + + @classmethod + def setup_mappers(cls): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', + cascade="all, delete-orphan"), + }) + mapper(Address, addresses) + + +class FixtureDataTest(TransactionTest): + run_inserts = 'each' + + def test_attrs_on_rollback(self): + sess = self.session() + u1 = sess.query(User).get(7) + u1.name = 'ed' + sess.rollback() + eq_(u1.name, 'jack') + + def test_commit_persistent(self): + sess = self.session() + u1 = sess.query(User).get(7) + u1.name = 'ed' + sess.flush() + sess.commit() + eq_(u1.name, 'ed') + + def test_concurrent_commit_persistent(self): + s1 = self.session() + u1 = s1.query(User).get(7) + u1.name = 'ed' + s1.commit() + + s2 = self.session() + u2 = s2.query(User).get(7) + assert u2.name == 'ed' + u2.name = 'will' + s2.commit() + + assert u1.name == 'will' + +class AutoExpireTest(TransactionTest): + + def test_expunge_pending_on_rollback(self): + sess = self.session() + u2= User(name='newuser') + sess.add(u2) + assert u2 in sess + sess.rollback() + assert u2 not in sess + + def test_trans_pending_cleared_on_commit(self): + sess = self.session() + u2= User(name='newuser') + sess.add(u2) + assert u2 in sess + sess.commit() + assert u2 in sess + u3 = User(name='anotheruser') + sess.add(u3) + sess.rollback() + assert u3 not in sess + assert u2 in sess + + def test_update_deleted_on_rollback(self): + s = self.session() + u1 = User(name='ed') + s.add(u1) + s.commit() + + # this actually tests that the delete() operation, + # when cascaded to the "addresses" collection, does not + # trigger a flush (via lazyload) before the cascade is complete. + s.delete(u1) + assert u1 in s.deleted + s.rollback() + assert u1 in s + assert u1 not in s.deleted + + def test_gced_delete_on_rollback(self): + s = self.session() + u1 = User(name='ed') + s.add(u1) + s.commit() + + s.delete(u1) + u1_state = attributes.instance_state(u1) + assert u1_state in s.identity_map.all_states() + assert u1_state in s._deleted + s.flush() + assert u1_state not in s.identity_map.all_states() + assert u1_state not in s._deleted + del u1 + gc.collect() + assert u1_state.obj() is None + + s.rollback() + assert u1_state in s.identity_map.all_states() + u1 = s.query(User).filter_by(name='ed').one() + assert u1_state not in s.identity_map.all_states() + assert s.scalar(users.count()) == 1 + s.delete(u1) + s.flush() + assert s.scalar(users.count()) == 0 + s.commit() + + def test_trans_deleted_cleared_on_rollback(self): + s = self.session() + u1 = User(name='ed') + s.add(u1) + s.commit() + + s.delete(u1) + s.commit() + assert u1 not in s + s.rollback() + assert u1 not in s + + def test_update_deleted_on_rollback_cascade(self): + s = self.session() + u1 = User(name='ed', addresses=[Address(email_address='foo')]) + s.add(u1) + s.commit() + + s.delete(u1) + assert u1 in s.deleted + assert u1.addresses[0] in s.deleted + s.rollback() + assert u1 in s + assert u1 not in s.deleted + assert u1.addresses[0] not in s.deleted + + def test_update_deleted_on_rollback_orphan(self): + s = self.session() + u1 = User(name='ed', addresses=[Address(email_address='foo')]) + s.add(u1) + s.commit() + + a1 = u1.addresses[0] + u1.addresses.remove(a1) + + s.flush() + eq_(s.query(Address).filter(Address.email_address=='foo').all(), []) + s.rollback() + assert a1 not in s.deleted + assert u1.addresses == [a1] + + def test_commit_pending(self): + sess = self.session() + u1 = User(name='newuser') + sess.add(u1) + sess.flush() + sess.commit() + eq_(u1.name, 'newuser') + + + def test_concurrent_commit_pending(self): + s1 = self.session() + u1 = User(name='edward') + s1.add(u1) + s1.commit() + + s2 = self.session() + u2 = s2.query(User).filter(User.name=='edward').one() + u2.name = 'will' + s2.commit() + + assert u1.name == 'will' + +class TwoPhaseTest(TransactionTest): + + @testing.requires.two_phase_transactions + def test_rollback_on_prepare(self): + s = self.session(twophase=True) + + u = User(name='ed') + s.add(u) + s.prepare() + s.rollback() + + assert u not in s + +class RollbackRecoverTest(TransactionTest): + + def test_pk_violation(self): + s = self.session() + a1 = Address(email_address='foo') + u1 = User(id=1, name='ed', addresses=[a1]) + s.add(u1) + s.commit() + + a2 = Address(email_address='bar') + u2 = User(id=1, name='jack', addresses=[a2]) + + u1.name = 'edward' + a1.email_address = 'foober' + s.add(u2) + assert_raises(sa_exc.FlushError, s.commit) + assert_raises(sa_exc.InvalidRequestError, s.commit) + s.rollback() + assert u2 not in s + assert a2 not in s + assert u1 in s + assert a1 in s + assert u1.name == 'ed' + assert a1.email_address == 'foo' + u1.name = 'edward' + a1.email_address = 'foober' + s.commit() + eq_( + s.query(User).all(), + [User(id=1, name='edward', addresses=[Address(email_address='foober')])] + ) + + @testing.requires.savepoints + def test_pk_violation_with_savepoint(self): + s = self.session() + a1 = Address(email_address='foo') + u1 = User(id=1, name='ed', addresses=[a1]) + s.add(u1) + s.commit() + + a2 = Address(email_address='bar') + u2 = User(id=1, name='jack', addresses=[a2]) + + u1.name = 'edward' + a1.email_address = 'foober' + s.begin_nested() + s.add(u2) + assert_raises(sa_exc.FlushError, s.commit) + assert_raises(sa_exc.InvalidRequestError, s.commit) + s.rollback() + assert u2 not in s + assert a2 not in s + assert u1 in s + assert a1 in s + + s.commit() + assert s.query(User).all() == [User(id=1, name='edward', addresses=[Address(email_address='foober')])] + + +class SavepointTest(TransactionTest): + + @testing.requires.savepoints + def test_savepoint_rollback(self): + s = self.session() + u1 = User(name='ed') + u2 = User(name='jack') + s.add_all([u1, u2]) + + s.begin_nested() + u3 = User(name='wendy') + u4 = User(name='foo') + u1.name = 'edward' + u2.name = 'jackward' + s.add_all([u3, u4]) + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + s.rollback() + assert u1.name == 'ed' + assert u2.name == 'jack' + eq_(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) + s.commit() + assert u1.name == 'ed' + assert u2.name == 'jack' + eq_(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) + + @testing.requires.savepoints + def test_savepoint_delete(self): + s = self.session() + u1 = User(name='ed') + s.add(u1) + s.commit() + eq_(s.query(User).filter_by(name='ed').count(), 1) + s.begin_nested() + s.delete(u1) + s.commit() + eq_(s.query(User).filter_by(name='ed').count(), 0) + s.commit() + + @testing.requires.savepoints + def test_savepoint_commit(self): + s = self.session() + u1 = User(name='ed') + u2 = User(name='jack') + s.add_all([u1, u2]) + + s.begin_nested() + u3 = User(name='wendy') + u4 = User(name='foo') + u1.name = 'edward' + u2.name = 'jackward' + s.add_all([u3, u4]) + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + s.commit() + def go(): + assert u1.name == 'edward' + assert u2.name == 'jackward' + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + self.assert_sql_count(testing.db, go, 1) + + s.commit() + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + + @testing.requires.savepoints + def test_savepoint_rollback_collections(self): + s = self.session() + u1 = User(name='ed', addresses=[Address(email_address='foo')]) + s.add(u1) + s.commit() + + u1.name='edward' + u1.addresses.append(Address(email_address='bar')) + s.begin_nested() + u2 = User(name='jack', addresses=[Address(email_address='bat')]) + s.add(u2) + eq_(s.query(User).order_by(User.id).all(), + [ + User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), + User(name='jack', addresses=[Address(email_address='bat')]) + ] + ) + s.rollback() + eq_(s.query(User).order_by(User.id).all(), + [ + User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), + ] + ) + s.commit() + eq_(s.query(User).order_by(User.id).all(), + [ + User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), + ] + ) + + @testing.requires.savepoints + def test_savepoint_commit_collections(self): + s = self.session() + u1 = User(name='ed', addresses=[Address(email_address='foo')]) + s.add(u1) + s.commit() + + u1.name='edward' + u1.addresses.append(Address(email_address='bar')) + s.begin_nested() + u2 = User(name='jack', addresses=[Address(email_address='bat')]) + s.add(u2) + eq_(s.query(User).order_by(User.id).all(), + [ + User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), + User(name='jack', addresses=[Address(email_address='bat')]) + ] + ) + s.commit() + eq_(s.query(User).order_by(User.id).all(), + [ + User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), + User(name='jack', addresses=[Address(email_address='bat')]) + ] + ) + s.commit() + eq_(s.query(User).order_by(User.id).all(), + [ + User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), + User(name='jack', addresses=[Address(email_address='bat')]) + ] + ) + + @testing.requires.savepoints + def test_expunge_pending_on_rollback(self): + sess = self.session() + + sess.begin_nested() + u2= User(name='newuser') + sess.add(u2) + assert u2 in sess + sess.rollback() + assert u2 not in sess + + @testing.requires.savepoints + def test_update_deleted_on_rollback(self): + s = self.session() + u1 = User(name='ed') + s.add(u1) + s.commit() + + s.begin_nested() + s.delete(u1) + assert u1 in s.deleted + s.rollback() + assert u1 in s + assert u1 not in s.deleted + + +class AccountingFlagsTest(TransactionTest): + def test_no_expire_on_commit(self): + sess = sessionmaker(expire_on_commit=False)() + u1 = User(name='ed') + sess.add(u1) + sess.commit() + + testing.db.execute(users.update(users.c.name=='ed').values(name='edward')) + + assert u1.name == 'ed' + sess.expire_all() + assert u1.name == 'edward' + + def test_rollback_no_accounting(self): + sess = sessionmaker(_enable_transaction_accounting=False)() + u1 = User(name='ed') + sess.add(u1) + sess.commit() + + u1.name = 'edwardo' + sess.rollback() + + testing.db.execute(users.update(users.c.name=='ed').values(name='edward')) + + assert u1.name == 'edwardo' + sess.expire_all() + assert u1.name == 'edward' + + def test_commit_no_accounting(self): + sess = sessionmaker(_enable_transaction_accounting=False)() + u1 = User(name='ed') + sess.add(u1) + sess.commit() + + u1.name = 'edwardo' + sess.rollback() + + testing.db.execute(users.update(users.c.name=='ed').values(name='edward')) + + assert u1.name == 'edwardo' + sess.commit() + + assert testing.db.execute(select([users.c.name])).fetchall() == [('edwardo',)] + assert u1.name == 'edwardo' + + sess.delete(u1) + sess.commit() + + def test_preflush_no_accounting(self): + sess = sessionmaker(_enable_transaction_accounting=False, autocommit=True)() + u1 = User(name='ed') + sess.add(u1) + sess.flush() + + sess.begin() + u1.name = 'edwardo' + u2 = User(name="some other user") + sess.add(u2) + + sess.rollback() + + sess.begin() + assert testing.db.execute(select([users.c.name])).fetchall() == [('ed',)] + + +class AutoCommitTest(TransactionTest): + def test_begin_nested_requires_trans(self): + sess = create_session(autocommit=True) + assert_raises(sa_exc.InvalidRequestError, sess.begin_nested) + + def test_begin_preflush(self): + sess = create_session(autocommit=True) + + u1 = User(name='ed') + sess.add(u1) + + sess.begin() + u2 = User(name='some other user') + sess.add(u2) + sess.rollback() + assert u2 not in sess + assert u1 in sess + assert sess.query(User).filter_by(name='ed').one() is u1 + + + + diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py new file mode 100644 index 000000000..f95346902 --- /dev/null +++ b/test/orm/test_unitofwork.py @@ -0,0 +1,2374 @@ +# coding: utf-8 +"""Tests unitofwork operations.""" + +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +import datetime +import operator +from sqlalchemy.orm import mapper as orm_mapper + +import sqlalchemy as sa +from sqlalchemy.test import engines, testing, pickleable +from sqlalchemy import Integer, String, ForeignKey, literal_column +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, column_property +from sqlalchemy.test.testing import eq_, ne_ +from test.orm import _base, _fixtures +from test.engine import _base as engine_base +from sqlalchemy.test.assertsql import AllOf, CompiledSQL +import gc + +class UnitOfWorkTest(object): + pass + +class HistoryTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + def setup_classes(cls): + class User(_base.ComparableEntity): + pass + class Address(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def test_backref(self): + am = mapper(Address, addresses) + m = mapper(User, users, properties=dict( + addresses = relation(am, backref='user', lazy=False))) + + session = create_session(autocommit=False) + + u = User(name='u1') + a = Address(email_address='u1@e') + a.user = u + session.add(u) + + self.assert_(u.addresses == [a]) + session.commit() + session.expunge_all() + + u = session.query(m).one() + assert u.addresses[0].user == u + session.close() + + +class VersioningTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('version_table', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('version_id', Integer, nullable=False), + Column('value', String(40), nullable=False)) + + @classmethod + def setup_classes(cls): + class Foo(_base.ComparableEntity): + pass + + @engines.close_open_connections + @testing.resolve_artifact_names + def test_basic(self): + mapper(Foo, version_table, version_id_col=version_table.c.version_id) + + s1 = create_session(autocommit=False) + f1 = Foo(value='f1') + f2 = Foo(value='f2') + s1.add_all((f1, f2)) + s1.commit() + + f1.value='f1rev2' + s1.commit() + + s2 = create_session(autocommit=False) + f1_s = s2.query(Foo).get(f1.id) + f1_s.value='f1rev3' + s2.commit() + + f1.value='f1rev3mine' + + # Only dialects with a sane rowcount can detect the + # ConcurrentModificationError + if testing.db.dialect.supports_sane_rowcount: + assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit) + s1.rollback() + else: + s1.commit() + + # new in 0.5 ! dont need to close the session + f1 = s1.query(Foo).get(f1.id) + f2 = s1.query(Foo).get(f2.id) + + f1_s.value='f1rev4' + s2.commit() + + s1.delete(f1) + s1.delete(f2) + + if testing.db.dialect.supports_sane_multi_rowcount: + assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit) + else: + s1.commit() + + @engines.close_open_connections + @testing.resolve_artifact_names + def test_versioncheck(self): + """query.with_lockmode performs a 'version check' on an already loaded instance""" + + s1 = create_session(autocommit=False) + + mapper(Foo, version_table, version_id_col=version_table.c.version_id) + f1s1 = Foo(value='f1 value') + s1.add(f1s1) + s1.commit() + + s2 = create_session(autocommit=False) + f1s2 = s2.query(Foo).get(f1s1.id) + f1s2.value='f1 new value' + s2.commit() + + # load, version is wrong + assert_raises(sa.orm.exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id) + + # reload it + s1.query(Foo).populate_existing().get(f1s1.id) + # now assert version OK + s1.query(Foo).with_lockmode('read').get(f1s1.id) + + # assert brand new load is OK too + s1.close() + s1.query(Foo).with_lockmode('read').get(f1s1.id) + + @engines.close_open_connections + @testing.resolve_artifact_names + def test_noversioncheck(self): + """test query.with_lockmode works when the mapper has no version id col""" + s1 = create_session(autocommit=False) + mapper(Foo, version_table) + f1s1 = Foo(value="foo", version_id=0) + s1.add(f1s1) + s1.commit() + + s2 = create_session(autocommit=False) + f1s2 = s2.query(Foo).with_lockmode('read').get(f1s1.id) + assert f1s2.id == f1s1.id + assert f1s2.value == f1s1.value + +class UnicodeTest(_base.MappedTest): + __requires__ = ('unicode_connections',) + + @classmethod + def define_tables(cls, metadata): + Table('uni_t1', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('txt', sa.Unicode(50), unique=True)) + Table('uni_t2', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('txt', sa.Unicode(50), ForeignKey('uni_t1'))) + + @classmethod + def setup_classes(cls): + class Test(_base.BasicEntity): + pass + class Test2(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def test_basic(self): + mapper(Test, uni_t1) + + txt = u"\u0160\u0110\u0106\u010c\u017d" + t1 = Test(id=1, txt=txt) + self.assert_(t1.txt == txt) + + session = create_session(autocommit=False) + session.add(t1) + session.commit() + + self.assert_(t1.txt == txt) + + @testing.resolve_artifact_names + def test_relation(self): + mapper(Test, uni_t1, properties={ + 't2s': relation(Test2)}) + mapper(Test2, uni_t2) + + txt = u"\u0160\u0110\u0106\u010c\u017d" + t1 = Test(txt=txt) + t1.t2s.append(Test2()) + t1.t2s.append(Test2()) + session = create_session(autocommit=False) + session.add(t1) + session.commit() + session.close() + + session = create_session() + t1 = session.query(Test).filter_by(id=t1.id).one() + assert len(t1.t2s) == 2 + +class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest): + __requires__ = ('unicode_connections', 'unicode_ddl',) + + @classmethod + def create_engine(cls): + return engines.utf8_engine() + + @classmethod + def define_tables(cls, metadata): + t1 = Table('unitable1', metadata, + Column(u'méil', Integer, primary_key=True, key='a', test_needs_autoincrement=True), + Column(u'\u6e2c\u8a66', Integer, key='b'), + Column('type', String(20)), + test_needs_fk=True, + test_needs_autoincrement=True) + t2 = Table(u'Unitéble2', metadata, + Column(u'méil', Integer, primary_key=True, key="cc", test_needs_autoincrement=True), + Column(u'\u6e2c\u8a66', Integer, + ForeignKey(u'unitable1.a'), key="d"), + Column(u'\u6e2c\u8a66_2', Integer, key="e"), + test_needs_fk=True, + test_needs_autoincrement=True) + + cls.tables['t1'] = t1 + cls.tables['t2'] = t2 + + @classmethod + def setup_class(cls): + super(UnicodeSchemaTest, cls).setup_class() + + @classmethod + def teardown_class(cls): + super(UnicodeSchemaTest, cls).teardown_class() + + @testing.fails_on('mssql', 'pyodbc returns a non unicode encoding of the results description.') + @testing.resolve_artifact_names + def test_mapping(self): + class A(_base.ComparableEntity): + pass + class B(_base.ComparableEntity): + pass + + mapper(A, t1, properties={ + 't2s':relation(B)}) + mapper(B, t2) + + a1 = A() + b1 = B() + a1.t2s.append(b1) + + session = create_session() + session.add(a1) + session.flush() + session.expunge_all() + + new_a1 = session.query(A).filter(t1.c.a == a1.a).one() + assert new_a1.a == a1.a + assert new_a1.t2s[0].d == b1.d + session.expunge_all() + + new_a1 = (session.query(A).options(sa.orm.eagerload('t2s')). + filter(t1.c.a == a1.a)).one() + assert new_a1.a == a1.a + assert new_a1.t2s[0].d == b1.d + session.expunge_all() + + new_a1 = session.query(A).filter(A.a == a1.a).one() + assert new_a1.a == a1.a + assert new_a1.t2s[0].d == b1.d + session.expunge_all() + + @testing.fails_on('mssql', 'pyodbc returns a non unicode encoding of the results description.') + @testing.resolve_artifact_names + def test_inheritance_mapping(self): + class A(_base.ComparableEntity): + pass + class B(A): + pass + + mapper(A, t1, + polymorphic_on=t1.c.type, + polymorphic_identity='a') + mapper(B, t2, + inherits=A, + polymorphic_identity='b') + a1 = A(b=5) + b1 = B(e=7) + + session = create_session() + session.add_all((a1, b1)) + session.flush() + session.expunge_all() + + eq_([A(b=5), B(e=7)], session.query(A).all()) + + +class MutableTypesTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('mutable_t', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', sa.PickleType), + Column('val', sa.Unicode(30))) + + @classmethod + def setup_classes(cls): + class Foo(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Foo, mutable_t) + + @testing.resolve_artifact_names + def test_basic(self): + """Changes are detected for types marked as MutableType.""" + + f1 = Foo() + f1.data = pickleable.Bar(4,5) + + session = create_session() + session.add(f1) + session.flush() + session.expunge_all() + + f2 = session.query(Foo).filter_by(id=f1.id).one() + assert 'data' in sa.orm.attributes.instance_state(f2).unmodified + eq_(f2.data, f1.data) + + f2.data.y = 19 + assert f2 in session.dirty + assert 'data' not in sa.orm.attributes.instance_state(f2).unmodified + session.flush() + session.expunge_all() + + f3 = session.query(Foo).filter_by(id=f1.id).one() + ne_(f3.data,f1.data) + eq_(f3.data, pickleable.Bar(4, 19)) + + @testing.resolve_artifact_names + def test_mutable_changes(self): + """Mutable changes are detected or not detected correctly""" + + f1 = Foo() + f1.data = pickleable.Bar(4,5) + f1.val = u'hi' + + session = create_session(autocommit=False) + session.add(f1) + session.commit() + + bind = self.metadata.bind + + self.sql_count_(0, session.commit) + f1.val = u'someothervalue' + self.assert_sql(bind, session.commit, [ + ("UPDATE mutable_t SET val=:val " + "WHERE mutable_t.id = :mutable_t_id", + {'mutable_t_id': f1.id, 'val': u'someothervalue'})]) + + f1.val = u'hi' + f1.data.x = 9 + self.assert_sql(bind, session.commit, [ + ("UPDATE mutable_t SET data=:data, val=:val " + "WHERE mutable_t.id = :mutable_t_id", + {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})]) + + + @testing.resolve_artifact_names + def test_resurrect(self): + f1 = Foo() + f1.data = pickleable.Bar(4,5) + f1.val = u'hi' + + session = create_session(autocommit=False) + session.add(f1) + session.commit() + + f1.data.y = 19 + del f1 + + gc.collect() + assert len(session.identity_map) == 1 + + session.commit() + + assert session.query(Foo).one().data == pickleable.Bar(4, 19) + + + @testing.uses_deprecated() + @testing.resolve_artifact_names + def test_nocomparison(self): + """Changes are detected on MutableTypes lacking an __eq__ method.""" + + f1 = Foo() + f1.data = pickleable.BarWithoutCompare(4,5) + session = create_session(autocommit=False) + session.add(f1) + session.commit() + + self.sql_count_(0, session.commit) + session.close() + + session = create_session(autocommit=False) + f2 = session.query(Foo).filter_by(id=f1.id).one() + self.sql_count_(0, session.commit) + + f2.data.y = 19 + self.sql_count_(1, session.commit) + session.close() + + session = create_session(autocommit=False) + f3 = session.query(Foo).filter_by(id=f1.id).one() + eq_((f3.data.x, f3.data.y), (4,19)) + self.sql_count_(0, session.commit) + session.close() + + @testing.resolve_artifact_names + def test_unicode(self): + """Equivalent Unicode values are not flagged as changed.""" + + f1 = Foo(val=u'hi') + + session = create_session(autocommit=False) + session.add(f1) + session.commit() + session.expunge_all() + + f1 = session.query(Foo).get(f1.id) + f1.val = u'hi' + self.sql_count_(0, session.commit) + + +class PickledDictsTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('mutable_t', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', sa.PickleType(comparator=operator.eq))) + + @classmethod + def setup_classes(cls): + class Foo(_base.BasicEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Foo, mutable_t) + + @testing.resolve_artifact_names + def test_dicts(self): + """Dictionaries may not pickle the same way twice.""" + + f1 = Foo() + f1.data = [ { + 'personne': {'nom': u'Smith', + 'pers_id': 1, + 'prenom': u'john', + 'civilite': u'Mr', + 'int_3': False, + 'int_2': False, + 'int_1': u'23', + 'VenSoir': True, + 'str_1': u'Test', + 'SamMidi': False, + 'str_2': u'chien', + 'DimMidi': False, + 'SamSoir': True, + 'SamAcc': False} } ] + + session = create_session(autocommit=False) + session.add(f1) + session.commit() + + self.sql_count_(0, session.commit) + + f1.data = [ { + 'personne': {'nom': u'Smith', + 'pers_id': 1, + 'prenom': u'john', + 'civilite': u'Mr', + 'int_3': False, + 'int_2': False, + 'int_1': u'23', + 'VenSoir': True, + 'str_1': u'Test', + 'SamMidi': False, + 'str_2': u'chien', + 'DimMidi': False, + 'SamSoir': True, + 'SamAcc': False} } ] + + self.sql_count_(0, session.commit) + + f1.data[0]['personne']['VenSoir']= False + self.sql_count_(1, session.commit) + + session.expunge_all() + f = session.query(Foo).get(f1.id) + eq_(f.data, + [ { + 'personne': {'nom': u'Smith', + 'pers_id': 1, + 'prenom': u'john', + 'civilite': u'Mr', + 'int_3': False, + 'int_2': False, + 'int_1': u'23', + 'VenSoir': False, + 'str_1': u'Test', + 'SamMidi': False, + 'str_2': u'chien', + 'DimMidi': False, + 'SamSoir': True, + 'SamAcc': False} } ]) + + +class PKTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('multipk1', metadata, + Column('multi_id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('multi_rev', Integer, primary_key=True), + Column('name', String(50), nullable=False), + Column('value', String(100))) + + Table('multipk2', metadata, + Column('pk_col_1', String(30), primary_key=True), + Column('pk_col_2', String(30), primary_key=True), + Column('data', String(30))) + Table('multipk3', metadata, + Column('pri_code', String(30), key='primary', primary_key=True), + Column('sec_code', String(30), key='secondary', primary_key=True), + Column('date_assigned', sa.Date, key='assigned', primary_key=True), + Column('data', String(30))) + + @classmethod + def setup_classes(cls): + class Entry(_base.BasicEntity): + pass + + # not supported on sqlite since sqlite's auto-pk generation only works with + # single column primary keys + @testing.fails_on('sqlite', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_primary_key(self): + mapper(Entry, multipk1) + + e = Entry(name='entry1', value='this is entry 1', multi_rev=2) + + session = create_session() + session.add(e) + session.flush() + session.expunge_all() + + e2 = session.query(Entry).get((e.multi_id, 2)) + self.assert_(e is not e2) + state = sa.orm.attributes.instance_state(e) + state2 = sa.orm.attributes.instance_state(e2) + eq_(state.key, state2.key) + + # this one works with sqlite since we are manually setting up pk values + @testing.resolve_artifact_names + def test_manual_pk(self): + mapper(Entry, multipk2) + + e = Entry(pk_col_1='pk1', pk_col_2='pk1_related', data='im the data') + + session = create_session() + session.add(e) + session.flush() + + @testing.resolve_artifact_names + def test_key_pks(self): + mapper(Entry, multipk3) + + e = Entry(primary= 'pk1', secondary='pk2', + assigned=datetime.date.today(), data='some more data') + + session = create_session() + session.add(e) + session.flush() + + +class ForeignPKTest(_base.MappedTest): + """Detection of the relationship direction on PK joins.""" + + @classmethod + def define_tables(cls, metadata): + Table("people", metadata, + Column('person', String(10), primary_key=True), + Column('firstname', String(10)), + Column('lastname', String(10))) + + Table("peoplesites", metadata, + Column('person', String(10), ForeignKey("people.person"), + primary_key=True), + Column('site', String(10))) + + @classmethod + def setup_classes(cls): + class Person(_base.BasicEntity): + pass + class PersonSite(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def test_basic(self): + m1 = mapper(PersonSite, peoplesites) + m2 = mapper(Person, people, properties={ + 'sites' : relation(PersonSite)}) + + sa.orm.compile_mappers() + eq_(list(m2.get_property('sites').synchronize_pairs), + [(people.c.person, peoplesites.c.person)]) + + p = Person(person='im the key', firstname='asdf') + ps = PersonSite(site='asdf') + p.sites.append(ps) + + session = create_session() + session.add(p) + session.flush() + + p_count = people.count(people.c.person=='im the key').scalar() + eq_(p_count, 1) + eq_(peoplesites.count(peoplesites.c.person=='im the key').scalar(), 1) + + +class ClauseAttributesTest(_base.MappedTest): + + @classmethod + def define_tables(cls, metadata): + Table('users_t', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(30)), + Column('counter', Integer, default=1)) + + @classmethod + def setup_classes(cls): + class User(_base.ComparableEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(User, users_t) + + @testing.resolve_artifact_names + def test_update(self): + u = User(name='test') + + session = create_session() + session.add(u) + session.flush() + + eq_(u.counter, 1) + u.counter = User.counter + 1 + session.flush() + + def go(): + assert (u.counter == 2) is True # ensure its not a ClauseElement + self.sql_count_(1, go) + + @testing.resolve_artifact_names + def test_multi_update(self): + u = User(name='test') + + session = create_session() + session.add(u) + session.flush() + + eq_(u.counter, 1) + u.name = 'test2' + u.counter = User.counter + 1 + session.flush() + + def go(): + eq_(u.name, 'test2') + assert (u.counter == 2) is True + self.sql_count_(1, go) + + session.expunge_all() + u = session.query(User).get(u.id) + eq_(u.name, 'test2') + eq_(u.counter, 2) + + @testing.resolve_artifact_names + def test_insert(self): + u = User(name='test', counter=sa.select([5])) + + session = create_session() + session.add(u) + session.flush() + + assert (u.counter == 5) is True + + +class PassiveDeletesTest(_base.MappedTest): + __requires__ = ('foreign_keys',) + + @classmethod + def define_tables(cls, metadata): + Table('mytable', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('data', String(30)), + test_needs_fk=True) + + Table('myothertable', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('parent_id', Integer), + Column('data', String(30)), + sa.ForeignKeyConstraint(['parent_id'], + ['mytable.id'], + ondelete="CASCADE"), + test_needs_fk=True) + + @classmethod + def setup_classes(cls): + class MyClass(_base.BasicEntity): + pass + class MyOtherClass(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def test_basic(self): + mapper(MyOtherClass, myothertable) + mapper(MyClass, mytable, properties={ + 'children':relation(MyOtherClass, + passive_deletes=True, + cascade="all")}) + session = create_session() + mc = MyClass() + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + + session.add(mc) + session.flush() + session.expunge_all() + + assert myothertable.count().scalar() == 4 + mc = session.query(MyClass).get(mc.id) + session.delete(mc) + session.flush() + + assert mytable.count().scalar() == 0 + assert myothertable.count().scalar() == 0 + + @testing.resolve_artifact_names + def test_backwards_pd(self): + # the unusual scenario where a trigger or something might be deleting + # a many-to-one on deletion of the parent row + mapper(MyOtherClass, myothertable, properties={ + 'myclass':relation(MyClass, cascade="all, delete", passive_deletes=True) + }) + mapper(MyClass, mytable) + + session = create_session() + mc = MyClass() + mco = MyOtherClass() + mco.myclass = mc + session.add(mco) + session.flush() + + assert mytable.count().scalar() == 1 + assert myothertable.count().scalar() == 1 + + session.expire(mco, ['myclass']) + session.delete(mco) + session.flush() + + assert mytable.count().scalar() == 1 + assert myothertable.count().scalar() == 0 + +class ExtraPassiveDeletesTest(_base.MappedTest): + __requires__ = ('foreign_keys',) + + @classmethod + def define_tables(cls, metadata): + Table('mytable', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('data', String(30)), + test_needs_fk=True) + + Table('myothertable', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('parent_id', Integer), + Column('data', String(30)), + # no CASCADE, the same as ON DELETE RESTRICT + sa.ForeignKeyConstraint(['parent_id'], + ['mytable.id']), + test_needs_fk=True) + + @classmethod + def setup_classes(cls): + class MyClass(_base.BasicEntity): + pass + class MyOtherClass(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def test_assertions(self): + mapper(MyOtherClass, myothertable) + try: + mapper(MyClass, mytable, properties={ + 'children':relation(MyOtherClass, + passive_deletes='all', + cascade="all")}) + assert False + except sa.exc.ArgumentError, e: + eq_(str(e), + "Can't set passive_deletes='all' in conjunction with 'delete' " + "or 'delete-orphan' cascade") + + @testing.resolve_artifact_names + def test_extra_passive(self): + mapper(MyOtherClass, myothertable) + mapper(MyClass, mytable, properties={ + 'children': relation(MyOtherClass, + passive_deletes='all', + cascade="save-update")}) + + session = create_session() + mc = MyClass() + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + mc.children.append(MyOtherClass()) + session.add(mc) + session.flush() + session.expunge_all() + + assert myothertable.count().scalar() == 4 + mc = session.query(MyClass).get(mc.id) + session.delete(mc) + assert_raises(sa.exc.DBAPIError, session.flush) + + @testing.resolve_artifact_names + def test_extra_passive_2(self): + mapper(MyOtherClass, myothertable) + mapper(MyClass, mytable, properties={ + 'children': relation(MyOtherClass, + passive_deletes='all', + cascade="save-update")}) + + session = create_session() + mc = MyClass() + mc.children.append(MyOtherClass()) + session.add(mc) + session.flush() + session.expunge_all() + + assert myothertable.count().scalar() == 1 + + mc = session.query(MyClass).get(mc.id) + session.delete(mc) + mc.children[0].data = 'some new data' + assert_raises(sa.exc.DBAPIError, session.flush) + + +class DefaultTest(_base.MappedTest): + """Exercise mappings on columns with DefaultGenerators. + + Tests that when saving objects whose table contains DefaultGenerators, + either python-side, preexec or database-side, the newly saved instances + receive all the default values either through a post-fetch or getting the + pre-exec'ed defaults back from the engine. + + """ + + @classmethod + def define_tables(cls, metadata): + use_string_defaults = testing.against('postgres', 'oracle', 'sqlite', 'mssql') + + if use_string_defaults: + hohotype = String(30) + hohoval = "im hoho" + althohoval = "im different hoho" + else: + hohotype = Integer + hohoval = 9 + althohoval = 15 + + cls.other_artifacts['hohoval'] = hohoval + cls.other_artifacts['althohoval'] = althohoval + + dt = Table('default_t', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('hoho', hohotype, server_default=str(hohoval)), + Column('counter', Integer, default=sa.func.char_length("1234567")), + Column('foober', String(30), default="im foober", + onupdate="im the update")) + + st = Table('secondary_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50))) + + if testing.against('postgres', 'oracle'): + dt.append_column( + Column('secondary_id', Integer, sa.Sequence('sec_id_seq'), + unique=True)) + st.append_column( + Column('fk_val', Integer, + ForeignKey('default_t.secondary_id'))) + elif testing.against('mssql'): + st.append_column( + Column('fk_val', Integer, + ForeignKey('default_t.id'))) + else: + st.append_column( + Column('hoho', hohotype, ForeignKey('default_t.hoho'))) + + @classmethod + def setup_classes(cls): + class Hoho(_base.ComparableEntity): + pass + class Secondary(_base.ComparableEntity): + pass + + @testing.fails_on('firebird', 'Data type unknown on the parameter') + @testing.resolve_artifact_names + def test_insert(self): + mapper(Hoho, default_t) + + h1 = Hoho(hoho=althohoval) + h2 = Hoho(counter=12) + h3 = Hoho(hoho=althohoval, counter=12) + h4 = Hoho() + h5 = Hoho(foober='im the new foober') + + session = create_session(autocommit=False) + session.add_all((h1, h2, h3, h4, h5)) + session.commit() + + eq_(h1.hoho, althohoval) + eq_(h3.hoho, althohoval) + + def go(): + # test deferred load of attribues, one select per instance + self.assert_(h2.hoho == h4.hoho == h5.hoho == hohoval) + self.sql_count_(3, go) + + def go(): + self.assert_(h1.counter == h4.counter == h5.counter == 7) + self.sql_count_(1, go) + + def go(): + self.assert_(h3.counter == h2.counter == 12) + self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') + self.assert_(h5.foober == 'im the new foober') + self.sql_count_(0, go) + + session.expunge_all() + + (h1, h2, h3, h4, h5) = session.query(Hoho).order_by(Hoho.id).all() + + eq_(h1.hoho, althohoval) + eq_(h3.hoho, althohoval) + self.assert_(h2.hoho == h4.hoho == h5.hoho == hohoval) + self.assert_(h3.counter == h2.counter == 12) + self.assert_(h1.counter == h4.counter == h5.counter == 7) + self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') + eq_(h5.foober, 'im the new foober') + + @testing.fails_on('firebird', 'Data type unknown on the parameter') + @testing.resolve_artifact_names + def test_eager_defaults(self): + mapper(Hoho, default_t, eager_defaults=True) + + h1 = Hoho() + + session = create_session() + session.add(h1) + session.flush() + + self.sql_count_(0, lambda: eq_(h1.hoho, hohoval)) + + @testing.resolve_artifact_names + def test_insert_nopostfetch(self): + # populates from the FetchValues explicitly so there is no + # "post-update" + mapper(Hoho, default_t) + + h1 = Hoho(hoho="15", counter="15") + session = create_session() + session.add(h1) + session.flush() + + def go(): + eq_(h1.hoho, "15") + eq_(h1.counter, "15") + eq_(h1.foober, "im foober") + self.sql_count_(0, go) + + @testing.fails_on('firebird', 'Data type unknown on the parameter') + @testing.resolve_artifact_names + def test_update(self): + mapper(Hoho, default_t) + + h1 = Hoho() + session = create_session() + session.add(h1) + session.flush() + + eq_(h1.foober, 'im foober') + h1.counter = 19 + session.flush() + eq_(h1.foober, 'im the update') + + @testing.fails_on('firebird', 'Data type unknown on the parameter') + @testing.resolve_artifact_names + def test_used_in_relation(self): + """A server-side default can be used as the target of a foreign key""" + + mapper(Hoho, default_t, properties={ + 'secondaries':relation(Secondary)}) + mapper(Secondary, secondary_table) + + h1 = Hoho() + s1 = Secondary(data='s1') + h1.secondaries.append(s1) + + session = create_session() + session.add(h1) + session.flush() + session.expunge_all() + + eq_(session.query(Hoho).get(h1.id), + Hoho(hoho=hohoval, + secondaries=[ + Secondary(data='s1')])) + + h1 = session.query(Hoho).get(h1.id) + h1.secondaries.append(Secondary(data='s2')) + session.flush() + session.expunge_all() + + eq_(session.query(Hoho).get(h1.id), + Hoho(hoho=hohoval, + secondaries=[ + Secondary(data='s1'), + Secondary(data='s2')])) + +class ColumnPropertyTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('data', metadata, + Column('id', Integer, primary_key=True), + Column('a', String(50)), + Column('b', String(50)) + ) + + Table('subdata', metadata, + Column('id', Integer, ForeignKey('data.id'), primary_key=True), + Column('c', String(50)), + ) + + @classmethod + def setup_mappers(cls): + class Data(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def test_refreshes(self): + mapper(Data, data, properties={ + 'aplusb':column_property(data.c.a + literal_column("' '") + data.c.b) + }) + self._test() + + @testing.resolve_artifact_names + def test_refreshes_post_init(self): + m = mapper(Data, data) + m.add_property('aplusb', column_property(data.c.a + literal_column("' '") + data.c.b)) + self._test() + + @testing.resolve_artifact_names + def test_with_inheritance(self): + class SubData(Data): + pass + mapper(Data, data, properties={ + 'aplusb':column_property(data.c.a + literal_column("' '") + data.c.b) + }) + mapper(SubData, subdata, inherits=Data) + + sess = create_session() + sd1 = SubData(a="hello", b="there", c="hi") + sess.add(sd1) + sess.flush() + eq_(sd1.aplusb, "hello there") + + @testing.resolve_artifact_names + def _test(self): + sess = create_session() + + d1 = Data(a="hello", b="there") + sess.add(d1) + sess.flush() + + eq_(d1.aplusb, "hello there") + + d1.b = "bye" + sess.flush() + eq_(d1.aplusb, "hello bye") + + d1.b = 'foobar' + d1.aplusb = 'im setting this explicitly' + sess.flush() + eq_(d1.aplusb, "im setting this explicitly") + +class OneToManyTest(_fixtures.FixtureTest): + run_inserts = None + + @testing.resolve_artifact_names + def test_one_to_many_1(self): + """Basic save of one to many.""" + + m = mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=True) + )) + u = User(name= 'one2manytester') + a = Address(email_address='one2many@test.org') + u.addresses.append(a) + + a2 = Address(email_address='lala@test.org') + u.addresses.append(a2) + + session = create_session() + session.add(u) + session.flush() + + user_rows = users.select(users.c.id.in_([u.id])).execute().fetchall() + eq_(user_rows[0].values(), [u.id, 'one2manytester']) + + address_rows = addresses.select( + addresses.c.id.in_([a.id, a2.id]), + order_by=[addresses.c.email_address]).execute().fetchall() + eq_(address_rows[0].values(), [a2.id, u.id, 'lala@test.org']) + eq_(address_rows[1].values(), [a.id, u.id, 'one2many@test.org']) + + userid = u.id + addressid = a2.id + + a2.email_address = 'somethingnew@foo.com' + + session.flush() + + address_rows = addresses.select( + addresses.c.id == addressid).execute().fetchall() + eq_(address_rows[0].values(), + [addressid, userid, 'somethingnew@foo.com']) + self.assert_(u.id == userid and a2.id == addressid) + + @testing.resolve_artifact_names + def test_one_to_many_2(self): + """Modifying the child items of an object.""" + + m = mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=True))) + + u1 = User(name='user1') + u1.addresses = [] + a1 = Address(email_address='emailaddress1') + u1.addresses.append(a1) + + u2 = User(name='user2') + u2.addresses = [] + a2 = Address(email_address='emailaddress2') + u2.addresses.append(a2) + + a3 = Address(email_address='emailaddress3') + + session = create_session() + session.add_all((u1, u2, a3)) + session.flush() + + # modify user2 directly, append an address to user1. + # upon commit, user2 should be updated, user1 should not + # both address1 and address3 should be updated + u2.name = 'user2modified' + u1.addresses.append(a3) + del u1.addresses[0] + + self.assert_sql(testing.db, session.flush, [ + ("UPDATE users SET name=:name " + "WHERE users.id = :users_id", + {'users_id': u2.id, 'name': 'user2modified'}), + + ("UPDATE addresses SET user_id=:user_id " + "WHERE addresses.id = :addresses_id", + {'user_id': None, 'addresses_id': a1.id}), + + ("UPDATE addresses SET user_id=:user_id " + "WHERE addresses.id = :addresses_id", + {'user_id': u1.id, 'addresses_id': a3.id})]) + + @testing.resolve_artifact_names + def test_child_move(self): + """Moving a child from one parent to another, with a delete. + + Tests that deleting the first parent properly updates the child with + the new parent. This tests the 'trackparent' option in the attributes + module. + + """ + m = mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=True))) + + u1 = User(name='user1') + u2 = User(name='user2') + a = Address(email_address='address1') + u1.addresses.append(a) + + session = create_session() + session.add_all((u1, u2)) + session.flush() + + del u1.addresses[0] + u2.addresses.append(a) + session.delete(u1) + + session.flush() + session.expunge_all() + + u2 = session.query(User).get(u2.id) + eq_(len(u2.addresses), 1) + + @testing.resolve_artifact_names + def test_child_move_2(self): + m = mapper(User, users, properties=dict( + addresses = relation(mapper(Address, addresses), lazy=True))) + + u1 = User(name='user1') + u2 = User(name='user2') + a = Address(email_address='address1') + u1.addresses.append(a) + + session = create_session() + session.add_all((u1, u2)) + session.flush() + + del u1.addresses[0] + u2.addresses.append(a) + + session.flush() + session.expunge_all() + + u2 = session.query(User).get(u2.id) + eq_(len(u2.addresses), 1) + + @testing.resolve_artifact_names + def test_o2m_delete_parent(self): + m = mapper(User, users, properties=dict( + address = relation(mapper(Address, addresses), + lazy=True, + uselist=False))) + + u = User(name='one2onetester') + a = Address(email_address='myonlyaddress@foo.com') + u.address = a + + session = create_session() + session.add(u) + session.flush() + + session.delete(u) + session.flush() + + assert a.id is not None + assert a.user_id is None + assert sa.orm.attributes.instance_state(a).key in session.identity_map + assert sa.orm.attributes.instance_state(u).key not in session.identity_map + + @testing.resolve_artifact_names + def test_one_to_one(self): + m = mapper(User, users, properties=dict( + address = relation(mapper(Address, addresses), + lazy=True, + uselist=False))) + + u = User(name='one2onetester') + u.address = Address(email_address='myonlyaddress@foo.com') + + session = create_session() + session.add(u) + session.flush() + + u.name = 'imnew' + session.flush() + + u.address.email_address = 'imnew@foo.com' + session.flush() + + @testing.resolve_artifact_names + def test_bidirectional(self): + m1 = mapper(User, users) + m2 = mapper(Address, addresses, properties=dict( + user = relation(m1, lazy=False, backref='addresses'))) + + + u = User(name='test') + a = Address(email_address='testaddress', user=u) + + session = create_session() + session.add(u) + session.flush() + session.delete(u) + session.flush() + + @testing.resolve_artifact_names + def test_double_relation(self): + m2 = mapper(Address, addresses) + m = mapper(User, users, properties={ + 'boston_addresses' : relation(m2, primaryjoin= + sa.and_(users.c.id==addresses.c.user_id, + addresses.c.email_address.like('%boston%'))), + 'newyork_addresses' : relation(m2, primaryjoin= + sa.and_(users.c.id==addresses.c.user_id, + addresses.c.email_address.like('%newyork%')))}) + + u = User(name='u1') + a = Address(email_address='foo@boston.com') + b = Address(email_address='bar@newyork.com') + u.boston_addresses.append(a) + u.newyork_addresses.append(b) + + session = create_session() + session.add(u) + session.flush() + +class SaveTest(_fixtures.FixtureTest): + run_inserts = None + + @testing.resolve_artifact_names + def test_basic(self): + m = mapper(User, users) + + # save two users + u = User(name='savetester') + u2 = User(name='savetester2') + + session = create_session() + session.add_all((u, u2)) + session.flush() + + # assert the first one retreives the same from the identity map + nu = session.query(m).get(u.id) + assert u is nu + + # clear out the identity map, so next get forces a SELECT + session.expunge_all() + + # check it again, identity should be different but ids the same + nu = session.query(m).get(u.id) + assert u is not nu and u.id == nu.id and nu.name == 'savetester' + + # change first users name and save + session = create_session() + session.add(u) + u.name = 'modifiedname' + assert u in session.dirty + session.flush() + + # select both + userlist = session.query(User).filter( + users.c.id.in_([u.id, u2.id])).order_by([users.c.name]).all() + + eq_(u.id, userlist[0].id) + eq_(userlist[0].name, 'modifiedname') + eq_(u2.id, userlist[1].id) + eq_(userlist[1].name, 'savetester2') + + @testing.resolve_artifact_names + def test_synonym(self): + class SUser(_base.BasicEntity): + def _get_name(self): + return "User:" + self.name + def _set_name(self, name): + self.name = name + ":User" + syn_name = property(_get_name, _set_name) + + mapper(SUser, users, properties={ + 'syn_name': sa.orm.synonym('name') + }) + + u = SUser(syn_name="some name") + eq_(u.syn_name, 'User:some name:User') + + session = create_session() + session.add(u) + session.flush() + session.expunge_all() + + u = session.query(SUser).first() + eq_(u.syn_name, 'User:some name:User') + + @testing.resolve_artifact_names + def test_lazyattr_commit(self): + """Lazily loaded relations. + + When a lazy-loaded list is unloaded, and a commit occurs, that the + 'passive' call on that list does not blow away its value + + """ + mapper(User, users, properties = { + 'addresses': relation(mapper(Address, addresses))}) + + u = User(name='u1') + u.addresses.append(Address(email_address='u1@e1')) + u.addresses.append(Address(email_address='u1@e2')) + u.addresses.append(Address(email_address='u1@e3')) + u.addresses.append(Address(email_address='u1@e4')) + + session = create_session() + session.add(u) + session.flush() + session.expunge_all() + + u = session.query(User).one() + u.name = 'newname' + session.flush() + eq_(len(u.addresses), 4) + + @testing.resolve_artifact_names + def test_inherits(self): + m1 = mapper(User, users) + + class AddressUser(User): + """a user object that also has the users mailing address.""" + pass + + # define a mapper for AddressUser that inherits the User.mapper, and + # joins on the id column + mapper(AddressUser, addresses, inherits=m1) + + au = AddressUser(name='u', email_address='u@e') + + session = create_session() + session.add(au) + session.flush() + session.expunge_all() + + rt = session.query(AddressUser).one() + eq_(au.user_id, rt.user_id) + eq_(rt.id, rt.id) + + @testing.resolve_artifact_names + def test_deferred(self): + """Deferred column operations""" + + mapper(Order, orders, properties={ + 'description': sa.orm.deferred(orders.c.description)}) + + # dont set deferred attribute, commit session + o = Order(id=42) + session = create_session(autocommit=False) + session.add(o) + session.commit() + + # assert that changes get picked up + o.description = 'foo' + session.commit() + + eq_(list(session.execute(orders.select(), mapper=Order)), + [(42, None, None, 'foo', None)]) + session.expunge_all() + + # assert that a set operation doesn't trigger a load operation + o = session.query(Order).filter(Order.description == 'foo').one() + def go(): + o.description = 'hoho' + self.sql_count_(0, go) + session.flush() + + eq_(list(session.execute(orders.select(), mapper=Order)), + [(42, None, None, 'hoho', None)]) + + session.expunge_all() + + # test assigning None to an unloaded deferred also works + o = session.query(Order).filter(Order.description == 'hoho').one() + o.description = None + session.flush() + eq_(list(session.execute(orders.select(), mapper=Order)), + [(42, None, None, None, None)]) + session.close() + + # why no support on oracle ? because oracle doesn't save + # "blank" strings; it saves a single space character. + @testing.fails_on('oracle', 'FIXME: unknown') + @testing.resolve_artifact_names + def test_dont_update_blanks(self): + mapper(User, users) + + u = User(name='') + session = create_session() + session.add(u) + session.flush() + session.expunge_all() + + u = session.query(User).get(u.id) + u.name = '' + self.sql_count_(0, session.flush) + + @testing.resolve_artifact_names + def test_multi_table_selectable(self): + """Mapped selectables that span tables. + + Also tests redefinition of the keynames for the column properties. + + """ + usersaddresses = sa.join(users, addresses, + users.c.id == addresses.c.user_id) + + m = mapper(User, usersaddresses, + properties=dict( + email = addresses.c.email_address, + foo_id = [users.c.id, addresses.c.user_id])) + + u = User(name='multitester', email='multi@test.org') + session = create_session() + session.add(u) + session.flush() + session.expunge_all() + + id = m.primary_key_from_instance(u) + + u = session.query(User).get(id) + assert u.name == 'multitester' + + user_rows = users.select(users.c.id.in_([u.foo_id])).execute().fetchall() + eq_(user_rows[0].values(), [u.foo_id, 'multitester']) + address_rows = addresses.select(addresses.c.id.in_([u.id])).execute().fetchall() + eq_(address_rows[0].values(), [u.id, u.foo_id, 'multi@test.org']) + + u.email = 'lala@hey.com' + u.name = 'imnew' + session.flush() + + user_rows = users.select(users.c.id.in_([u.foo_id])).execute().fetchall() + eq_(user_rows[0].values(), [u.foo_id, 'imnew']) + address_rows = addresses.select(addresses.c.id.in_([u.id])).execute().fetchall() + eq_(address_rows[0].values(), [u.id, u.foo_id, 'lala@hey.com']) + + session.expunge_all() + u = session.query(User).get(id) + assert u.name == 'imnew' + + @testing.resolve_artifact_names + def test_history_get(self): + """The history lazy-fetches data when it wasn't otherwise loaded.""" + mapper(User, users, properties={ + 'addresses':relation(Address, cascade="all, delete-orphan")}) + mapper(Address, addresses) + + u = User(name='u1') + u.addresses.append(Address(email_address='u1@e1')) + u.addresses.append(Address(email_address='u1@e2')) + session = create_session() + session.add(u) + session.flush() + session.expunge_all() + + u = session.query(User).get(u.id) + session.delete(u) + session.flush() + assert users.count().scalar() == 0 + assert addresses.count().scalar() == 0 + + @testing.resolve_artifact_names + def test_batch_mode(self): + """The 'batch=False' flag on mapper()""" + + names = [] + class TestExtension(sa.orm.MapperExtension): + def before_insert(self, mapper, connection, instance): + self.current_instance = instance + names.append(instance.name) + def after_insert(self, mapper, connection, instance): + assert instance is self.current_instance + + mapper(User, users, extension=TestExtension(), batch=False) + u1 = User(name='user1') + u2 = User(name='user2') + + session = create_session() + session.add_all((u1, u2)) + session.flush() + + u3 = User(name='user3') + u4 = User(name='user4') + u5 = User(name='user5') + + session.add_all([u4, u5, u3]) + session.flush() + + # test insert ordering is maintained + assert names == ['user1', 'user2', 'user4', 'user5', 'user3'] + session.expunge_all() + + sa.orm.clear_mappers() + + m = mapper(User, users, extension=TestExtension()) + u1 = User(name='user1') + u2 = User(name='user2') + session.add_all((u1, u2)) + assert_raises(AssertionError, session.flush) + + +class ManyToOneTest(_fixtures.FixtureTest): + run_inserts = None + + @testing.resolve_artifact_names + def test_m2o_one_to_one(self): + # TODO: put assertion in here !!! + m = mapper(Address, addresses, properties=dict( + user = relation(mapper(User, users), lazy=True, uselist=False))) + + session = create_session() + + data = [ + {'name': 'thesub' , 'email_address': 'bar@foo.com'}, + {'name': 'assdkfj' , 'email_address': 'thesdf@asdf.com'}, + {'name': 'n4knd' , 'email_address': 'asf3@bar.org'}, + {'name': 'v88f4' , 'email_address': 'adsd5@llala.net'}, + {'name': 'asdf8d' , 'email_address': 'theater@foo.com'} + ] + objects = [] + for elem in data: + a = Address() + a.email_address = elem['email_address'] + a.user = User() + a.user.name = elem['name'] + objects.append(a) + session.add(a) + + session.flush() + + objects[2].email_address = 'imnew@foo.bar' + objects[3].user = User() + objects[3].user.name = 'imnewlyadded' + self.assert_sql_execution(testing.db, + session.flush, + CompiledSQL("INSERT INTO users (name) VALUES (:name)", + {'name': 'imnewlyadded'} ), + + AllOf( + CompiledSQL("UPDATE addresses SET email_address=:email_address " + "WHERE addresses.id = :addresses_id", + lambda ctx: {'email_address': 'imnew@foo.bar', + 'addresses_id': objects[2].id}), + CompiledSQL("UPDATE addresses SET user_id=:user_id " + "WHERE addresses.id = :addresses_id", + lambda ctx: {'user_id': objects[3].user.id, + 'addresses_id': objects[3].id}) + ) + ) + + l = sa.select([users, addresses], + sa.and_(users.c.id==addresses.c.user_id, + addresses.c.id==a.id)).execute() + eq_(l.fetchone().values(), + [a.user.id, 'asdf8d', a.id, a.user_id, 'theater@foo.com']) + + @testing.resolve_artifact_names + def test_many_to_one_1(self): + m = mapper(Address, addresses, properties=dict( + user = relation(mapper(User, users), lazy=True))) + + a1 = Address(email_address='emailaddress1') + u1 = User(name='user1') + a1.user = u1 + + session = create_session() + session.add(a1) + session.flush() + session.expunge_all() + + a1 = session.query(Address).get(a1.id) + u1 = session.query(User).get(u1.id) + assert a1.user is u1 + + a1.user = None + session.flush() + session.expunge_all() + a1 = session.query(Address).get(a1.id) + u1 = session.query(User).get(u1.id) + assert a1.user is None + + @testing.resolve_artifact_names + def test_many_to_one_2(self): + m = mapper(Address, addresses, properties=dict( + user = relation(mapper(User, users), lazy=True))) + + a1 = Address(email_address='emailaddress1') + a2 = Address(email_address='emailaddress2') + u1 = User(name='user1') + a1.user = u1 + + session = create_session() + session.add_all((a1, a2)) + session.flush() + session.expunge_all() + + a1 = session.query(Address).get(a1.id) + a2 = session.query(Address).get(a2.id) + u1 = session.query(User).get(u1.id) + assert a1.user is u1 + + a1.user = None + a2.user = u1 + session.flush() + session.expunge_all() + + a1 = session.query(Address).get(a1.id) + a2 = session.query(Address).get(a2.id) + u1 = session.query(User).get(u1.id) + assert a1.user is None + assert a2.user is u1 + + @testing.resolve_artifact_names + def test_many_to_one_3(self): + m = mapper(Address, addresses, properties=dict( + user = relation(mapper(User, users), lazy=True))) + + a1 = Address(email_address='emailaddress1') + u1 = User(name='user1') + u2 = User(name='user2') + a1.user = u1 + + session = create_session() + session.add_all((a1, u1, u2)) + session.flush() + session.expunge_all() + + a1 = session.query(Address).get(a1.id) + u1 = session.query(User).get(u1.id) + u2 = session.query(User).get(u2.id) + assert a1.user is u1 + + a1.user = u2 + session.flush() + session.expunge_all() + a1 = session.query(Address).get(a1.id) + u1 = session.query(User).get(u1.id) + u2 = session.query(User).get(u2.id) + assert a1.user is u2 + + @testing.resolve_artifact_names + def test_bidirectional_no_load(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=None)}) + mapper(Address, addresses) + + # try it on unsaved objects + u1 = User(name='u1') + a1 = Address(email_address='e1') + a1.user = u1 + + session = create_session() + session.add(u1) + session.flush() + session.expunge_all() + + a1 = session.query(Address).get(a1.id) + + a1.user = None + session.flush() + session.expunge_all() + assert session.query(Address).get(a1.id).user is None + assert session.query(User).get(u1.id).addresses == [] + + +class ManyToManyTest(_fixtures.FixtureTest): + run_inserts = None + + @testing.resolve_artifact_names + def test_many_to_many(self): + mapper(Keyword, keywords) + + m = mapper(Item, items, properties=dict( + keywords=relation(Keyword, + item_keywords, + lazy=False, + order_by=keywords.c.name))) + + data = [Item, + {'description': 'mm_item1', + 'keywords' : (Keyword, [{'name': 'big'}, + {'name': 'green'}, + {'name': 'purple'}, + {'name': 'round'}])}, + {'description': 'mm_item2', + 'keywords' : (Keyword, [{'name':'blue'}, + {'name':'imnew'}, + {'name':'round'}, + {'name':'small'}])}, + {'description': 'mm_item3', + 'keywords' : (Keyword, [])}, + {'description': 'mm_item4', + 'keywords' : (Keyword, [{'name':'big'}, + {'name':'blue'},])}, + {'description': 'mm_item5', + 'keywords' : (Keyword, [{'name':'big'}, + {'name':'exacting'}, + {'name':'green'}])}, + {'description': 'mm_item6', + 'keywords' : (Keyword, [{'name':'red'}, + {'name':'round'}, + {'name':'small'}])}] + + session = create_session() + + objects = [] + _keywords = dict([(k.name, k) for k in session.query(Keyword)]) + + for elem in data[1:]: + item = Item(description=elem['description']) + objects.append(item) + + for spec in elem['keywords'][1]: + keyword_name = spec['name'] + try: + kw = _keywords[keyword_name] + except KeyError: + _keywords[keyword_name] = kw = Keyword(name=keyword_name) + item.keywords.append(kw) + + session.add_all(objects) + session.flush() + + l = (session.query(Item). + filter(Item.description.in_([e['description'] + for e in data[1:]])). + order_by(Item.description).all()) + self.assert_result(l, *data) + + objects[4].description = 'item4updated' + k = Keyword() + k.name = 'yellow' + objects[5].keywords.append(k) + self.assert_sql_execution( + testing.db, + session.flush, + AllOf( + CompiledSQL("UPDATE items SET description=:description " + "WHERE items.id = :items_id", + {'description': 'item4updated', + 'items_id': objects[4].id}, + ), + CompiledSQL("INSERT INTO keywords (name) " + "VALUES (:name)", + {'name': 'yellow'}, + ) + ), + CompiledSQL("INSERT INTO item_keywords (item_id, keyword_id) " + "VALUES (:item_id, :keyword_id)", + lambda ctx: [{'item_id': objects[5].id, + 'keyword_id': k.id}]) + ) + + objects[2].keywords.append(k) + dkid = objects[5].keywords[1].id + del objects[5].keywords[1] + self.assert_sql_execution( + testing.db, + session.flush, + CompiledSQL("DELETE FROM item_keywords " + "WHERE item_keywords.item_id = :item_id AND " + "item_keywords.keyword_id = :keyword_id", + [{'item_id': objects[5].id, 'keyword_id': dkid}]), + CompiledSQL("INSERT INTO item_keywords (item_id, keyword_id) " + "VALUES (:item_id, :keyword_id)", + lambda ctx: [{'item_id': objects[2].id, 'keyword_id': k.id}] + )) + + session.delete(objects[3]) + session.flush() + + @testing.resolve_artifact_names + def test_many_to_many_remove(self): + """Setting a collection to empty deletes many-to-many rows. + + Tests that setting a list-based attribute to '[]' properly affects the + history and allows the many-to-many rows to be deleted + + """ + mapper(Keyword, keywords) + mapper(Item, items, properties=dict( + keywords = relation(Keyword, item_keywords, lazy=False), + )) + + i = Item(description='i1') + k1 = Keyword(name='k1') + k2 = Keyword(name='k2') + i.keywords.append(k1) + i.keywords.append(k2) + + session = create_session() + session.add(i) + session.flush() + + assert item_keywords.count().scalar() == 2 + i.keywords = [] + session.flush() + assert item_keywords.count().scalar() == 0 + + @testing.resolve_artifact_names + def test_scalar(self): + """sa.dependency won't delete an m2m relation referencing None.""" + + mapper(Keyword, keywords) + + mapper(Item, items, properties=dict( + keyword=relation(Keyword, secondary=item_keywords, uselist=False))) + + i = Item(description='x') + session = create_session() + session.add(i) + session.flush() + session.delete(i) + session.flush() + + @testing.resolve_artifact_names + def test_many_to_many_update(self): + """Assorted history operations on a many to many""" + mapper(Keyword, keywords) + mapper(Item, items, properties=dict( + keywords=relation(Keyword, + secondary=item_keywords, + lazy=False, + order_by=keywords.c.name))) + + k1 = Keyword(name='keyword 1') + k2 = Keyword(name='keyword 2') + k3 = Keyword(name='keyword 3') + + item = Item(description='item 1') + item.keywords.extend([k1, k2, k3]) + + session = create_session() + session.add(item) + session.flush() + + item.keywords = [] + item.keywords.append(k1) + item.keywords.append(k2) + session.flush() + + session.expunge_all() + item = session.query(Item).get(item.id) + assert item.keywords == [k1, k2] + + @testing.resolve_artifact_names + def test_association(self): + """Basic test of an association object""" + + class IKAssociation(_base.ComparableEntity): + pass + + mapper(Keyword, keywords) + + # note that we are breaking a rule here, making a second + # mapper(Keyword, keywords) the reorganization of mapper construction + # affected this, but was fixed again + + mapper(IKAssociation, item_keywords, + primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id], + properties=dict( + keyword=relation(mapper(Keyword, keywords, non_primary=True), + lazy=False, + uselist=False, + order_by=keywords.c.name # note here is a valid place where order_by can be used + ))) # on a scalar relation(); to determine eager ordering of + # the parent object within its collection. + + mapper(Item, items, properties=dict( + keywords=relation(IKAssociation, lazy=False))) + + session = create_session() + + def fixture(): + _kw = dict([(k.name, k) for k in session.query(Keyword)]) + for n in ('big', 'green', 'purple', 'round', 'huge', + 'violet', 'yellow', 'blue'): + if n not in _kw: + _kw[n] = Keyword(name=n) + + def assocs(*names): + return [IKAssociation(keyword=kw) + for kw in [_kw[n] for n in names]] + + return [ + Item(description='a_item1', + keywords=assocs('big', 'green', 'purple', 'round')), + Item(description='a_item2', + keywords=assocs('huge', 'violet', 'yellow')), + Item(description='a_item3', + keywords=assocs('big', 'blue'))] + + session.add_all(fixture()) + session.flush() + eq_(fixture(), session.query(Item).order_by(Item.description).all()) + + +class SaveTest2(_fixtures.FixtureTest): + run_inserts = None + + @testing.resolve_artifact_names + def test_m2o_nonmatch(self): + mapper(User, users) + mapper(Address, addresses, properties=dict( + user = relation(User, lazy=True, uselist=False))) + + session = create_session() + + def fixture(): + return [ + Address(email_address='a1', user=User(name='u1')), + Address(email_address='a2', user=User(name='u2'))] + + session.add_all(fixture()) + + self.assert_sql_execution( + testing.db, + session.flush, + CompiledSQL("INSERT INTO users (name) VALUES (:name)", + {'name': 'u1'}), + CompiledSQL("INSERT INTO users (name) VALUES (:name)", + {'name': 'u2'}), + CompiledSQL("INSERT INTO addresses (user_id, email_address) " + "VALUES (:user_id, :email_address)", + {'user_id': 1, 'email_address': 'a1'}), + CompiledSQL("INSERT INTO addresses (user_id, email_address) " + "VALUES (:user_id, :email_address)", + {'user_id': 2, 'email_address': 'a2'}), + ) + +class SaveTest3(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('items', metadata, + Column('item_id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('item_name', String(50))) + + Table('keywords', metadata, + Column('keyword_id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(50))) + + Table('assoc', metadata, + Column('item_id', Integer, ForeignKey("items")), + Column('keyword_id', Integer, ForeignKey("keywords")), + Column('foo', sa.Boolean, default=True)) + + @classmethod + def setup_classes(cls): + class Keyword(_base.BasicEntity): + pass + class Item(_base.BasicEntity): + pass + + @testing.resolve_artifact_names + def test_manytomany_xtracol_delete(self): + """A many-to-many on a table that has an extra column can properly delete rows from the table without referencing the extra column""" + + mapper(Keyword, keywords) + mapper(Item, items, properties=dict( + keywords = relation(Keyword, secondary=assoc, lazy=False),)) + + i = Item() + k1 = Keyword() + k2 = Keyword() + i.keywords.append(k1) + i.keywords.append(k2) + + session = create_session() + session.add(i) + session.flush() + + assert assoc.count().scalar() == 2 + i.keywords = [] + print i.keywords + session.flush() + assert assoc.count().scalar() == 0 + +class BooleanColTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('t1_t', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('name', String(30)), + Column('value', sa.Boolean)) + + @testing.resolve_artifact_names + def test_boolean(self): + # use the regular mapper + class T(_base.ComparableEntity): + pass + orm_mapper(T, t1_t, order_by=t1_t.c.id) + + sess = create_session() + t1 = T(value=True, name="t1") + t2 = T(value=False, name="t2") + t3 = T(value=True, name="t3") + sess.add_all((t1, t2, t3)) + + sess.flush() + + for clear in (False, True): + if clear: + sess.expunge_all() + eq_(sess.query(T).all(), [T(value=True, name="t1"), T(value=False, name="t2"), T(value=True, name="t3")]) + if clear: + sess.expunge_all() + eq_(sess.query(T).filter(T.value==True).all(), [T(value=True, name="t1"),T(value=True, name="t3")]) + if clear: + sess.expunge_all() + eq_(sess.query(T).filter(T.value==False).all(), [T(value=False, name="t2")]) + + t2 = sess.query(T).get(t2.id) + t2.value = True + sess.flush() + eq_(sess.query(T).filter(T.value==True).all(), [T(value=True, name="t1"), T(value=True, name="t2"), T(value=True, name="t3")]) + t2.value = False + sess.flush() + eq_(sess.query(T).filter(T.value==True).all(), [T(value=True, name="t1"),T(value=True, name="t3")]) + + +class RowSwitchTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + # parent + Table('t5', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30), nullable=False)) + + # onetomany + Table('t6', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30), nullable=False), + Column('t5id', Integer, ForeignKey('t5.id'),nullable=False)) + + # associated + Table('t7', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30), nullable=False)) + + #manytomany + Table('t5t7', metadata, + Column('t5id', Integer, ForeignKey('t5.id'),nullable=False), + Column('t7id', Integer, ForeignKey('t7.id'),nullable=False)) + + @classmethod + def setup_classes(cls): + class T5(_base.ComparableEntity): + pass + + class T6(_base.ComparableEntity): + pass + + class T7(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def test_onetomany(self): + mapper(T5, t5, properties={ + 't6s':relation(T6, cascade="all, delete-orphan") + }) + mapper(T6, t6) + + sess = create_session() + + o5 = T5(data='some t5', id=1) + o5.t6s.append(T6(data='some t6', id=1)) + o5.t6s.append(T6(data='some other t6', id=2)) + + sess.add(o5) + sess.flush() + + assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some t5')] + assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some t6', 1), (2, 'some other t6', 1)] + + o6 = T5(data='some other t5', id=o5.id, t6s=[ + T6(data='third t6', id=3), + T6(data='fourth t6', id=4), + ]) + sess.delete(o5) + sess.add(o6) + sess.flush() + + assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some other t5')] + assert list(sess.execute(t6.select(), mapper=T5)) == [(3, 'third t6', 1), (4, 'fourth t6', 1)] + + @testing.resolve_artifact_names + def test_manytomany(self): + mapper(T5, t5, properties={ + 't7s':relation(T7, secondary=t5t7, cascade="all") + }) + mapper(T7, t7) + + sess = create_session() + + o5 = T5(data='some t5', id=1) + o5.t7s.append(T7(data='some t7', id=1)) + o5.t7s.append(T7(data='some other t7', id=2)) + + sess.add(o5) + sess.flush() + + assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some t5')] + assert testing.rowset(sess.execute(t5t7.select(), mapper=T5)) == set([(1,1), (1, 2)]) + assert list(sess.execute(t7.select(), mapper=T5)) == [(1, 'some t7'), (2, 'some other t7')] + + o6 = T5(data='some other t5', id=1, t7s=[ + T7(data='third t7', id=3), + T7(data='fourth t7', id=4), + ]) + sess.delete(o5) + sess.add(o6) + sess.flush() + + assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some other t5')] + assert list(sess.execute(t7.select(), mapper=T5)) == [(3, 'third t7'), (4, 'fourth t7')] + + @testing.resolve_artifact_names + def test_manytoone(self): + + mapper(T6, t6, properties={ + 't5':relation(T5) + }) + mapper(T5, t5) + + sess = create_session() + + o5 = T6(data='some t6', id=1) + o5.t5 = T5(data='some t5', id=1) + + sess.add(o5) + sess.flush() + + assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some t5')] + assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some t6', 1)] + + o6 = T6(data='some other t6', id=1, t5=T5(data='some other t5', id=2)) + sess.delete(o5) + sess.delete(o5.t5) + sess.add(o6) + sess.flush() + + assert list(sess.execute(t5.select(), mapper=T5)) == [(2, 'some other t5')] + assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some other t6', 2)] + +class InheritingRowSwitchTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('parent', metadata, + Column('id', Integer, primary_key=True), + Column('pdata', String(30)) + ) + Table('child', metadata, + Column('id', Integer, primary_key=True), + Column('pid', Integer, ForeignKey('parent.id')), + Column('cdata', String(30)) + ) + + @classmethod + def setup_classes(cls): + class P(_base.ComparableEntity): + pass + + class C(P): + pass + + @testing.resolve_artifact_names + def test_row_switch_no_child_table(self): + mapper(P, parent) + mapper(C, child, inherits=P) + + sess = create_session() + c1 = C(id=1, pdata='c1', cdata='c1') + sess.add(c1) + sess.flush() + + # establish a row switch between c1 and c2. + # c2 has no value for the "child" table + c2 = C(id=1, pdata='c2') + sess.add(c2) + sess.delete(c1) + + self.assert_sql_execution(testing.db, sess.flush, + CompiledSQL("UPDATE parent SET pdata=:pdata WHERE parent.id = :parent_id", + {'pdata':'c2', 'parent_id':1} + ) + ) + + + +class TransactionTest(_base.MappedTest): + __requires__ = ('deferrable_constraints',) + + __whitelist__ = ('sqlite',) + # sqlite doesn't have deferrable constraints, but it allows them to + # be specified. it'll raise immediately post-INSERT, instead of at + # COMMIT. either way, this test should pass. + + @classmethod + def define_tables(cls, metadata): + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True)) + + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('t1_id', Integer, + ForeignKey('t1.id', deferrable=True, initially='deferred') + )) + @classmethod + def setup_classes(cls): + class T1(_base.ComparableEntity): + pass + + class T2(_base.ComparableEntity): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + orm_mapper(T1, t1) + orm_mapper(T2, t2) + + @testing.resolve_artifact_names + def test_close_transaction_on_commit_fail(self): + session = create_session(autocommit=True) + + # with a deferred constraint, this fails at COMMIT time instead + # of at INSERT time. + session.add(T2(t1_id=123)) + + try: + session.flush() + assert False + except: + # Flush needs to rollback also when commit fails + assert session.transaction is None + + # todo: on 8.3 at least, the failed commit seems to close the cursor? + # needs investigation. leaving in the DDL above now to help verify + # that the new deferrable support on FK isn't involved in this issue. + if testing.against('postgres'): + t1.bind.engine.dispose() + diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py new file mode 100644 index 000000000..06533a243 --- /dev/null +++ b/test/orm/test_utils.py @@ -0,0 +1,239 @@ +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy.orm import interfaces, util +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import Table +from sqlalchemy.orm import aliased +from sqlalchemy.orm import mapper, create_session + + +from sqlalchemy.test import TestBase, testing + +from test.orm import _fixtures +from sqlalchemy.test.testing import eq_ + + +class ExtensionCarrierTest(TestBase): + def test_basic(self): + carrier = util.ExtensionCarrier() + + assert 'translate_row' not in carrier + assert carrier.translate_row() is interfaces.EXT_CONTINUE + assert 'translate_row' not in carrier + + assert_raises(AttributeError, lambda: carrier.snickysnack) + + class Partial(object): + def __init__(self, marker): + self.marker = marker + def translate_row(self, row): + return self.marker + + carrier.append(Partial('end')) + assert 'translate_row' in carrier + assert carrier.translate_row(None) == 'end' + + carrier.push(Partial('front')) + assert carrier.translate_row(None) == 'front' + + assert 'populate_instance' not in carrier + carrier.append(interfaces.MapperExtension) + assert 'populate_instance' in carrier + + assert carrier.interface + for m in carrier.interface: + assert getattr(interfaces.MapperExtension, m) + +class AliasedClassTest(TestBase): + def point_map(self, cls): + table = Table('point', MetaData(), + Column('id', Integer(), primary_key=True), + Column('x', Integer), + Column('y', Integer)) + mapper(cls, table) + return table + + def test_simple(self): + class Point(object): + pass + table = self.point_map(Point) + + alias = aliased(Point) + + assert alias.id + assert alias.x + assert alias.y + + assert Point.id.__clause_element__().table is table + assert alias.id.__clause_element__().table is not table + + def test_notcallable(self): + class Point(object): + pass + table = self.point_map(Point) + alias = aliased(Point) + + assert_raises(TypeError, alias) + + def test_instancemethods(self): + class Point(object): + def zero(self): + self.x, self.y = 0, 0 + + table = self.point_map(Point) + alias = aliased(Point) + + assert Point.zero + assert not getattr(alias, 'zero') + + def test_classmethods(self): + class Point(object): + @classmethod + def max_x(cls): + return 100 + + table = self.point_map(Point) + alias = aliased(Point) + + assert Point.max_x + assert alias.max_x + assert Point.max_x() == alias.max_x() + + def test_simpleproperties(self): + class Point(object): + @property + def max_x(self): + return 100 + + table = self.point_map(Point) + alias = aliased(Point) + + assert Point.max_x + assert Point.max_x != 100 + assert alias.max_x + assert Point.max_x is alias.max_x + + def test_descriptors(self): + class descriptor(object): + """Tortured...""" + def __init__(self, fn): + self.fn = fn + def __get__(self, obj, owner): + if obj is not None: + return self.fn(obj, obj) + else: + return self + def method(self): + return 'method' + + class Point(object): + center = (0, 0) + @descriptor + def thing(self, arg): + return arg.center + + table = self.point_map(Point) + alias = aliased(Point) + + assert Point.thing != (0, 0) + assert Point().thing == (0, 0) + assert Point.thing.method() == 'method' + + assert alias.thing != (0, 0) + assert alias.thing.method() == 'method' + + def test_hybrid_descriptors(self): + from sqlalchemy import Column # override testlib's override + import types + + class MethodDescriptor(object): + def __init__(self, func): + self.func = func + def __get__(self, instance, owner): + if instance is None: + args = (self.func, owner, owner.__class__) + else: + args = (self.func, instance, owner) + return types.MethodType(*args) + + class PropertyDescriptor(object): + def __init__(self, fget, fset, fdel): + self.fget = fget + self.fset = fset + self.fdel = fdel + def __get__(self, instance, owner): + if instance is None: + return self.fget(owner) + else: + return self.fget(instance) + def __set__(self, instance, value): + self.fset(instance, value) + def __delete__(self, instance): + self.fdel(instance) + hybrid = MethodDescriptor + def hybrid_property(fget, fset=None, fdel=None): + return PropertyDescriptor(fget, fset, fdel) + + def assert_table(expr, table): + for child in expr.get_children(): + if isinstance(child, Column): + assert child.table is table + + class Point(object): + def __init__(self, x, y): + self.x, self.y = x, y + @hybrid + def left_of(self, other): + return self.x < other.x + + double_x = hybrid_property(lambda self: self.x * 2) + + table = self.point_map(Point) + alias = aliased(Point) + alias_table = alias.x.__clause_element__().table + assert table is not alias_table + + p1 = Point(-10, -10) + p2 = Point(20, 20) + + assert p1.left_of(p2) + assert p1.double_x == -20 + + assert_table(Point.double_x, table) + assert_table(alias.double_x, alias_table) + + assert_table(Point.left_of(p2), table) + assert_table(alias.left_of(p2), alias_table) + +class IdentityKeyTest(_fixtures.FixtureTest): + run_inserts = None + + @testing.resolve_artifact_names + def test_identity_key_1(self): + mapper(User, users) + + key = util.identity_key(User, 1) + eq_(key, (User, (1,))) + key = util.identity_key(User, ident=1) + eq_(key, (User, (1,))) + + @testing.resolve_artifact_names + def test_identity_key_2(self): + mapper(User, users) + s = create_session() + u = User(name='u1') + s.add(u) + s.flush() + key = util.identity_key(instance=u) + eq_(key, (User, (u.id,))) + + @testing.resolve_artifact_names + def test_identity_key_3(self): + mapper(User, users) + + row = {users.c.id: 1, users.c.name: "Frank"} + key = util.identity_key(User, row=row) + eq_(key, (User, (1,))) + + diff --git a/test/orm/transaction.py b/test/orm/transaction.py deleted file mode 100644 index 0fcd55df3..000000000 --- a/test/orm/transaction.py +++ /dev/null @@ -1,499 +0,0 @@ -import testenv; testenv.configure_for_tests() - -from sqlalchemy import * -from sqlalchemy.orm import attributes -from sqlalchemy import exc as sa_exc -from sqlalchemy.orm import * - -from testlib import testing -from orm import _base -from orm._fixtures import FixtureTest, User, Address, users, addresses - -import gc - -class TransactionTest(FixtureTest): - run_setup_mappers = 'once' - run_inserts = None - session = sessionmaker() - - def setup_mappers(self): - mapper(User, users, properties={ - 'addresses':relation(Address, backref='user', - cascade="all, delete-orphan"), - }) - mapper(Address, addresses) - - -class FixtureDataTest(TransactionTest): - run_inserts = 'each' - - def test_attrs_on_rollback(self): - sess = self.session() - u1 = sess.query(User).get(7) - u1.name = 'ed' - sess.rollback() - self.assertEquals(u1.name, 'jack') - - def test_commit_persistent(self): - sess = self.session() - u1 = sess.query(User).get(7) - u1.name = 'ed' - sess.flush() - sess.commit() - self.assertEquals(u1.name, 'ed') - - def test_concurrent_commit_persistent(self): - s1 = self.session() - u1 = s1.query(User).get(7) - u1.name = 'ed' - s1.commit() - - s2 = self.session() - u2 = s2.query(User).get(7) - assert u2.name == 'ed' - u2.name = 'will' - s2.commit() - - assert u1.name == 'will' - -class AutoExpireTest(TransactionTest): - - def test_expunge_pending_on_rollback(self): - sess = self.session() - u2= User(name='newuser') - sess.add(u2) - assert u2 in sess - sess.rollback() - assert u2 not in sess - - def test_trans_pending_cleared_on_commit(self): - sess = self.session() - u2= User(name='newuser') - sess.add(u2) - assert u2 in sess - sess.commit() - assert u2 in sess - u3 = User(name='anotheruser') - sess.add(u3) - sess.rollback() - assert u3 not in sess - assert u2 in sess - - def test_update_deleted_on_rollback(self): - s = self.session() - u1 = User(name='ed') - s.add(u1) - s.commit() - - # this actually tests that the delete() operation, - # when cascaded to the "addresses" collection, does not - # trigger a flush (via lazyload) before the cascade is complete. - s.delete(u1) - assert u1 in s.deleted - s.rollback() - assert u1 in s - assert u1 not in s.deleted - - def test_gced_delete_on_rollback(self): - s = self.session() - u1 = User(name='ed') - s.add(u1) - s.commit() - - s.delete(u1) - u1_state = attributes.instance_state(u1) - assert u1_state in s.identity_map.all_states() - assert u1_state in s._deleted - s.flush() - assert u1_state not in s.identity_map.all_states() - assert u1_state not in s._deleted - del u1 - gc.collect() - assert u1_state.obj() is None - - s.rollback() - assert u1_state in s.identity_map.all_states() - u1 = s.query(User).filter_by(name='ed').one() - assert u1_state not in s.identity_map.all_states() - assert s.scalar(users.count()) == 1 - s.delete(u1) - s.flush() - assert s.scalar(users.count()) == 0 - s.commit() - - def test_trans_deleted_cleared_on_rollback(self): - s = self.session() - u1 = User(name='ed') - s.add(u1) - s.commit() - - s.delete(u1) - s.commit() - assert u1 not in s - s.rollback() - assert u1 not in s - - def test_update_deleted_on_rollback_cascade(self): - s = self.session() - u1 = User(name='ed', addresses=[Address(email_address='foo')]) - s.add(u1) - s.commit() - - s.delete(u1) - assert u1 in s.deleted - assert u1.addresses[0] in s.deleted - s.rollback() - assert u1 in s - assert u1 not in s.deleted - assert u1.addresses[0] not in s.deleted - - def test_update_deleted_on_rollback_orphan(self): - s = self.session() - u1 = User(name='ed', addresses=[Address(email_address='foo')]) - s.add(u1) - s.commit() - - a1 = u1.addresses[0] - u1.addresses.remove(a1) - - s.flush() - self.assertEquals(s.query(Address).filter(Address.email_address=='foo').all(), []) - s.rollback() - assert a1 not in s.deleted - assert u1.addresses == [a1] - - def test_commit_pending(self): - sess = self.session() - u1 = User(name='newuser') - sess.add(u1) - sess.flush() - sess.commit() - self.assertEquals(u1.name, 'newuser') - - - def test_concurrent_commit_pending(self): - s1 = self.session() - u1 = User(name='edward') - s1.add(u1) - s1.commit() - - s2 = self.session() - u2 = s2.query(User).filter(User.name=='edward').one() - u2.name = 'will' - s2.commit() - - assert u1.name == 'will' - -class TwoPhaseTest(TransactionTest): - - @testing.requires.two_phase_transactions - def test_rollback_on_prepare(self): - s = self.session(twophase=True) - - u = User(name='ed') - s.add(u) - s.prepare() - s.rollback() - - assert u not in s - -class RollbackRecoverTest(TransactionTest): - - def test_pk_violation(self): - s = self.session() - a1 = Address(email_address='foo') - u1 = User(id=1, name='ed', addresses=[a1]) - s.add(u1) - s.commit() - - a2 = Address(email_address='bar') - u2 = User(id=1, name='jack', addresses=[a2]) - - u1.name = 'edward' - a1.email_address = 'foober' - s.add(u2) - self.assertRaises(sa_exc.FlushError, s.commit) - self.assertRaises(sa_exc.InvalidRequestError, s.commit) - s.rollback() - assert u2 not in s - assert a2 not in s - assert u1 in s - assert a1 in s - assert u1.name == 'ed' - assert a1.email_address == 'foo' - u1.name = 'edward' - a1.email_address = 'foober' - s.commit() - self.assertEquals( - s.query(User).all(), - [User(id=1, name='edward', addresses=[Address(email_address='foober')])] - ) - - @testing.requires.savepoints - def test_pk_violation_with_savepoint(self): - s = self.session() - a1 = Address(email_address='foo') - u1 = User(id=1, name='ed', addresses=[a1]) - s.add(u1) - s.commit() - - a2 = Address(email_address='bar') - u2 = User(id=1, name='jack', addresses=[a2]) - - u1.name = 'edward' - a1.email_address = 'foober' - s.begin_nested() - s.add(u2) - self.assertRaises(sa_exc.FlushError, s.commit) - self.assertRaises(sa_exc.InvalidRequestError, s.commit) - s.rollback() - assert u2 not in s - assert a2 not in s - assert u1 in s - assert a1 in s - - s.commit() - assert s.query(User).all() == [User(id=1, name='edward', addresses=[Address(email_address='foober')])] - - -class SavepointTest(TransactionTest): - - @testing.requires.savepoints - def test_savepoint_rollback(self): - s = self.session() - u1 = User(name='ed') - u2 = User(name='jack') - s.add_all([u1, u2]) - - s.begin_nested() - u3 = User(name='wendy') - u4 = User(name='foo') - u1.name = 'edward' - u2.name = 'jackward' - s.add_all([u3, u4]) - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) - s.rollback() - assert u1.name == 'ed' - assert u2.name == 'jack' - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) - s.commit() - assert u1.name == 'ed' - assert u2.name == 'jack' - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) - - @testing.requires.savepoints - def test_savepoint_delete(self): - s = self.session() - u1 = User(name='ed') - s.add(u1) - s.commit() - self.assertEquals(s.query(User).filter_by(name='ed').count(), 1) - s.begin_nested() - s.delete(u1) - s.commit() - self.assertEquals(s.query(User).filter_by(name='ed').count(), 0) - s.commit() - - @testing.requires.savepoints - def test_savepoint_commit(self): - s = self.session() - u1 = User(name='ed') - u2 = User(name='jack') - s.add_all([u1, u2]) - - s.begin_nested() - u3 = User(name='wendy') - u4 = User(name='foo') - u1.name = 'edward' - u2.name = 'jackward' - s.add_all([u3, u4]) - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) - s.commit() - def go(): - assert u1.name == 'edward' - assert u2.name == 'jackward' - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) - self.assert_sql_count(testing.db, go, 1) - - s.commit() - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) - - @testing.requires.savepoints - def test_savepoint_rollback_collections(self): - s = self.session() - u1 = User(name='ed', addresses=[Address(email_address='foo')]) - s.add(u1) - s.commit() - - u1.name='edward' - u1.addresses.append(Address(email_address='bar')) - s.begin_nested() - u2 = User(name='jack', addresses=[Address(email_address='bat')]) - s.add(u2) - self.assertEquals(s.query(User).order_by(User.id).all(), - [ - User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), - User(name='jack', addresses=[Address(email_address='bat')]) - ] - ) - s.rollback() - self.assertEquals(s.query(User).order_by(User.id).all(), - [ - User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), - ] - ) - s.commit() - self.assertEquals(s.query(User).order_by(User.id).all(), - [ - User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), - ] - ) - - @testing.requires.savepoints - def test_savepoint_commit_collections(self): - s = self.session() - u1 = User(name='ed', addresses=[Address(email_address='foo')]) - s.add(u1) - s.commit() - - u1.name='edward' - u1.addresses.append(Address(email_address='bar')) - s.begin_nested() - u2 = User(name='jack', addresses=[Address(email_address='bat')]) - s.add(u2) - self.assertEquals(s.query(User).order_by(User.id).all(), - [ - User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), - User(name='jack', addresses=[Address(email_address='bat')]) - ] - ) - s.commit() - self.assertEquals(s.query(User).order_by(User.id).all(), - [ - User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), - User(name='jack', addresses=[Address(email_address='bat')]) - ] - ) - s.commit() - self.assertEquals(s.query(User).order_by(User.id).all(), - [ - User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), - User(name='jack', addresses=[Address(email_address='bat')]) - ] - ) - - @testing.requires.savepoints - def test_expunge_pending_on_rollback(self): - sess = self.session() - - sess.begin_nested() - u2= User(name='newuser') - sess.add(u2) - assert u2 in sess - sess.rollback() - assert u2 not in sess - - @testing.requires.savepoints - def test_update_deleted_on_rollback(self): - s = self.session() - u1 = User(name='ed') - s.add(u1) - s.commit() - - s.begin_nested() - s.delete(u1) - assert u1 in s.deleted - s.rollback() - assert u1 in s - assert u1 not in s.deleted - - -class AccountingFlagsTest(TransactionTest): - def test_no_expire_on_commit(self): - sess = sessionmaker(expire_on_commit=False)() - u1 = User(name='ed') - sess.add(u1) - sess.commit() - - testing.db.execute(users.update(users.c.name=='ed').values(name='edward')) - - assert u1.name == 'ed' - sess.expire_all() - assert u1.name == 'edward' - - def test_rollback_no_accounting(self): - sess = sessionmaker(_enable_transaction_accounting=False)() - u1 = User(name='ed') - sess.add(u1) - sess.commit() - - u1.name = 'edwardo' - sess.rollback() - - testing.db.execute(users.update(users.c.name=='ed').values(name='edward')) - - assert u1.name == 'edwardo' - sess.expire_all() - assert u1.name == 'edward' - - def test_commit_no_accounting(self): - sess = sessionmaker(_enable_transaction_accounting=False)() - u1 = User(name='ed') - sess.add(u1) - sess.commit() - - u1.name = 'edwardo' - sess.rollback() - - testing.db.execute(users.update(users.c.name=='ed').values(name='edward')) - - assert u1.name == 'edwardo' - sess.commit() - - assert testing.db.execute(select([users.c.name])).fetchall() == [('edwardo',)] - assert u1.name == 'edwardo' - - sess.delete(u1) - sess.commit() - - def test_preflush_no_accounting(self): - sess = sessionmaker(_enable_transaction_accounting=False, autocommit=True)() - u1 = User(name='ed') - sess.add(u1) - sess.flush() - - sess.begin() - u1.name = 'edwardo' - u2 = User(name="some other user") - sess.add(u2) - - sess.rollback() - - sess.begin() - assert testing.db.execute(select([users.c.name])).fetchall() == [('ed',)] - - -class AutoCommitTest(TransactionTest): - def test_begin_nested_requires_trans(self): - sess = create_session(autocommit=True) - self.assertRaises(sa_exc.InvalidRequestError, sess.begin_nested) - - def test_begin_preflush(self): - sess = create_session(autocommit=True) - - u1 = User(name='ed') - sess.add(u1) - - sess.begin() - u2 = User(name='some other user') - sess.add(u2) - sess.rollback() - assert u2 not in sess - assert u1 in sess - assert sess.query(User).filter_by(name='ed').one() is u1 - - - - -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py deleted file mode 100644 index c5e3afd01..000000000 --- a/test/orm/unitofwork.py +++ /dev/null @@ -1,2336 +0,0 @@ -# coding: utf-8 -"""Tests unitofwork operations.""" - -import testenv; testenv.configure_for_tests() -import datetime -import operator -from sqlalchemy.orm import mapper as orm_mapper - -from testlib import engines, sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, literal_column -from testlib.sa.orm import mapper, relation, create_session, column_property -from testlib.testing import eq_, ne_ -from orm import _base, _fixtures -from engine import _base as engine_base -import pickleable -from testlib.assertsql import AllOf, CompiledSQL -import gc - -class UnitOfWorkTest(object): - pass - -class HistoryTest(_fixtures.FixtureTest): - run_inserts = None - - def setup_classes(self): - class User(_base.ComparableEntity): - pass - class Address(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def test_backref(self): - am = mapper(Address, addresses) - m = mapper(User, users, properties=dict( - addresses = relation(am, backref='user', lazy=False))) - - session = create_session(autocommit=False) - - u = User(name='u1') - a = Address(email_address='u1@e') - a.user = u - session.add(u) - - self.assert_(u.addresses == [a]) - session.commit() - session.expunge_all() - - u = session.query(m).one() - assert u.addresses[0].user == u - session.close() - - -class VersioningTest(_base.MappedTest): - def define_tables(self, metadata): - Table('version_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('version_id', Integer, nullable=False), - Column('value', String(40), nullable=False)) - - def setup_classes(self): - class Foo(_base.ComparableEntity): - pass - - @engines.close_open_connections - @testing.resolve_artifact_names - def test_basic(self): - mapper(Foo, version_table, version_id_col=version_table.c.version_id) - - s1 = create_session(autocommit=False) - f1 = Foo(value='f1') - f2 = Foo(value='f2') - s1.add_all((f1, f2)) - s1.commit() - - f1.value='f1rev2' - s1.commit() - - s2 = create_session(autocommit=False) - f1_s = s2.query(Foo).get(f1.id) - f1_s.value='f1rev3' - s2.commit() - - f1.value='f1rev3mine' - - # Only dialects with a sane rowcount can detect the - # ConcurrentModificationError - if testing.db.dialect.supports_sane_rowcount: - self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.commit) - s1.rollback() - else: - s1.commit() - - # new in 0.5 ! dont need to close the session - f1 = s1.query(Foo).get(f1.id) - f2 = s1.query(Foo).get(f2.id) - - f1_s.value='f1rev4' - s2.commit() - - s1.delete(f1) - s1.delete(f2) - - if testing.db.dialect.supports_sane_multi_rowcount: - self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.commit) - else: - s1.commit() - - @engines.close_open_connections - @testing.resolve_artifact_names - def test_versioncheck(self): - """query.with_lockmode performs a 'version check' on an already loaded instance""" - - s1 = create_session(autocommit=False) - - mapper(Foo, version_table, version_id_col=version_table.c.version_id) - f1s1 = Foo(value='f1 value') - s1.add(f1s1) - s1.commit() - - s2 = create_session(autocommit=False) - f1s2 = s2.query(Foo).get(f1s1.id) - f1s2.value='f1 new value' - s2.commit() - - # load, version is wrong - self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id) - - # reload it - s1.query(Foo).populate_existing().get(f1s1.id) - # now assert version OK - s1.query(Foo).with_lockmode('read').get(f1s1.id) - - # assert brand new load is OK too - s1.close() - s1.query(Foo).with_lockmode('read').get(f1s1.id) - - @engines.close_open_connections - @testing.resolve_artifact_names - def test_noversioncheck(self): - """test query.with_lockmode works when the mapper has no version id col""" - s1 = create_session(autocommit=False) - mapper(Foo, version_table) - f1s1 = Foo(value="foo", version_id=0) - s1.add(f1s1) - s1.commit() - - s2 = create_session(autocommit=False) - f1s2 = s2.query(Foo).with_lockmode('read').get(f1s1.id) - assert f1s2.id == f1s1.id - assert f1s2.value == f1s1.value - -class UnicodeTest(_base.MappedTest): - __requires__ = ('unicode_connections',) - - def define_tables(self, metadata): - Table('uni_t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('txt', sa.Unicode(50), unique=True)) - Table('uni_t2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('txt', sa.Unicode(50), ForeignKey('uni_t1'))) - - def setup_classes(self): - class Test(_base.BasicEntity): - pass - class Test2(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def test_basic(self): - mapper(Test, uni_t1) - - txt = u"\u0160\u0110\u0106\u010c\u017d" - t1 = Test(id=1, txt=txt) - self.assert_(t1.txt == txt) - - session = create_session(autocommit=False) - session.add(t1) - session.commit() - - self.assert_(t1.txt == txt) - - @testing.resolve_artifact_names - def test_relation(self): - mapper(Test, uni_t1, properties={ - 't2s': relation(Test2)}) - mapper(Test2, uni_t2) - - txt = u"\u0160\u0110\u0106\u010c\u017d" - t1 = Test(txt=txt) - t1.t2s.append(Test2()) - t1.t2s.append(Test2()) - session = create_session(autocommit=False) - session.add(t1) - session.commit() - session.close() - - session = create_session() - t1 = session.query(Test).filter_by(id=t1.id).one() - assert len(t1.t2s) == 2 - -class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest): - __requires__ = ('unicode_connections', 'unicode_ddl',) - - def create_engine(self): - return engines.utf8_engine() - - def define_tables(self, metadata): - t1 = Table('unitable1', metadata, - Column(u'méil', Integer, primary_key=True, key='a', test_needs_autoincrement=True), - Column(u'\u6e2c\u8a66', Integer, key='b'), - Column('type', String(20)), - test_needs_fk=True, - test_needs_autoincrement=True) - t2 = Table(u'Unitéble2', metadata, - Column(u'méil', Integer, primary_key=True, key="cc", test_needs_autoincrement=True), - Column(u'\u6e2c\u8a66', Integer, - ForeignKey(u'unitable1.a'), key="d"), - Column(u'\u6e2c\u8a66_2', Integer, key="e"), - test_needs_fk=True, - test_needs_autoincrement=True) - - self.tables['t1'] = t1 - self.tables['t2'] = t2 - - def setUpAll(self): - engine_base.AltEngineTest.setUpAll(self) - _base.MappedTest.setUpAll(self) - - def tearDownAll(self): - _base.MappedTest.tearDownAll(self) - engine_base.AltEngineTest.tearDownAll(self) - - @testing.fails_on('mssql', 'pyodbc returns a non unicode encoding of the results description.') - @testing.resolve_artifact_names - def test_mapping(self): - class A(_base.ComparableEntity): - pass - class B(_base.ComparableEntity): - pass - - mapper(A, t1, properties={ - 't2s':relation(B)}) - mapper(B, t2) - - a1 = A() - b1 = B() - a1.t2s.append(b1) - - session = create_session() - session.add(a1) - session.flush() - session.expunge_all() - - new_a1 = session.query(A).filter(t1.c.a == a1.a).one() - assert new_a1.a == a1.a - assert new_a1.t2s[0].d == b1.d - session.expunge_all() - - new_a1 = (session.query(A).options(sa.orm.eagerload('t2s')). - filter(t1.c.a == a1.a)).one() - assert new_a1.a == a1.a - assert new_a1.t2s[0].d == b1.d - session.expunge_all() - - new_a1 = session.query(A).filter(A.a == a1.a).one() - assert new_a1.a == a1.a - assert new_a1.t2s[0].d == b1.d - session.expunge_all() - - @testing.fails_on('mssql', 'pyodbc returns a non unicode encoding of the results description.') - @testing.resolve_artifact_names - def test_inheritance_mapping(self): - class A(_base.ComparableEntity): - pass - class B(A): - pass - - mapper(A, t1, - polymorphic_on=t1.c.type, - polymorphic_identity='a') - mapper(B, t2, - inherits=A, - polymorphic_identity='b') - a1 = A(b=5) - b1 = B(e=7) - - session = create_session() - session.add_all((a1, b1)) - session.flush() - session.expunge_all() - - eq_([A(b=5), B(e=7)], session.query(A).all()) - - -class MutableTypesTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('mutable_t', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', sa.PickleType), - Column('val', sa.Unicode(30))) - - def setup_classes(self): - class Foo(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Foo, mutable_t) - - @testing.resolve_artifact_names - def test_basic(self): - """Changes are detected for types marked as MutableType.""" - - f1 = Foo() - f1.data = pickleable.Bar(4,5) - - session = create_session() - session.add(f1) - session.flush() - session.expunge_all() - - f2 = session.query(Foo).filter_by(id=f1.id).one() - assert 'data' in sa.orm.attributes.instance_state(f2).unmodified - eq_(f2.data, f1.data) - - f2.data.y = 19 - assert f2 in session.dirty - assert 'data' not in sa.orm.attributes.instance_state(f2).unmodified - session.flush() - session.expunge_all() - - f3 = session.query(Foo).filter_by(id=f1.id).one() - ne_(f3.data,f1.data) - eq_(f3.data, pickleable.Bar(4, 19)) - - @testing.resolve_artifact_names - def test_mutable_changes(self): - """Mutable changes are detected or not detected correctly""" - - f1 = Foo() - f1.data = pickleable.Bar(4,5) - f1.val = u'hi' - - session = create_session(autocommit=False) - session.add(f1) - session.commit() - - bind = self.metadata.bind - - self.sql_count_(0, session.commit) - f1.val = u'someothervalue' - self.assert_sql(bind, session.commit, [ - ("UPDATE mutable_t SET val=:val " - "WHERE mutable_t.id = :mutable_t_id", - {'mutable_t_id': f1.id, 'val': u'someothervalue'})]) - - f1.val = u'hi' - f1.data.x = 9 - self.assert_sql(bind, session.commit, [ - ("UPDATE mutable_t SET data=:data, val=:val " - "WHERE mutable_t.id = :mutable_t_id", - {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})]) - - - @testing.resolve_artifact_names - def test_resurrect(self): - f1 = Foo() - f1.data = pickleable.Bar(4,5) - f1.val = u'hi' - - session = create_session(autocommit=False) - session.add(f1) - session.commit() - - f1.data.y = 19 - del f1 - - gc.collect() - assert len(session.identity_map) == 1 - - session.commit() - - assert session.query(Foo).one().data == pickleable.Bar(4, 19) - - - @testing.uses_deprecated() - @testing.resolve_artifact_names - def test_nocomparison(self): - """Changes are detected on MutableTypes lacking an __eq__ method.""" - - f1 = Foo() - f1.data = pickleable.BarWithoutCompare(4,5) - session = create_session(autocommit=False) - session.add(f1) - session.commit() - - self.sql_count_(0, session.commit) - session.close() - - session = create_session(autocommit=False) - f2 = session.query(Foo).filter_by(id=f1.id).one() - self.sql_count_(0, session.commit) - - f2.data.y = 19 - self.sql_count_(1, session.commit) - session.close() - - session = create_session(autocommit=False) - f3 = session.query(Foo).filter_by(id=f1.id).one() - eq_((f3.data.x, f3.data.y), (4,19)) - self.sql_count_(0, session.commit) - session.close() - - @testing.resolve_artifact_names - def test_unicode(self): - """Equivalent Unicode values are not flagged as changed.""" - - f1 = Foo(val=u'hi') - - session = create_session(autocommit=False) - session.add(f1) - session.commit() - session.expunge_all() - - f1 = session.query(Foo).get(f1.id) - f1.val = u'hi' - self.sql_count_(0, session.commit) - - -class PickledDicts(_base.MappedTest): - - def define_tables(self, metadata): - Table('mutable_t', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', sa.PickleType(comparator=operator.eq))) - - def setup_classes(self): - class Foo(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(Foo, mutable_t) - - @testing.resolve_artifact_names - def test_dicts(self): - """Dictionaries may not pickle the same way twice.""" - - f1 = Foo() - f1.data = [ { - 'personne': {'nom': u'Smith', - 'pers_id': 1, - 'prenom': u'john', - 'civilite': u'Mr', - 'int_3': False, - 'int_2': False, - 'int_1': u'23', - 'VenSoir': True, - 'str_1': u'Test', - 'SamMidi': False, - 'str_2': u'chien', - 'DimMidi': False, - 'SamSoir': True, - 'SamAcc': False} } ] - - session = create_session(autocommit=False) - session.add(f1) - session.commit() - - self.sql_count_(0, session.commit) - - f1.data = [ { - 'personne': {'nom': u'Smith', - 'pers_id': 1, - 'prenom': u'john', - 'civilite': u'Mr', - 'int_3': False, - 'int_2': False, - 'int_1': u'23', - 'VenSoir': True, - 'str_1': u'Test', - 'SamMidi': False, - 'str_2': u'chien', - 'DimMidi': False, - 'SamSoir': True, - 'SamAcc': False} } ] - - self.sql_count_(0, session.commit) - - f1.data[0]['personne']['VenSoir']= False - self.sql_count_(1, session.commit) - - session.expunge_all() - f = session.query(Foo).get(f1.id) - eq_(f.data, - [ { - 'personne': {'nom': u'Smith', - 'pers_id': 1, - 'prenom': u'john', - 'civilite': u'Mr', - 'int_3': False, - 'int_2': False, - 'int_1': u'23', - 'VenSoir': False, - 'str_1': u'Test', - 'SamMidi': False, - 'str_2': u'chien', - 'DimMidi': False, - 'SamSoir': True, - 'SamAcc': False} } ]) - - -class PKTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('multipk1', metadata, - Column('multi_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('multi_rev', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('value', String(100))) - - Table('multipk2', metadata, - Column('pk_col_1', String(30), primary_key=True), - Column('pk_col_2', String(30), primary_key=True), - Column('data', String(30))) - Table('multipk3', metadata, - Column('pri_code', String(30), key='primary', primary_key=True), - Column('sec_code', String(30), key='secondary', primary_key=True), - Column('date_assigned', sa.Date, key='assigned', primary_key=True), - Column('data', String(30))) - - def setup_classes(self): - class Entry(_base.BasicEntity): - pass - - # not supported on sqlite since sqlite's auto-pk generation only works with - # single column primary keys - @testing.fails_on('sqlite', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_primary_key(self): - mapper(Entry, multipk1) - - e = Entry(name='entry1', value='this is entry 1', multi_rev=2) - - session = create_session() - session.add(e) - session.flush() - session.expunge_all() - - e2 = session.query(Entry).get((e.multi_id, 2)) - self.assert_(e is not e2) - state = sa.orm.attributes.instance_state(e) - state2 = sa.orm.attributes.instance_state(e2) - eq_(state.key, state2.key) - - # this one works with sqlite since we are manually setting up pk values - @testing.resolve_artifact_names - def test_manual_pk(self): - mapper(Entry, multipk2) - - e = Entry(pk_col_1='pk1', pk_col_2='pk1_related', data='im the data') - - session = create_session() - session.add(e) - session.flush() - - @testing.resolve_artifact_names - def test_key_pks(self): - mapper(Entry, multipk3) - - e = Entry(primary= 'pk1', secondary='pk2', - assigned=datetime.date.today(), data='some more data') - - session = create_session() - session.add(e) - session.flush() - - -class ForeignPKTest(_base.MappedTest): - """Detection of the relationship direction on PK joins.""" - - def define_tables(self, metadata): - Table("people", metadata, - Column('person', String(10), primary_key=True), - Column('firstname', String(10)), - Column('lastname', String(10))) - - Table("peoplesites", metadata, - Column('person', String(10), ForeignKey("people.person"), - primary_key=True), - Column('site', String(10))) - - def setup_classes(self): - class Person(_base.BasicEntity): - pass - class PersonSite(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def test_basic(self): - m1 = mapper(PersonSite, peoplesites) - m2 = mapper(Person, people, properties={ - 'sites' : relation(PersonSite)}) - - sa.orm.compile_mappers() - eq_(list(m2.get_property('sites').synchronize_pairs), - [(people.c.person, peoplesites.c.person)]) - - p = Person(person='im the key', firstname='asdf') - ps = PersonSite(site='asdf') - p.sites.append(ps) - - session = create_session() - session.add(p) - session.flush() - - p_count = people.count(people.c.person=='im the key').scalar() - eq_(p_count, 1) - eq_(peoplesites.count(peoplesites.c.person=='im the key').scalar(), 1) - - -class ClauseAttributesTest(_base.MappedTest): - - def define_tables(self, metadata): - Table('users_t', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30)), - Column('counter', Integer, default=1)) - - def setup_classes(self): - class User(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - mapper(User, users_t) - - @testing.resolve_artifact_names - def test_update(self): - u = User(name='test') - - session = create_session() - session.add(u) - session.flush() - - eq_(u.counter, 1) - u.counter = User.counter + 1 - session.flush() - - def go(): - assert (u.counter == 2) is True # ensure its not a ClauseElement - self.sql_count_(1, go) - - @testing.resolve_artifact_names - def test_multi_update(self): - u = User(name='test') - - session = create_session() - session.add(u) - session.flush() - - eq_(u.counter, 1) - u.name = 'test2' - u.counter = User.counter + 1 - session.flush() - - def go(): - eq_(u.name, 'test2') - assert (u.counter == 2) is True - self.sql_count_(1, go) - - session.expunge_all() - u = session.query(User).get(u.id) - eq_(u.name, 'test2') - eq_(u.counter, 2) - - @testing.resolve_artifact_names - def test_insert(self): - u = User(name='test', counter=sa.select([5])) - - session = create_session() - session.add(u) - session.flush() - - assert (u.counter == 5) is True - - -class PassiveDeletesTest(_base.MappedTest): - __requires__ = ('foreign_keys',) - - def define_tables(self, metadata): - Table('mytable', metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('data', String(30)), - test_needs_fk=True) - - Table('myothertable', metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('parent_id', Integer), - Column('data', String(30)), - sa.ForeignKeyConstraint(['parent_id'], - ['mytable.id'], - ondelete="CASCADE"), - test_needs_fk=True) - - def setup_classes(self): - class MyClass(_base.BasicEntity): - pass - class MyOtherClass(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def test_basic(self): - mapper(MyOtherClass, myothertable) - mapper(MyClass, mytable, properties={ - 'children':relation(MyOtherClass, - passive_deletes=True, - cascade="all")}) - session = create_session() - mc = MyClass() - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - - session.add(mc) - session.flush() - session.expunge_all() - - assert myothertable.count().scalar() == 4 - mc = session.query(MyClass).get(mc.id) - session.delete(mc) - session.flush() - - assert mytable.count().scalar() == 0 - assert myothertable.count().scalar() == 0 - - @testing.resolve_artifact_names - def test_backwards_pd(self): - # the unusual scenario where a trigger or something might be deleting - # a many-to-one on deletion of the parent row - mapper(MyOtherClass, myothertable, properties={ - 'myclass':relation(MyClass, cascade="all, delete", passive_deletes=True) - }) - mapper(MyClass, mytable) - - session = create_session() - mc = MyClass() - mco = MyOtherClass() - mco.myclass = mc - session.add(mco) - session.flush() - - assert mytable.count().scalar() == 1 - assert myothertable.count().scalar() == 1 - - session.expire(mco, ['myclass']) - session.delete(mco) - session.flush() - - assert mytable.count().scalar() == 1 - assert myothertable.count().scalar() == 0 - -class ExtraPassiveDeletesTest(_base.MappedTest): - __requires__ = ('foreign_keys',) - - def define_tables(self, metadata): - Table('mytable', metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('data', String(30)), - test_needs_fk=True) - - Table('myothertable', metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('parent_id', Integer), - Column('data', String(30)), - # no CASCADE, the same as ON DELETE RESTRICT - sa.ForeignKeyConstraint(['parent_id'], - ['mytable.id']), - test_needs_fk=True) - - def setup_classes(self): - class MyClass(_base.BasicEntity): - pass - class MyOtherClass(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def test_assertions(self): - mapper(MyOtherClass, myothertable) - try: - mapper(MyClass, mytable, properties={ - 'children':relation(MyOtherClass, - passive_deletes='all', - cascade="all")}) - assert False - except sa.exc.ArgumentError, e: - eq_(str(e), - "Can't set passive_deletes='all' in conjunction with 'delete' " - "or 'delete-orphan' cascade") - - @testing.resolve_artifact_names - def test_extra_passive(self): - mapper(MyOtherClass, myothertable) - mapper(MyClass, mytable, properties={ - 'children': relation(MyOtherClass, - passive_deletes='all', - cascade="save-update")}) - - session = create_session() - mc = MyClass() - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - mc.children.append(MyOtherClass()) - session.add(mc) - session.flush() - session.expunge_all() - - assert myothertable.count().scalar() == 4 - mc = session.query(MyClass).get(mc.id) - session.delete(mc) - self.assertRaises(sa.exc.DBAPIError, session.flush) - - @testing.resolve_artifact_names - def test_extra_passive_2(self): - mapper(MyOtherClass, myothertable) - mapper(MyClass, mytable, properties={ - 'children': relation(MyOtherClass, - passive_deletes='all', - cascade="save-update")}) - - session = create_session() - mc = MyClass() - mc.children.append(MyOtherClass()) - session.add(mc) - session.flush() - session.expunge_all() - - assert myothertable.count().scalar() == 1 - - mc = session.query(MyClass).get(mc.id) - session.delete(mc) - mc.children[0].data = 'some new data' - self.assertRaises(sa.exc.DBAPIError, session.flush) - - -class DefaultTest(_base.MappedTest): - """Exercise mappings on columns with DefaultGenerators. - - Tests that when saving objects whose table contains DefaultGenerators, - either python-side, preexec or database-side, the newly saved instances - receive all the default values either through a post-fetch or getting the - pre-exec'ed defaults back from the engine. - - """ - - def define_tables(self, metadata): - use_string_defaults = testing.against('postgres', 'oracle', 'sqlite', 'mssql') - - if use_string_defaults: - hohotype = String(30) - hohoval = "im hoho" - althohoval = "im different hoho" - else: - hohotype = Integer - hohoval = 9 - althohoval = 15 - - self.other_artifacts['hohoval'] = hohoval - self.other_artifacts['althohoval'] = althohoval - - dt = Table('default_t', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('hoho', hohotype, server_default=str(hohoval)), - Column('counter', Integer, default=sa.func.char_length("1234567")), - Column('foober', String(30), default="im foober", - onupdate="im the update")) - - st = Table('secondary_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) - - if testing.against('postgres', 'oracle'): - dt.append_column( - Column('secondary_id', Integer, sa.Sequence('sec_id_seq'), - unique=True)) - st.append_column( - Column('fk_val', Integer, - ForeignKey('default_t.secondary_id'))) - elif testing.against('mssql'): - st.append_column( - Column('fk_val', Integer, - ForeignKey('default_t.id'))) - else: - st.append_column( - Column('hoho', hohotype, ForeignKey('default_t.hoho'))) - - def setup_classes(self): - class Hoho(_base.ComparableEntity): - pass - class Secondary(_base.ComparableEntity): - pass - - @testing.fails_on('firebird', 'Data type unknown on the parameter') - @testing.resolve_artifact_names - def test_insert(self): - mapper(Hoho, default_t) - - h1 = Hoho(hoho=althohoval) - h2 = Hoho(counter=12) - h3 = Hoho(hoho=althohoval, counter=12) - h4 = Hoho() - h5 = Hoho(foober='im the new foober') - - session = create_session(autocommit=False) - session.add_all((h1, h2, h3, h4, h5)) - session.commit() - - eq_(h1.hoho, althohoval) - eq_(h3.hoho, althohoval) - - def go(): - # test deferred load of attribues, one select per instance - self.assert_(h2.hoho == h4.hoho == h5.hoho == hohoval) - self.sql_count_(3, go) - - def go(): - self.assert_(h1.counter == h4.counter == h5.counter == 7) - self.sql_count_(1, go) - - def go(): - self.assert_(h3.counter == h2.counter == 12) - self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') - self.assert_(h5.foober == 'im the new foober') - self.sql_count_(0, go) - - session.expunge_all() - - (h1, h2, h3, h4, h5) = session.query(Hoho).order_by(Hoho.id).all() - - eq_(h1.hoho, althohoval) - eq_(h3.hoho, althohoval) - self.assert_(h2.hoho == h4.hoho == h5.hoho == hohoval) - self.assert_(h3.counter == h2.counter == 12) - self.assert_(h1.counter == h4.counter == h5.counter == 7) - self.assert_(h2.foober == h3.foober == h4.foober == 'im foober') - eq_(h5.foober, 'im the new foober') - - @testing.fails_on('firebird', 'Data type unknown on the parameter') - @testing.resolve_artifact_names - def test_eager_defaults(self): - mapper(Hoho, default_t, eager_defaults=True) - - h1 = Hoho() - - session = create_session() - session.add(h1) - session.flush() - - self.sql_count_(0, lambda: eq_(h1.hoho, hohoval)) - - @testing.resolve_artifact_names - def test_insert_nopostfetch(self): - # populates from the FetchValues explicitly so there is no - # "post-update" - mapper(Hoho, default_t) - - h1 = Hoho(hoho="15", counter="15") - session = create_session() - session.add(h1) - session.flush() - - def go(): - eq_(h1.hoho, "15") - eq_(h1.counter, "15") - eq_(h1.foober, "im foober") - self.sql_count_(0, go) - - @testing.fails_on('firebird', 'Data type unknown on the parameter') - @testing.resolve_artifact_names - def test_update(self): - mapper(Hoho, default_t) - - h1 = Hoho() - session = create_session() - session.add(h1) - session.flush() - - eq_(h1.foober, 'im foober') - h1.counter = 19 - session.flush() - eq_(h1.foober, 'im the update') - - @testing.fails_on('firebird', 'Data type unknown on the parameter') - @testing.resolve_artifact_names - def test_used_in_relation(self): - """A server-side default can be used as the target of a foreign key""" - - mapper(Hoho, default_t, properties={ - 'secondaries':relation(Secondary)}) - mapper(Secondary, secondary_table) - - h1 = Hoho() - s1 = Secondary(data='s1') - h1.secondaries.append(s1) - - session = create_session() - session.add(h1) - session.flush() - session.expunge_all() - - eq_(session.query(Hoho).get(h1.id), - Hoho(hoho=hohoval, - secondaries=[ - Secondary(data='s1')])) - - h1 = session.query(Hoho).get(h1.id) - h1.secondaries.append(Secondary(data='s2')) - session.flush() - session.expunge_all() - - eq_(session.query(Hoho).get(h1.id), - Hoho(hoho=hohoval, - secondaries=[ - Secondary(data='s1'), - Secondary(data='s2')])) - -class ColumnPropertyTest(_base.MappedTest): - def define_tables(self, metadata): - Table('data', metadata, - Column('id', Integer, primary_key=True), - Column('a', String(50)), - Column('b', String(50)) - ) - - Table('subdata', metadata, - Column('id', Integer, ForeignKey('data.id'), primary_key=True), - Column('c', String(50)), - ) - - def setup_mappers(self): - class Data(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def test_refreshes(self): - mapper(Data, data, properties={ - 'aplusb':column_property(data.c.a + literal_column("' '") + data.c.b) - }) - self._test() - - @testing.resolve_artifact_names - def test_refreshes_post_init(self): - m = mapper(Data, data) - m.add_property('aplusb', column_property(data.c.a + literal_column("' '") + data.c.b)) - self._test() - - @testing.resolve_artifact_names - def test_with_inheritance(self): - class SubData(Data): - pass - mapper(Data, data, properties={ - 'aplusb':column_property(data.c.a + literal_column("' '") + data.c.b) - }) - mapper(SubData, subdata, inherits=Data) - - sess = create_session() - sd1 = SubData(a="hello", b="there", c="hi") - sess.add(sd1) - sess.flush() - self.assertEquals(sd1.aplusb, "hello there") - - @testing.resolve_artifact_names - def _test(self): - sess = create_session() - - d1 = Data(a="hello", b="there") - sess.add(d1) - sess.flush() - - self.assertEquals(d1.aplusb, "hello there") - - d1.b = "bye" - sess.flush() - self.assertEquals(d1.aplusb, "hello bye") - - d1.b = 'foobar' - d1.aplusb = 'im setting this explicitly' - sess.flush() - self.assertEquals(d1.aplusb, "im setting this explicitly") - -class OneToManyTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def test_one_to_many_1(self): - """Basic save of one to many.""" - - m = mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=True) - )) - u = User(name= 'one2manytester') - a = Address(email_address='one2many@test.org') - u.addresses.append(a) - - a2 = Address(email_address='lala@test.org') - u.addresses.append(a2) - - session = create_session() - session.add(u) - session.flush() - - user_rows = users.select(users.c.id.in_([u.id])).execute().fetchall() - eq_(user_rows[0].values(), [u.id, 'one2manytester']) - - address_rows = addresses.select( - addresses.c.id.in_([a.id, a2.id]), - order_by=[addresses.c.email_address]).execute().fetchall() - eq_(address_rows[0].values(), [a2.id, u.id, 'lala@test.org']) - eq_(address_rows[1].values(), [a.id, u.id, 'one2many@test.org']) - - userid = u.id - addressid = a2.id - - a2.email_address = 'somethingnew@foo.com' - - session.flush() - - address_rows = addresses.select( - addresses.c.id == addressid).execute().fetchall() - eq_(address_rows[0].values(), - [addressid, userid, 'somethingnew@foo.com']) - self.assert_(u.id == userid and a2.id == addressid) - - @testing.resolve_artifact_names - def test_one_to_many_2(self): - """Modifying the child items of an object.""" - - m = mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=True))) - - u1 = User(name='user1') - u1.addresses = [] - a1 = Address(email_address='emailaddress1') - u1.addresses.append(a1) - - u2 = User(name='user2') - u2.addresses = [] - a2 = Address(email_address='emailaddress2') - u2.addresses.append(a2) - - a3 = Address(email_address='emailaddress3') - - session = create_session() - session.add_all((u1, u2, a3)) - session.flush() - - # modify user2 directly, append an address to user1. - # upon commit, user2 should be updated, user1 should not - # both address1 and address3 should be updated - u2.name = 'user2modified' - u1.addresses.append(a3) - del u1.addresses[0] - - self.assert_sql(testing.db, session.flush, [ - ("UPDATE users SET name=:name " - "WHERE users.id = :users_id", - {'users_id': u2.id, 'name': 'user2modified'}), - - ("UPDATE addresses SET user_id=:user_id " - "WHERE addresses.id = :addresses_id", - {'user_id': None, 'addresses_id': a1.id}), - - ("UPDATE addresses SET user_id=:user_id " - "WHERE addresses.id = :addresses_id", - {'user_id': u1.id, 'addresses_id': a3.id})]) - - @testing.resolve_artifact_names - def test_child_move(self): - """Moving a child from one parent to another, with a delete. - - Tests that deleting the first parent properly updates the child with - the new parent. This tests the 'trackparent' option in the attributes - module. - - """ - m = mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=True))) - - u1 = User(name='user1') - u2 = User(name='user2') - a = Address(email_address='address1') - u1.addresses.append(a) - - session = create_session() - session.add_all((u1, u2)) - session.flush() - - del u1.addresses[0] - u2.addresses.append(a) - session.delete(u1) - - session.flush() - session.expunge_all() - - u2 = session.query(User).get(u2.id) - eq_(len(u2.addresses), 1) - - @testing.resolve_artifact_names - def test_child_move_2(self): - m = mapper(User, users, properties=dict( - addresses = relation(mapper(Address, addresses), lazy=True))) - - u1 = User(name='user1') - u2 = User(name='user2') - a = Address(email_address='address1') - u1.addresses.append(a) - - session = create_session() - session.add_all((u1, u2)) - session.flush() - - del u1.addresses[0] - u2.addresses.append(a) - - session.flush() - session.expunge_all() - - u2 = session.query(User).get(u2.id) - eq_(len(u2.addresses), 1) - - @testing.resolve_artifact_names - def test_o2m_delete_parent(self): - m = mapper(User, users, properties=dict( - address = relation(mapper(Address, addresses), - lazy=True, - uselist=False))) - - u = User(name='one2onetester') - a = Address(email_address='myonlyaddress@foo.com') - u.address = a - - session = create_session() - session.add(u) - session.flush() - - session.delete(u) - session.flush() - - assert a.id is not None - assert a.user_id is None - assert sa.orm.attributes.instance_state(a).key in session.identity_map - assert sa.orm.attributes.instance_state(u).key not in session.identity_map - - @testing.resolve_artifact_names - def test_one_to_one(self): - m = mapper(User, users, properties=dict( - address = relation(mapper(Address, addresses), - lazy=True, - uselist=False))) - - u = User(name='one2onetester') - u.address = Address(email_address='myonlyaddress@foo.com') - - session = create_session() - session.add(u) - session.flush() - - u.name = 'imnew' - session.flush() - - u.address.email_address = 'imnew@foo.com' - session.flush() - - @testing.resolve_artifact_names - def test_bidirectional(self): - m1 = mapper(User, users) - m2 = mapper(Address, addresses, properties=dict( - user = relation(m1, lazy=False, backref='addresses'))) - - - u = User(name='test') - a = Address(email_address='testaddress', user=u) - - session = create_session() - session.add(u) - session.flush() - session.delete(u) - session.flush() - - @testing.resolve_artifact_names - def test_double_relation(self): - m2 = mapper(Address, addresses) - m = mapper(User, users, properties={ - 'boston_addresses' : relation(m2, primaryjoin= - sa.and_(users.c.id==addresses.c.user_id, - addresses.c.email_address.like('%boston%'))), - 'newyork_addresses' : relation(m2, primaryjoin= - sa.and_(users.c.id==addresses.c.user_id, - addresses.c.email_address.like('%newyork%')))}) - - u = User(name='u1') - a = Address(email_address='foo@boston.com') - b = Address(email_address='bar@newyork.com') - u.boston_addresses.append(a) - u.newyork_addresses.append(b) - - session = create_session() - session.add(u) - session.flush() - -class SaveTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def test_basic(self): - m = mapper(User, users) - - # save two users - u = User(name='savetester') - u2 = User(name='savetester2') - - session = create_session() - session.add_all((u, u2)) - session.flush() - - # assert the first one retreives the same from the identity map - nu = session.query(m).get(u.id) - assert u is nu - - # clear out the identity map, so next get forces a SELECT - session.expunge_all() - - # check it again, identity should be different but ids the same - nu = session.query(m).get(u.id) - assert u is not nu and u.id == nu.id and nu.name == 'savetester' - - # change first users name and save - session = create_session() - session.add(u) - u.name = 'modifiedname' - assert u in session.dirty - session.flush() - - # select both - userlist = session.query(User).filter( - users.c.id.in_([u.id, u2.id])).order_by([users.c.name]).all() - - eq_(u.id, userlist[0].id) - eq_(userlist[0].name, 'modifiedname') - eq_(u2.id, userlist[1].id) - eq_(userlist[1].name, 'savetester2') - - @testing.resolve_artifact_names - def test_synonym(self): - class SUser(_base.BasicEntity): - def _get_name(self): - return "User:" + self.name - def _set_name(self, name): - self.name = name + ":User" - syn_name = property(_get_name, _set_name) - - mapper(SUser, users, properties={ - 'syn_name': sa.orm.synonym('name') - }) - - u = SUser(syn_name="some name") - eq_(u.syn_name, 'User:some name:User') - - session = create_session() - session.add(u) - session.flush() - session.expunge_all() - - u = session.query(SUser).first() - eq_(u.syn_name, 'User:some name:User') - - @testing.resolve_artifact_names - def test_lazyattr_commit(self): - """Lazily loaded relations. - - When a lazy-loaded list is unloaded, and a commit occurs, that the - 'passive' call on that list does not blow away its value - - """ - mapper(User, users, properties = { - 'addresses': relation(mapper(Address, addresses))}) - - u = User(name='u1') - u.addresses.append(Address(email_address='u1@e1')) - u.addresses.append(Address(email_address='u1@e2')) - u.addresses.append(Address(email_address='u1@e3')) - u.addresses.append(Address(email_address='u1@e4')) - - session = create_session() - session.add(u) - session.flush() - session.expunge_all() - - u = session.query(User).one() - u.name = 'newname' - session.flush() - eq_(len(u.addresses), 4) - - @testing.resolve_artifact_names - def test_inherits(self): - m1 = mapper(User, users) - - class AddressUser(User): - """a user object that also has the users mailing address.""" - pass - - # define a mapper for AddressUser that inherits the User.mapper, and - # joins on the id column - mapper(AddressUser, addresses, inherits=m1) - - au = AddressUser(name='u', email_address='u@e') - - session = create_session() - session.add(au) - session.flush() - session.expunge_all() - - rt = session.query(AddressUser).one() - eq_(au.user_id, rt.user_id) - eq_(rt.id, rt.id) - - @testing.resolve_artifact_names - def test_deferred(self): - """Deferred column operations""" - - mapper(Order, orders, properties={ - 'description': sa.orm.deferred(orders.c.description)}) - - # dont set deferred attribute, commit session - o = Order(id=42) - session = create_session(autocommit=False) - session.add(o) - session.commit() - - # assert that changes get picked up - o.description = 'foo' - session.commit() - - eq_(list(session.execute(orders.select(), mapper=Order)), - [(42, None, None, 'foo', None)]) - session.expunge_all() - - # assert that a set operation doesn't trigger a load operation - o = session.query(Order).filter(Order.description == 'foo').one() - def go(): - o.description = 'hoho' - self.sql_count_(0, go) - session.flush() - - eq_(list(session.execute(orders.select(), mapper=Order)), - [(42, None, None, 'hoho', None)]) - - session.expunge_all() - - # test assigning None to an unloaded deferred also works - o = session.query(Order).filter(Order.description == 'hoho').one() - o.description = None - session.flush() - eq_(list(session.execute(orders.select(), mapper=Order)), - [(42, None, None, None, None)]) - session.close() - - # why no support on oracle ? because oracle doesn't save - # "blank" strings; it saves a single space character. - @testing.fails_on('oracle', 'FIXME: unknown') - @testing.resolve_artifact_names - def test_dont_update_blanks(self): - mapper(User, users) - - u = User(name='') - session = create_session() - session.add(u) - session.flush() - session.expunge_all() - - u = session.query(User).get(u.id) - u.name = '' - self.sql_count_(0, session.flush) - - @testing.resolve_artifact_names - def test_multi_table_selectable(self): - """Mapped selectables that span tables. - - Also tests redefinition of the keynames for the column properties. - - """ - usersaddresses = sa.join(users, addresses, - users.c.id == addresses.c.user_id) - - m = mapper(User, usersaddresses, - properties=dict( - email = addresses.c.email_address, - foo_id = [users.c.id, addresses.c.user_id])) - - u = User(name='multitester', email='multi@test.org') - session = create_session() - session.add(u) - session.flush() - session.expunge_all() - - id = m.primary_key_from_instance(u) - - u = session.query(User).get(id) - assert u.name == 'multitester' - - user_rows = users.select(users.c.id.in_([u.foo_id])).execute().fetchall() - eq_(user_rows[0].values(), [u.foo_id, 'multitester']) - address_rows = addresses.select(addresses.c.id.in_([u.id])).execute().fetchall() - eq_(address_rows[0].values(), [u.id, u.foo_id, 'multi@test.org']) - - u.email = 'lala@hey.com' - u.name = 'imnew' - session.flush() - - user_rows = users.select(users.c.id.in_([u.foo_id])).execute().fetchall() - eq_(user_rows[0].values(), [u.foo_id, 'imnew']) - address_rows = addresses.select(addresses.c.id.in_([u.id])).execute().fetchall() - eq_(address_rows[0].values(), [u.id, u.foo_id, 'lala@hey.com']) - - session.expunge_all() - u = session.query(User).get(id) - assert u.name == 'imnew' - - @testing.resolve_artifact_names - def test_history_get(self): - """The history lazy-fetches data when it wasn't otherwise loaded.""" - mapper(User, users, properties={ - 'addresses':relation(Address, cascade="all, delete-orphan")}) - mapper(Address, addresses) - - u = User(name='u1') - u.addresses.append(Address(email_address='u1@e1')) - u.addresses.append(Address(email_address='u1@e2')) - session = create_session() - session.add(u) - session.flush() - session.expunge_all() - - u = session.query(User).get(u.id) - session.delete(u) - session.flush() - assert users.count().scalar() == 0 - assert addresses.count().scalar() == 0 - - @testing.resolve_artifact_names - def test_batch_mode(self): - """The 'batch=False' flag on mapper()""" - - names = [] - class TestExtension(sa.orm.MapperExtension): - def before_insert(self, mapper, connection, instance): - self.current_instance = instance - names.append(instance.name) - def after_insert(self, mapper, connection, instance): - assert instance is self.current_instance - - mapper(User, users, extension=TestExtension(), batch=False) - u1 = User(name='user1') - u2 = User(name='user2') - - session = create_session() - session.add_all((u1, u2)) - session.flush() - - u3 = User(name='user3') - u4 = User(name='user4') - u5 = User(name='user5') - - session.add_all([u4, u5, u3]) - session.flush() - - # test insert ordering is maintained - assert names == ['user1', 'user2', 'user4', 'user5', 'user3'] - session.expunge_all() - - sa.orm.clear_mappers() - - m = mapper(User, users, extension=TestExtension()) - u1 = User(name='user1') - u2 = User(name='user2') - session.add_all((u1, u2)) - self.assertRaises(AssertionError, session.flush) - - -class ManyToOneTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def test_m2o_one_to_one(self): - # TODO: put assertion in here !!! - m = mapper(Address, addresses, properties=dict( - user = relation(mapper(User, users), lazy=True, uselist=False))) - - session = create_session() - - data = [ - {'name': 'thesub' , 'email_address': 'bar@foo.com'}, - {'name': 'assdkfj' , 'email_address': 'thesdf@asdf.com'}, - {'name': 'n4knd' , 'email_address': 'asf3@bar.org'}, - {'name': 'v88f4' , 'email_address': 'adsd5@llala.net'}, - {'name': 'asdf8d' , 'email_address': 'theater@foo.com'} - ] - objects = [] - for elem in data: - a = Address() - a.email_address = elem['email_address'] - a.user = User() - a.user.name = elem['name'] - objects.append(a) - session.add(a) - - session.flush() - - objects[2].email_address = 'imnew@foo.bar' - objects[3].user = User() - objects[3].user.name = 'imnewlyadded' - self.assert_sql_execution(testing.db, - session.flush, - CompiledSQL("INSERT INTO users (name) VALUES (:name)", - {'name': 'imnewlyadded'} ), - - AllOf( - CompiledSQL("UPDATE addresses SET email_address=:email_address " - "WHERE addresses.id = :addresses_id", - lambda ctx: {'email_address': 'imnew@foo.bar', - 'addresses_id': objects[2].id}), - CompiledSQL("UPDATE addresses SET user_id=:user_id " - "WHERE addresses.id = :addresses_id", - lambda ctx: {'user_id': objects[3].user.id, - 'addresses_id': objects[3].id}) - ) - ) - - l = sa.select([users, addresses], - sa.and_(users.c.id==addresses.c.user_id, - addresses.c.id==a.id)).execute() - eq_(l.fetchone().values(), - [a.user.id, 'asdf8d', a.id, a.user_id, 'theater@foo.com']) - - @testing.resolve_artifact_names - def test_many_to_one_1(self): - m = mapper(Address, addresses, properties=dict( - user = relation(mapper(User, users), lazy=True))) - - a1 = Address(email_address='emailaddress1') - u1 = User(name='user1') - a1.user = u1 - - session = create_session() - session.add(a1) - session.flush() - session.expunge_all() - - a1 = session.query(Address).get(a1.id) - u1 = session.query(User).get(u1.id) - assert a1.user is u1 - - a1.user = None - session.flush() - session.expunge_all() - a1 = session.query(Address).get(a1.id) - u1 = session.query(User).get(u1.id) - assert a1.user is None - - @testing.resolve_artifact_names - def test_many_to_one_2(self): - m = mapper(Address, addresses, properties=dict( - user = relation(mapper(User, users), lazy=True))) - - a1 = Address(email_address='emailaddress1') - a2 = Address(email_address='emailaddress2') - u1 = User(name='user1') - a1.user = u1 - - session = create_session() - session.add_all((a1, a2)) - session.flush() - session.expunge_all() - - a1 = session.query(Address).get(a1.id) - a2 = session.query(Address).get(a2.id) - u1 = session.query(User).get(u1.id) - assert a1.user is u1 - - a1.user = None - a2.user = u1 - session.flush() - session.expunge_all() - - a1 = session.query(Address).get(a1.id) - a2 = session.query(Address).get(a2.id) - u1 = session.query(User).get(u1.id) - assert a1.user is None - assert a2.user is u1 - - @testing.resolve_artifact_names - def test_many_to_one_3(self): - m = mapper(Address, addresses, properties=dict( - user = relation(mapper(User, users), lazy=True))) - - a1 = Address(email_address='emailaddress1') - u1 = User(name='user1') - u2 = User(name='user2') - a1.user = u1 - - session = create_session() - session.add_all((a1, u1, u2)) - session.flush() - session.expunge_all() - - a1 = session.query(Address).get(a1.id) - u1 = session.query(User).get(u1.id) - u2 = session.query(User).get(u2.id) - assert a1.user is u1 - - a1.user = u2 - session.flush() - session.expunge_all() - a1 = session.query(Address).get(a1.id) - u1 = session.query(User).get(u1.id) - u2 = session.query(User).get(u2.id) - assert a1.user is u2 - - @testing.resolve_artifact_names - def test_bidirectional_no_load(self): - mapper(User, users, properties={ - 'addresses':relation(Address, backref='user', lazy=None)}) - mapper(Address, addresses) - - # try it on unsaved objects - u1 = User(name='u1') - a1 = Address(email_address='e1') - a1.user = u1 - - session = create_session() - session.add(u1) - session.flush() - session.expunge_all() - - a1 = session.query(Address).get(a1.id) - - a1.user = None - session.flush() - session.expunge_all() - assert session.query(Address).get(a1.id).user is None - assert session.query(User).get(u1.id).addresses == [] - - -class ManyToManyTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def test_many_to_many(self): - mapper(Keyword, keywords) - - m = mapper(Item, items, properties=dict( - keywords=relation(Keyword, - item_keywords, - lazy=False, - order_by=keywords.c.name))) - - data = [Item, - {'description': 'mm_item1', - 'keywords' : (Keyword, [{'name': 'big'}, - {'name': 'green'}, - {'name': 'purple'}, - {'name': 'round'}])}, - {'description': 'mm_item2', - 'keywords' : (Keyword, [{'name':'blue'}, - {'name':'imnew'}, - {'name':'round'}, - {'name':'small'}])}, - {'description': 'mm_item3', - 'keywords' : (Keyword, [])}, - {'description': 'mm_item4', - 'keywords' : (Keyword, [{'name':'big'}, - {'name':'blue'},])}, - {'description': 'mm_item5', - 'keywords' : (Keyword, [{'name':'big'}, - {'name':'exacting'}, - {'name':'green'}])}, - {'description': 'mm_item6', - 'keywords' : (Keyword, [{'name':'red'}, - {'name':'round'}, - {'name':'small'}])}] - - session = create_session() - - objects = [] - _keywords = dict([(k.name, k) for k in session.query(Keyword)]) - - for elem in data[1:]: - item = Item(description=elem['description']) - objects.append(item) - - for spec in elem['keywords'][1]: - keyword_name = spec['name'] - try: - kw = _keywords[keyword_name] - except KeyError: - _keywords[keyword_name] = kw = Keyword(name=keyword_name) - item.keywords.append(kw) - - session.add_all(objects) - session.flush() - - l = (session.query(Item). - filter(Item.description.in_([e['description'] - for e in data[1:]])). - order_by(Item.description).all()) - self.assert_result(l, *data) - - objects[4].description = 'item4updated' - k = Keyword() - k.name = 'yellow' - objects[5].keywords.append(k) - self.assert_sql_execution( - testing.db, - session.flush, - AllOf( - CompiledSQL("UPDATE items SET description=:description " - "WHERE items.id = :items_id", - {'description': 'item4updated', - 'items_id': objects[4].id}, - ), - CompiledSQL("INSERT INTO keywords (name) " - "VALUES (:name)", - {'name': 'yellow'}, - ) - ), - CompiledSQL("INSERT INTO item_keywords (item_id, keyword_id) " - "VALUES (:item_id, :keyword_id)", - lambda ctx: [{'item_id': objects[5].id, - 'keyword_id': k.id}]) - ) - - objects[2].keywords.append(k) - dkid = objects[5].keywords[1].id - del objects[5].keywords[1] - self.assert_sql_execution( - testing.db, - session.flush, - CompiledSQL("DELETE FROM item_keywords " - "WHERE item_keywords.item_id = :item_id AND " - "item_keywords.keyword_id = :keyword_id", - [{'item_id': objects[5].id, 'keyword_id': dkid}]), - CompiledSQL("INSERT INTO item_keywords (item_id, keyword_id) " - "VALUES (:item_id, :keyword_id)", - lambda ctx: [{'item_id': objects[2].id, 'keyword_id': k.id}] - )) - - session.delete(objects[3]) - session.flush() - - @testing.resolve_artifact_names - def test_many_to_many_remove(self): - """Setting a collection to empty deletes many-to-many rows. - - Tests that setting a list-based attribute to '[]' properly affects the - history and allows the many-to-many rows to be deleted - - """ - mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords = relation(Keyword, item_keywords, lazy=False), - )) - - i = Item(description='i1') - k1 = Keyword(name='k1') - k2 = Keyword(name='k2') - i.keywords.append(k1) - i.keywords.append(k2) - - session = create_session() - session.add(i) - session.flush() - - assert item_keywords.count().scalar() == 2 - i.keywords = [] - session.flush() - assert item_keywords.count().scalar() == 0 - - @testing.resolve_artifact_names - def test_scalar(self): - """sa.dependency won't delete an m2m relation referencing None.""" - - mapper(Keyword, keywords) - - mapper(Item, items, properties=dict( - keyword=relation(Keyword, secondary=item_keywords, uselist=False))) - - i = Item(description='x') - session = create_session() - session.add(i) - session.flush() - session.delete(i) - session.flush() - - @testing.resolve_artifact_names - def test_many_to_many_update(self): - """Assorted history operations on a many to many""" - mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords=relation(Keyword, - secondary=item_keywords, - lazy=False, - order_by=keywords.c.name))) - - k1 = Keyword(name='keyword 1') - k2 = Keyword(name='keyword 2') - k3 = Keyword(name='keyword 3') - - item = Item(description='item 1') - item.keywords.extend([k1, k2, k3]) - - session = create_session() - session.add(item) - session.flush() - - item.keywords = [] - item.keywords.append(k1) - item.keywords.append(k2) - session.flush() - - session.expunge_all() - item = session.query(Item).get(item.id) - assert item.keywords == [k1, k2] - - @testing.resolve_artifact_names - def test_association(self): - """Basic test of an association object""" - - class IKAssociation(_base.ComparableEntity): - pass - - mapper(Keyword, keywords) - - # note that we are breaking a rule here, making a second - # mapper(Keyword, keywords) the reorganization of mapper construction - # affected this, but was fixed again - - mapper(IKAssociation, item_keywords, - primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id], - properties=dict( - keyword=relation(mapper(Keyword, keywords, non_primary=True), - lazy=False, - uselist=False, - order_by=keywords.c.name # note here is a valid place where order_by can be used - ))) # on a scalar relation(); to determine eager ordering of - # the parent object within its collection. - - mapper(Item, items, properties=dict( - keywords=relation(IKAssociation, lazy=False))) - - session = create_session() - - def fixture(): - _kw = dict([(k.name, k) for k in session.query(Keyword)]) - for n in ('big', 'green', 'purple', 'round', 'huge', - 'violet', 'yellow', 'blue'): - if n not in _kw: - _kw[n] = Keyword(name=n) - - def assocs(*names): - return [IKAssociation(keyword=kw) - for kw in [_kw[n] for n in names]] - - return [ - Item(description='a_item1', - keywords=assocs('big', 'green', 'purple', 'round')), - Item(description='a_item2', - keywords=assocs('huge', 'violet', 'yellow')), - Item(description='a_item3', - keywords=assocs('big', 'blue'))] - - session.add_all(fixture()) - session.flush() - eq_(fixture(), session.query(Item).order_by(Item.description).all()) - - -class SaveTest2(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def test_m2o_nonmatch(self): - mapper(User, users) - mapper(Address, addresses, properties=dict( - user = relation(User, lazy=True, uselist=False))) - - session = create_session() - - def fixture(): - return [ - Address(email_address='a1', user=User(name='u1')), - Address(email_address='a2', user=User(name='u2'))] - - session.add_all(fixture()) - - self.assert_sql_execution( - testing.db, - session.flush, - CompiledSQL("INSERT INTO users (name) VALUES (:name)", - {'name': 'u1'}), - CompiledSQL("INSERT INTO users (name) VALUES (:name)", - {'name': 'u2'}), - CompiledSQL("INSERT INTO addresses (user_id, email_address) " - "VALUES (:user_id, :email_address)", - {'user_id': 1, 'email_address': 'a1'}), - CompiledSQL("INSERT INTO addresses (user_id, email_address) " - "VALUES (:user_id, :email_address)", - {'user_id': 2, 'email_address': 'a2'}), - ) - -class SaveTest3(_base.MappedTest): - def define_tables(self, metadata): - Table('items', metadata, - Column('item_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('item_name', String(50))) - - Table('keywords', metadata, - Column('keyword_id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(50))) - - Table('assoc', metadata, - Column('item_id', Integer, ForeignKey("items")), - Column('keyword_id', Integer, ForeignKey("keywords")), - Column('foo', sa.Boolean, default=True)) - - def setup_classes(self): - class Keyword(_base.BasicEntity): - pass - class Item(_base.BasicEntity): - pass - - @testing.resolve_artifact_names - def test_manytomany_xtracol_delete(self): - """A many-to-many on a table that has an extra column can properly delete rows from the table without referencing the extra column""" - - mapper(Keyword, keywords) - mapper(Item, items, properties=dict( - keywords = relation(Keyword, secondary=assoc, lazy=False),)) - - i = Item() - k1 = Keyword() - k2 = Keyword() - i.keywords.append(k1) - i.keywords.append(k2) - - session = create_session() - session.add(i) - session.flush() - - assert assoc.count().scalar() == 2 - i.keywords = [] - print i.keywords - session.flush() - assert assoc.count().scalar() == 0 - -class BooleanColTest(_base.MappedTest): - def define_tables(self, metadata): - Table('t1_t', metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('name', String(30)), - Column('value', sa.Boolean)) - - @testing.resolve_artifact_names - def test_boolean(self): - # use the regular mapper - class T(_base.ComparableEntity): - pass - orm_mapper(T, t1_t, order_by=t1_t.c.id) - - sess = create_session() - t1 = T(value=True, name="t1") - t2 = T(value=False, name="t2") - t3 = T(value=True, name="t3") - sess.add_all((t1, t2, t3)) - - sess.flush() - - for clear in (False, True): - if clear: - sess.expunge_all() - eq_(sess.query(T).all(), [T(value=True, name="t1"), T(value=False, name="t2"), T(value=True, name="t3")]) - if clear: - sess.expunge_all() - eq_(sess.query(T).filter(T.value==True).all(), [T(value=True, name="t1"),T(value=True, name="t3")]) - if clear: - sess.expunge_all() - eq_(sess.query(T).filter(T.value==False).all(), [T(value=False, name="t2")]) - - t2 = sess.query(T).get(t2.id) - t2.value = True - sess.flush() - eq_(sess.query(T).filter(T.value==True).all(), [T(value=True, name="t1"), T(value=True, name="t2"), T(value=True, name="t3")]) - t2.value = False - sess.flush() - eq_(sess.query(T).filter(T.value==True).all(), [T(value=True, name="t1"),T(value=True, name="t3")]) - - -class RowSwitchTest(_base.MappedTest): - def define_tables(self, metadata): - # parent - Table('t5', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30), nullable=False)) - - # onetomany - Table('t6', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30), nullable=False), - Column('t5id', Integer, ForeignKey('t5.id'),nullable=False)) - - # associated - Table('t7', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(30), nullable=False)) - - #manytomany - Table('t5t7', metadata, - Column('t5id', Integer, ForeignKey('t5.id'),nullable=False), - Column('t7id', Integer, ForeignKey('t7.id'),nullable=False)) - - def setup_classes(self): - class T5(_base.ComparableEntity): - pass - - class T6(_base.ComparableEntity): - pass - - class T7(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def test_onetomany(self): - mapper(T5, t5, properties={ - 't6s':relation(T6, cascade="all, delete-orphan") - }) - mapper(T6, t6) - - sess = create_session() - - o5 = T5(data='some t5', id=1) - o5.t6s.append(T6(data='some t6', id=1)) - o5.t6s.append(T6(data='some other t6', id=2)) - - sess.add(o5) - sess.flush() - - assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some t5')] - assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some t6', 1), (2, 'some other t6', 1)] - - o6 = T5(data='some other t5', id=o5.id, t6s=[ - T6(data='third t6', id=3), - T6(data='fourth t6', id=4), - ]) - sess.delete(o5) - sess.add(o6) - sess.flush() - - assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some other t5')] - assert list(sess.execute(t6.select(), mapper=T5)) == [(3, 'third t6', 1), (4, 'fourth t6', 1)] - - @testing.resolve_artifact_names - def test_manytomany(self): - mapper(T5, t5, properties={ - 't7s':relation(T7, secondary=t5t7, cascade="all") - }) - mapper(T7, t7) - - sess = create_session() - - o5 = T5(data='some t5', id=1) - o5.t7s.append(T7(data='some t7', id=1)) - o5.t7s.append(T7(data='some other t7', id=2)) - - sess.add(o5) - sess.flush() - - assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some t5')] - assert testing.rowset(sess.execute(t5t7.select(), mapper=T5)) == set([(1,1), (1, 2)]) - assert list(sess.execute(t7.select(), mapper=T5)) == [(1, 'some t7'), (2, 'some other t7')] - - o6 = T5(data='some other t5', id=1, t7s=[ - T7(data='third t7', id=3), - T7(data='fourth t7', id=4), - ]) - sess.delete(o5) - sess.add(o6) - sess.flush() - - assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some other t5')] - assert list(sess.execute(t7.select(), mapper=T5)) == [(3, 'third t7'), (4, 'fourth t7')] - - @testing.resolve_artifact_names - def test_manytoone(self): - - mapper(T6, t6, properties={ - 't5':relation(T5) - }) - mapper(T5, t5) - - sess = create_session() - - o5 = T6(data='some t6', id=1) - o5.t5 = T5(data='some t5', id=1) - - sess.add(o5) - sess.flush() - - assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some t5')] - assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some t6', 1)] - - o6 = T6(data='some other t6', id=1, t5=T5(data='some other t5', id=2)) - sess.delete(o5) - sess.delete(o5.t5) - sess.add(o6) - sess.flush() - - assert list(sess.execute(t5.select(), mapper=T5)) == [(2, 'some other t5')] - assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some other t6', 2)] - -class InheritingRowSwitchTest(_base.MappedTest): - def define_tables(self, metadata): - Table('parent', metadata, - Column('id', Integer, primary_key=True), - Column('pdata', String(30)) - ) - Table('child', metadata, - Column('id', Integer, primary_key=True), - Column('pid', Integer, ForeignKey('parent.id')), - Column('cdata', String(30)) - ) - - def setup_classes(self): - class P(_base.ComparableEntity): - pass - - class C(P): - pass - - @testing.resolve_artifact_names - def test_row_switch_no_child_table(self): - mapper(P, parent) - mapper(C, child, inherits=P) - - sess = create_session() - c1 = C(id=1, pdata='c1', cdata='c1') - sess.add(c1) - sess.flush() - - # establish a row switch between c1 and c2. - # c2 has no value for the "child" table - c2 = C(id=1, pdata='c2') - sess.add(c2) - sess.delete(c1) - - self.assert_sql_execution(testing.db, sess.flush, - CompiledSQL("UPDATE parent SET pdata=:pdata WHERE parent.id = :parent_id", - {'pdata':'c2', 'parent_id':1} - ) - ) - - - -class TransactionTest(_base.MappedTest): - __requires__ = ('deferrable_constraints',) - - __whitelist__ = ('sqlite',) - # sqlite doesn't have deferrable constraints, but it allows them to - # be specified. it'll raise immediately post-INSERT, instead of at - # COMMIT. either way, this test should pass. - - def define_tables(self, metadata): - t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True)) - - t2 = Table('t2', metadata, - Column('id', Integer, primary_key=True), - Column('t1_id', Integer, - ForeignKey('t1.id', deferrable=True, initially='deferred') - )) - def setup_classes(self): - class T1(_base.ComparableEntity): - pass - - class T2(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def setup_mappers(self): - orm_mapper(T1, t1) - orm_mapper(T2, t2) - - @testing.resolve_artifact_names - def test_close_transaction_on_commit_fail(self): - session = create_session(autocommit=True) - - # with a deferred constraint, this fails at COMMIT time instead - # of at INSERT time. - session.add(T2(t1_id=123)) - - try: - session.flush() - assert False - except: - # Flush needs to rollback also when commit fails - assert session.transaction is None - - # todo: on 8.3 at least, the failed commit seems to close the cursor? - # needs investigation. leaving in the DDL above now to help verify - # that the new deferrable support on FK isn't involved in this issue. - if testing.against('postgres'): - t1.bind.engine.dispose() - -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/utils.py b/test/orm/utils.py deleted file mode 100644 index 813121a44..000000000 --- a/test/orm/utils.py +++ /dev/null @@ -1,241 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy.orm import interfaces, util -from sqlalchemy import Column -from sqlalchemy import Integer -from sqlalchemy import MetaData -from sqlalchemy import Table -from sqlalchemy.orm import aliased -from sqlalchemy.orm import mapper, create_session - - -from testlib import TestBase, testing - -from orm import _fixtures -from testlib.testing import eq_ - - -class ExtensionCarrierTest(TestBase): - def test_basic(self): - carrier = util.ExtensionCarrier() - - assert 'translate_row' not in carrier - assert carrier.translate_row() is interfaces.EXT_CONTINUE - assert 'translate_row' not in carrier - - self.assertRaises(AttributeError, lambda: carrier.snickysnack) - - class Partial(object): - def __init__(self, marker): - self.marker = marker - def translate_row(self, row): - return self.marker - - carrier.append(Partial('end')) - assert 'translate_row' in carrier - assert carrier.translate_row(None) == 'end' - - carrier.push(Partial('front')) - assert carrier.translate_row(None) == 'front' - - assert 'populate_instance' not in carrier - carrier.append(interfaces.MapperExtension) - assert 'populate_instance' in carrier - - assert carrier.interface - for m in carrier.interface: - assert getattr(interfaces.MapperExtension, m) - -class AliasedClassTest(TestBase): - def point_map(self, cls): - table = Table('point', MetaData(), - Column('id', Integer(), primary_key=True), - Column('x', Integer), - Column('y', Integer)) - mapper(cls, table) - return table - - def test_simple(self): - class Point(object): - pass - table = self.point_map(Point) - - alias = aliased(Point) - - assert alias.id - assert alias.x - assert alias.y - - assert Point.id.__clause_element__().table is table - assert alias.id.__clause_element__().table is not table - - def test_notcallable(self): - class Point(object): - pass - table = self.point_map(Point) - alias = aliased(Point) - - self.assertRaises(TypeError, alias) - - def test_instancemethods(self): - class Point(object): - def zero(self): - self.x, self.y = 0, 0 - - table = self.point_map(Point) - alias = aliased(Point) - - assert Point.zero - assert not getattr(alias, 'zero') - - def test_classmethods(self): - class Point(object): - @classmethod - def max_x(cls): - return 100 - - table = self.point_map(Point) - alias = aliased(Point) - - assert Point.max_x - assert alias.max_x - assert Point.max_x() == alias.max_x() - - def test_simpleproperties(self): - class Point(object): - @property - def max_x(self): - return 100 - - table = self.point_map(Point) - alias = aliased(Point) - - assert Point.max_x - assert Point.max_x != 100 - assert alias.max_x - assert Point.max_x is alias.max_x - - def test_descriptors(self): - class descriptor(object): - """Tortured...""" - def __init__(self, fn): - self.fn = fn - def __get__(self, obj, owner): - if obj is not None: - return self.fn(obj, obj) - else: - return self - def method(self): - return 'method' - - class Point(object): - center = (0, 0) - @descriptor - def thing(self, arg): - return arg.center - - table = self.point_map(Point) - alias = aliased(Point) - - assert Point.thing != (0, 0) - assert Point().thing == (0, 0) - assert Point.thing.method() == 'method' - - assert alias.thing != (0, 0) - assert alias.thing.method() == 'method' - - def test_hybrid_descriptors(self): - from sqlalchemy import Column # override testlib's override - import types - - class MethodDescriptor(object): - def __init__(self, func): - self.func = func - def __get__(self, instance, owner): - if instance is None: - args = (self.func, owner, owner.__class__) - else: - args = (self.func, instance, owner) - return types.MethodType(*args) - - class PropertyDescriptor(object): - def __init__(self, fget, fset, fdel): - self.fget = fget - self.fset = fset - self.fdel = fdel - def __get__(self, instance, owner): - if instance is None: - return self.fget(owner) - else: - return self.fget(instance) - def __set__(self, instance, value): - self.fset(instance, value) - def __delete__(self, instance): - self.fdel(instance) - hybrid = MethodDescriptor - def hybrid_property(fget, fset=None, fdel=None): - return PropertyDescriptor(fget, fset, fdel) - - def assert_table(expr, table): - for child in expr.get_children(): - if isinstance(child, Column): - assert child.table is table - - class Point(object): - def __init__(self, x, y): - self.x, self.y = x, y - @hybrid - def left_of(self, other): - return self.x < other.x - - double_x = hybrid_property(lambda self: self.x * 2) - - table = self.point_map(Point) - alias = aliased(Point) - alias_table = alias.x.__clause_element__().table - assert table is not alias_table - - p1 = Point(-10, -10) - p2 = Point(20, 20) - - assert p1.left_of(p2) - assert p1.double_x == -20 - - assert_table(Point.double_x, table) - assert_table(alias.double_x, alias_table) - - assert_table(Point.left_of(p2), table) - assert_table(alias.left_of(p2), alias_table) - -class IdentityKeyTest(_fixtures.FixtureTest): - run_inserts = None - - @testing.resolve_artifact_names - def test_identity_key_1(self): - mapper(User, users) - - key = util.identity_key(User, 1) - eq_(key, (User, (1,))) - key = util.identity_key(User, ident=1) - eq_(key, (User, (1,))) - - @testing.resolve_artifact_names - def test_identity_key_2(self): - mapper(User, users) - s = create_session() - u = User(name='u1') - s.add(u) - s.flush() - key = util.identity_key(instance=u) - eq_(key, (User, (u.id,))) - - @testing.resolve_artifact_names - def test_identity_key_3(self): - mapper(User, users) - - row = {users.c.id: 1, users.c.name: "Frank"} - key = util.identity_key(User, row=row) - eq_(key, (User, (1,))) - -if __name__ == '__main__': - testenv.main() - -- cgit v1.2.1