From 45cec095b4904ba71425d2fe18c143982dd08f43 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 10 Jun 2009 21:18:24 +0000 Subject: - unit tests have been migrated from unittest to nose. See README.unittests for information on how to run the tests. [ticket:970] --- test/ext/alltests.py | 36 - test/ext/associationproxy.py | 887 --------------------- test/ext/compiler.py | 126 --- test/ext/declarative.py | 1538 ------------------------------------ test/ext/orderinglist.py | 403 ---------- test/ext/serializer.py | 139 ---- test/ext/test_associationproxy.py | 885 +++++++++++++++++++++ test/ext/test_compiler.py | 123 +++ test/ext/test_declarative.py | 1545 +++++++++++++++++++++++++++++++++++++ test/ext/test_orderinglist.py | 400 ++++++++++ test/ext/test_serializer.py | 144 ++++ 11 files changed, 3097 insertions(+), 3129 deletions(-) delete mode 100644 test/ext/alltests.py delete mode 100644 test/ext/associationproxy.py delete mode 100644 test/ext/compiler.py delete mode 100644 test/ext/declarative.py delete mode 100644 test/ext/orderinglist.py delete mode 100644 test/ext/serializer.py create mode 100644 test/ext/test_associationproxy.py create mode 100644 test/ext/test_compiler.py create mode 100644 test/ext/test_declarative.py create mode 100644 test/ext/test_orderinglist.py create mode 100644 test/ext/test_serializer.py (limited to 'test/ext') diff --git a/test/ext/alltests.py b/test/ext/alltests.py deleted file mode 100644 index 9f5353e04..000000000 --- a/test/ext/alltests.py +++ /dev/null @@ -1,36 +0,0 @@ -import testenv; testenv.configure_for_tests() -import doctest, sys - -from testlib import sa_unittest as unittest - - -def suite(): - unittest_modules = ( - 'ext.declarative', - 'ext.orderinglist', - 'ext.associationproxy', - 'ext.serializer', - 'ext.compiler', - ) - - if sys.version_info < (2, 4): - doctest_modules = () - else: - doctest_modules = ( - ('sqlalchemy.ext.orderinglist', {'optionflags': doctest.ELLIPSIS}), - ('sqlalchemy.ext.sqlsoup', {}) - ) - - alltests = unittest.TestSuite() - for name in unittest_modules: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - for name, opts in doctest_modules: - alltests.addTest(doctest.DocTestSuite(name, **opts)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/ext/associationproxy.py b/test/ext/associationproxy.py deleted file mode 100644 index 821ed9072..000000000 --- a/test/ext/associationproxy.py +++ /dev/null @@ -1,887 +0,0 @@ -import testenv; testenv.configure_for_tests() -import gc -from sqlalchemy import * -from sqlalchemy.orm import * -from sqlalchemy.orm.collections import collection -from sqlalchemy.ext.associationproxy import * -from testlib import * - - -class DictCollection(dict): - @collection.appender - def append(self, obj): - self[obj.foo] = obj - @collection.remover - def remove(self, obj): - del self[obj.foo] - -class SetCollection(set): - pass - -class ListCollection(list): - pass - -class ObjectCollection(object): - def __init__(self): - self.values = list() - @collection.appender - def append(self, obj): - self.values.append(obj) - @collection.remover - def remove(self, obj): - self.values.remove(obj) - def __iter__(self): - return iter(self.values) - -class _CollectionOperations(TestBase): - def setUp(self): - collection_class = self.collection_class - - metadata = MetaData(testing.db) - - parents_table = Table('Parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(128))) - children_table = Table('Children', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, - ForeignKey('Parent.id')), - Column('foo', String(128)), - Column('name', String(128))) - - class Parent(object): - children = association_proxy('_children', 'name') - - def __init__(self, name): - self.name = name - - class Child(object): - if collection_class and issubclass(collection_class, dict): - def __init__(self, foo, name): - self.foo = foo - self.name = name - else: - def __init__(self, name): - self.name = name - - mapper(Parent, parents_table, properties={ - '_children': relation(Child, lazy=False, - collection_class=collection_class)}) - mapper(Child, children_table) - - metadata.create_all() - - self.metadata = metadata - self.session = create_session() - self.Parent, self.Child = Parent, Child - - def tearDown(self): - self.metadata.drop_all() - - def roundtrip(self, obj): - if obj not in self.session: - self.session.add(obj) - self.session.flush() - id, type_ = obj.id, type(obj) - self.session.expunge_all() - return self.session.query(type_).get(id) - - def _test_sequence_ops(self): - Parent, Child = self.Parent, self.Child - - p1 = Parent('P1') - - self.assert_(not p1._children) - self.assert_(not p1.children) - - ch = Child('regular') - p1._children.append(ch) - - self.assert_(ch in p1._children) - self.assert_(len(p1._children) == 1) - - self.assert_(p1.children) - self.assert_(len(p1.children) == 1) - self.assert_(ch not in p1.children) - self.assert_('regular' in p1.children) - - p1.children.append('proxied') - - self.assert_('proxied' in p1.children) - self.assert_('proxied' not in p1._children) - self.assert_(len(p1.children) == 2) - self.assert_(len(p1._children) == 2) - - self.assert_(p1._children[0].name == 'regular') - self.assert_(p1._children[1].name == 'proxied') - - del p1._children[1] - - self.assert_(len(p1._children) == 1) - self.assert_(len(p1.children) == 1) - self.assert_(p1._children[0] == ch) - - del p1.children[0] - - self.assert_(len(p1._children) == 0) - self.assert_(len(p1.children) == 0) - - p1.children = ['a','b','c'] - self.assert_(len(p1._children) == 3) - self.assert_(len(p1.children) == 3) - - del ch - p1 = self.roundtrip(p1) - - self.assert_(len(p1._children) == 3) - self.assert_(len(p1.children) == 3) - - popped = p1.children.pop() - self.assert_(len(p1.children) == 2) - self.assert_(popped not in p1.children) - p1 = self.roundtrip(p1) - self.assert_(len(p1.children) == 2) - self.assert_(popped not in p1.children) - - p1.children[1] = 'changed-in-place' - self.assert_(p1.children[1] == 'changed-in-place') - inplace_id = p1._children[1].id - p1 = self.roundtrip(p1) - self.assert_(p1.children[1] == 'changed-in-place') - assert p1._children[1].id == inplace_id - - p1.children.append('changed-in-place') - self.assert_(p1.children.count('changed-in-place') == 2) - - p1.children.remove('changed-in-place') - self.assert_(p1.children.count('changed-in-place') == 1) - - p1 = self.roundtrip(p1) - self.assert_(p1.children.count('changed-in-place') == 1) - - p1._children = [] - self.assert_(len(p1.children) == 0) - - after = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] - p1.children = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] - self.assert_(len(p1.children) == 10) - self.assert_([c.name for c in p1._children] == after) - - p1.children[2:6] = ['x'] * 4 - after = ['a', 'b', 'x', 'x', 'x', 'x', 'g', 'h', 'i', 'j'] - self.assert_(p1.children == after) - self.assert_([c.name for c in p1._children] == after) - - p1.children[2:6] = ['y'] - after = ['a', 'b', 'y', 'g', 'h', 'i', 'j'] - self.assert_(p1.children == after) - self.assert_([c.name for c in p1._children] == after) - - p1.children[2:3] = ['z'] * 4 - after = ['a', 'b', 'z', 'z', 'z', 'z', 'g', 'h', 'i', 'j'] - self.assert_(p1.children == after) - self.assert_([c.name for c in p1._children] == after) - - p1.children[2::2] = ['O'] * 4 - after = ['a', 'b', 'O', 'z', 'O', 'z', 'O', 'h', 'O', 'j'] - self.assert_(p1.children == after) - self.assert_([c.name for c in p1._children] == after) - - self.assertRaises(TypeError, set, [p1.children]) - - p1.children *= 0 - after = [] - self.assert_(p1.children == after) - self.assert_([c.name for c in p1._children] == after) - - p1.children += ['a', 'b'] - after = ['a', 'b'] - self.assert_(p1.children == after) - self.assert_([c.name for c in p1._children] == after) - - p1.children += ['c'] - after = ['a', 'b', 'c'] - self.assert_(p1.children == after) - self.assert_([c.name for c in p1._children] == after) - - p1.children *= 1 - after = ['a', 'b', 'c'] - self.assert_(p1.children == after) - self.assert_([c.name for c in p1._children] == after) - - p1.children *= 2 - after = ['a', 'b', 'c', 'a', 'b', 'c'] - self.assert_(p1.children == after) - self.assert_([c.name for c in p1._children] == after) - - p1.children = ['a'] - after = ['a'] - self.assert_(p1.children == after) - self.assert_([c.name for c in p1._children] == after) - - self.assert_((p1.children * 2) == ['a', 'a']) - self.assert_((2 * p1.children) == ['a', 'a']) - self.assert_((p1.children * 0) == []) - self.assert_((0 * p1.children) == []) - - self.assert_((p1.children + ['b']) == ['a', 'b']) - self.assert_((['b'] + p1.children) == ['b', 'a']) - - try: - p1.children + 123 - assert False - except TypeError: - assert True - -class DefaultTest(_CollectionOperations): - def __init__(self, *args, **kw): - super(DefaultTest, self).__init__(*args, **kw) - self.collection_class = None - - def test_sequence_ops(self): - self._test_sequence_ops() - - -class ListTest(_CollectionOperations): - def __init__(self, *args, **kw): - super(ListTest, self).__init__(*args, **kw) - self.collection_class = list - - def test_sequence_ops(self): - self._test_sequence_ops() - -class CustomListTest(ListTest): - def __init__(self, *args, **kw): - super(CustomListTest, self).__init__(*args, **kw) - self.collection_class = list - -# No-can-do until ticket #213 -class DictTest(_CollectionOperations): - pass - -class CustomDictTest(DictTest): - def __init__(self, *args, **kw): - super(DictTest, self).__init__(*args, **kw) - self.collection_class = DictCollection - - def test_mapping_ops(self): - Parent, Child = self.Parent, self.Child - - p1 = Parent('P1') - - self.assert_(not p1._children) - self.assert_(not p1.children) - - ch = Child('a', 'regular') - p1._children.append(ch) - - self.assert_(ch in p1._children.values()) - self.assert_(len(p1._children) == 1) - - self.assert_(p1.children) - self.assert_(len(p1.children) == 1) - self.assert_(ch not in p1.children) - self.assert_('a' in p1.children) - self.assert_(p1.children['a'] == 'regular') - self.assert_(p1._children['a'] == ch) - - p1.children['b'] = 'proxied' - - self.assert_('proxied' in p1.children.values()) - self.assert_('b' in p1.children) - self.assert_('proxied' not in p1._children) - self.assert_(len(p1.children) == 2) - self.assert_(len(p1._children) == 2) - - self.assert_(p1._children['a'].name == 'regular') - self.assert_(p1._children['b'].name == 'proxied') - - del p1._children['b'] - - self.assert_(len(p1._children) == 1) - self.assert_(len(p1.children) == 1) - self.assert_(p1._children['a'] == ch) - - del p1.children['a'] - - self.assert_(len(p1._children) == 0) - self.assert_(len(p1.children) == 0) - - p1.children = {'d': 'v d', 'e': 'v e', 'f': 'v f'} - self.assert_(len(p1._children) == 3) - self.assert_(len(p1.children) == 3) - - self.assert_(set(p1.children) == set(['d','e','f'])) - - del ch - p1 = self.roundtrip(p1) - self.assert_(len(p1._children) == 3) - self.assert_(len(p1.children) == 3) - - p1.children['e'] = 'changed-in-place' - self.assert_(p1.children['e'] == 'changed-in-place') - inplace_id = p1._children['e'].id - p1 = self.roundtrip(p1) - self.assert_(p1.children['e'] == 'changed-in-place') - self.assert_(p1._children['e'].id == inplace_id) - - p1._children = {} - self.assert_(len(p1.children) == 0) - - try: - p1._children = [] - self.assert_(False) - except TypeError: - self.assert_(True) - - try: - p1._children = None - self.assert_(False) - except TypeError: - self.assert_(True) - - self.assertRaises(TypeError, set, [p1.children]) - - -class SetTest(_CollectionOperations): - def __init__(self, *args, **kw): - super(SetTest, self).__init__(*args, **kw) - self.collection_class = set - - def test_set_operations(self): - Parent, Child = self.Parent, self.Child - - p1 = Parent('P1') - - self.assert_(not p1._children) - self.assert_(not p1.children) - - ch1 = Child('regular') - p1._children.add(ch1) - - self.assert_(ch1 in p1._children) - self.assert_(len(p1._children) == 1) - - self.assert_(p1.children) - self.assert_(len(p1.children) == 1) - self.assert_(ch1 not in p1.children) - self.assert_('regular' in p1.children) - - p1.children.add('proxied') - - self.assert_('proxied' in p1.children) - self.assert_('proxied' not in p1._children) - self.assert_(len(p1.children) == 2) - self.assert_(len(p1._children) == 2) - - self.assert_(set([o.name for o in p1._children]) == - set(['regular', 'proxied'])) - - ch2 = None - for o in p1._children: - if o.name == 'proxied': - ch2 = o - break - - p1._children.remove(ch2) - - self.assert_(len(p1._children) == 1) - self.assert_(len(p1.children) == 1) - self.assert_(p1._children == set([ch1])) - - p1.children.remove('regular') - - self.assert_(len(p1._children) == 0) - self.assert_(len(p1.children) == 0) - - p1.children = ['a','b','c'] - self.assert_(len(p1._children) == 3) - self.assert_(len(p1.children) == 3) - - del ch1 - p1 = self.roundtrip(p1) - - self.assert_(len(p1._children) == 3) - self.assert_(len(p1.children) == 3) - - self.assert_('a' in p1.children) - self.assert_('b' in p1.children) - self.assert_('d' not in p1.children) - - self.assert_(p1.children == set(['a','b','c'])) - - try: - p1.children.remove('d') - self.fail() - except KeyError: - pass - - self.assert_(len(p1.children) == 3) - p1.children.discard('d') - self.assert_(len(p1.children) == 3) - p1 = self.roundtrip(p1) - self.assert_(len(p1.children) == 3) - - popped = p1.children.pop() - self.assert_(len(p1.children) == 2) - self.assert_(popped not in p1.children) - p1 = self.roundtrip(p1) - self.assert_(len(p1.children) == 2) - self.assert_(popped not in p1.children) - - p1.children = ['a','b','c'] - p1 = self.roundtrip(p1) - self.assert_(p1.children == set(['a','b','c'])) - - p1.children.discard('b') - p1 = self.roundtrip(p1) - self.assert_(p1.children == set(['a', 'c'])) - - p1.children.remove('a') - p1 = self.roundtrip(p1) - self.assert_(p1.children == set(['c'])) - - p1._children = set() - self.assert_(len(p1.children) == 0) - - try: - p1._children = [] - self.assert_(False) - except TypeError: - self.assert_(True) - - try: - p1._children = None - self.assert_(False) - except TypeError: - self.assert_(True) - - self.assertRaises(TypeError, set, [p1.children]) - - - def test_set_comparisons(self): - Parent, Child = self.Parent, self.Child - - p1 = Parent('P1') - p1.children = ['a','b','c'] - control = set(['a','b','c']) - - for other in (set(['a','b','c']), set(['a','b','c','d']), - set(['a']), set(['a','b']), - set(['c','d']), set(['e', 'f', 'g']), - set()): - - self.assertEqual(p1.children.union(other), - control.union(other)) - self.assertEqual(p1.children.difference(other), - control.difference(other)) - self.assertEqual((p1.children - other), - (control - other)) - self.assertEqual(p1.children.intersection(other), - control.intersection(other)) - self.assertEqual(p1.children.symmetric_difference(other), - control.symmetric_difference(other)) - self.assertEqual(p1.children.issubset(other), - control.issubset(other)) - self.assertEqual(p1.children.issuperset(other), - control.issuperset(other)) - - self.assert_((p1.children == other) == (control == other)) - self.assert_((p1.children != other) == (control != other)) - self.assert_((p1.children < other) == (control < other)) - self.assert_((p1.children <= other) == (control <= other)) - self.assert_((p1.children > other) == (control > other)) - self.assert_((p1.children >= other) == (control >= other)) - - def test_set_mutation(self): - Parent, Child = self.Parent, self.Child - - # mutations - for op in ('update', 'intersection_update', - 'difference_update', 'symmetric_difference_update'): - for base in (['a', 'b', 'c'], []): - for other in (set(['a','b','c']), set(['a','b','c','d']), - set(['a']), set(['a','b']), - set(['c','d']), set(['e', 'f', 'g']), - set()): - p = Parent('p') - p.children = base[:] - control = set(base[:]) - - getattr(p.children, op)(other) - getattr(control, op)(other) - try: - self.assert_(p.children == control) - except: - print 'Test %s.%s(%s):' % (set(base), op, other) - print 'want', repr(control) - print 'got', repr(p.children) - raise - - p = self.roundtrip(p) - - try: - self.assert_(p.children == control) - except: - print 'Test %s.%s(%s):' % (base, op, other) - print 'want', repr(control) - print 'got', repr(p.children) - raise - - # in-place mutations - for op in ('|=', '-=', '&=', '^='): - for base in (['a', 'b', 'c'], []): - for other in (set(['a','b','c']), set(['a','b','c','d']), - set(['a']), set(['a','b']), - set(['c','d']), set(['e', 'f', 'g']), - frozenset(['e', 'f', 'g']), - set()): - p = Parent('p') - p.children = base[:] - control = set(base[:]) - - exec "p.children %s other" % op - exec "control %s other" % op - - try: - self.assert_(p.children == control) - except: - print 'Test %s %s %s:' % (set(base), op, other) - print 'want', repr(control) - print 'got', repr(p.children) - raise - - p = self.roundtrip(p) - - try: - self.assert_(p.children == control) - except: - print 'Test %s %s %s:' % (base, op, other) - print 'want', repr(control) - print 'got', repr(p.children) - raise - - -class CustomSetTest(SetTest): - def __init__(self, *args, **kw): - super(CustomSetTest, self).__init__(*args, **kw) - self.collection_class = SetCollection - -class CustomObjectTest(_CollectionOperations): - def __init__(self, *args, **kw): - super(CustomObjectTest, self).__init__(*args, **kw) - self.collection_class = ObjectCollection - - def test_basic(self): - Parent, Child = self.Parent, self.Child - - p = Parent('p1') - self.assert_(len(list(p.children)) == 0) - - p.children.append('child') - self.assert_(len(list(p.children)) == 1) - - p = self.roundtrip(p) - self.assert_(len(list(p.children)) == 1) - - # We didn't provide an alternate _AssociationList implementation for - # our ObjectCollection, so indexing will fail. - try: - v = p.children[1] - self.fail() - except TypeError: - pass - -class ScalarTest(TestBase): - def test_scalar_proxy(self): - metadata = MetaData(testing.db) - - parents_table = Table('Parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(128))) - children_table = Table('Children', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, - ForeignKey('Parent.id')), - Column('foo', String(128)), - Column('bar', String(128)), - Column('baz', String(128))) - - class Parent(object): - foo = association_proxy('child', 'foo') - bar = association_proxy('child', 'bar', - creator=lambda v: Child(bar=v)) - baz = association_proxy('child', 'baz', - creator=lambda v: Child(baz=v)) - - def __init__(self, name): - self.name = name - - class Child(object): - def __init__(self, **kw): - for attr in kw: - setattr(self, attr, kw[attr]) - - mapper(Parent, parents_table, properties={ - 'child': relation(Child, lazy=False, - backref='parent', uselist=False)}) - mapper(Child, children_table) - - metadata.create_all() - session = create_session() - - def roundtrip(obj): - if obj not in session: - session.add(obj) - session.flush() - id, type_ = obj.id, type(obj) - session.expunge_all() - return session.query(type_).get(id) - - p = Parent('p') - - # No child - try: - v = p.foo - self.fail() - except: - pass - - p.child = Child(foo='a', bar='b', baz='c') - - self.assert_(p.foo == 'a') - self.assert_(p.bar == 'b') - self.assert_(p.baz == 'c') - - p.bar = 'x' - self.assert_(p.foo == 'a') - self.assert_(p.bar == 'x') - self.assert_(p.baz == 'c') - - p = roundtrip(p) - - self.assert_(p.foo == 'a') - self.assert_(p.bar == 'x') - self.assert_(p.baz == 'c') - - p.child = None - - # No child again - try: - v = p.foo - self.fail() - except: - pass - - # Bogus creator for this scalar type - try: - p.foo = 'zzz' - self.fail() - except TypeError: - pass - - p.bar = 'yyy' - - self.assert_(p.foo is None) - self.assert_(p.bar == 'yyy') - self.assert_(p.baz is None) - - del p.child - - p = roundtrip(p) - - self.assert_(p.child is None) - - p.baz = 'xxx' - - self.assert_(p.foo is None) - self.assert_(p.bar is None) - self.assert_(p.baz == 'xxx') - - p = roundtrip(p) - - self.assert_(p.foo is None) - self.assert_(p.bar is None) - self.assert_(p.baz == 'xxx') - - # Ensure an immediate __set__ works. - p2 = Parent('p2') - p2.bar = 'quux' - - -class LazyLoadTest(TestBase): - def setUp(self): - metadata = MetaData(testing.db) - - parents_table = Table('Parent', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(128))) - children_table = Table('Children', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, - ForeignKey('Parent.id')), - Column('foo', String(128)), - Column('name', String(128))) - - class Parent(object): - children = association_proxy('_children', 'name') - - def __init__(self, name): - self.name = name - - class Child(object): - def __init__(self, name): - self.name = name - - - mapper(Child, children_table) - metadata.create_all() - - self.metadata = metadata - self.session = create_session() - self.Parent, self.Child = Parent, Child - self.table = parents_table - - def tearDown(self): - self.metadata.drop_all() - - def roundtrip(self, obj): - self.session.add(obj) - self.session.flush() - id, type_ = obj.id, type(obj) - self.session.expunge_all() - return self.session.query(type_).get(id) - - def test_lazy_list(self): - Parent, Child = self.Parent, self.Child - - mapper(Parent, self.table, properties={ - '_children': relation(Child, lazy=True, - collection_class=list)}) - - p = Parent('p') - p.children = ['a','b','c'] - - p = self.roundtrip(p) - - # Is there a better way to ensure that the association_proxy - # didn't convert a lazy load to an eager load? This does work though. - self.assert_('_children' not in p.__dict__) - self.assert_(len(p._children) == 3) - self.assert_('_children' in p.__dict__) - - def test_eager_list(self): - Parent, Child = self.Parent, self.Child - - mapper(Parent, self.table, properties={ - '_children': relation(Child, lazy=False, - collection_class=list)}) - - p = Parent('p') - p.children = ['a','b','c'] - - p = self.roundtrip(p) - - self.assert_('_children' in p.__dict__) - self.assert_(len(p._children) == 3) - - def test_lazy_scalar(self): - Parent, Child = self.Parent, self.Child - - mapper(Parent, self.table, properties={ - '_children': relation(Child, lazy=True, uselist=False)}) - - - p = Parent('p') - p.children = 'value' - - p = self.roundtrip(p) - - self.assert_('_children' not in p.__dict__) - self.assert_(p._children is not None) - - def test_eager_scalar(self): - Parent, Child = self.Parent, self.Child - - mapper(Parent, self.table, properties={ - '_children': relation(Child, lazy=False, uselist=False)}) - - - p = Parent('p') - p.children = 'value' - - p = self.roundtrip(p) - - self.assert_('_children' in p.__dict__) - self.assert_(p._children is not None) - - -class ReconstitutionTest(TestBase): - def setUp(self): - metadata = MetaData(testing.db) - parents = Table('parents', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(30))) - children = Table('children', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('parent_id', Integer, ForeignKey('parents.id')), - Column('name', String(30))) - metadata.create_all() - parents.insert().execute(name='p1') - - class Parent(object): - kids = association_proxy('children', 'name') - def __init__(self, name): - self.name = name - - class Child(object): - def __init__(self, name): - self.name = name - - mapper(Parent, parents, properties=dict(children=relation(Child))) - mapper(Child, children) - - self.metadata = metadata - self.Parent = Parent - - def tearDown(self): - self.metadata.drop_all() - - def test_weak_identity_map(self): - session = create_session(weak_identity_map=True) - - def add_child(parent_name, child_name): - parent = (session.query(self.Parent). - filter_by(name=parent_name)).one() - parent.kids.append(child_name) - - - add_child('p1', 'c1') - gc.collect() - add_child('p1', 'c2') - - session.flush() - p = session.query(self.Parent).filter_by(name='p1').one() - assert set(p.kids) == set(['c1', 'c2']), p.kids - - def test_copy(self): - import copy - p = self.Parent('p1') - p.kids.extend(['c1', 'c2']) - p_copy = copy.copy(p) - del p - gc.collect() - - assert set(p_copy.kids) == set(['c1', 'c2']), p.kids - - -if __name__ == "__main__": - testenv.main() diff --git a/test/ext/compiler.py b/test/ext/compiler.py deleted file mode 100644 index 370ea62ab..000000000 --- a/test/ext/compiler.py +++ /dev/null @@ -1,126 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.sql.expression import ClauseElement, ColumnClause -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql import table, column -from testlib import * - -class UserDefinedTest(TestBase, AssertsCompiledSQL): - - def test_column(self): - - class MyThingy(ColumnClause): - def __init__(self, arg= None): - super(MyThingy, self).__init__(arg or 'MYTHINGY!') - - @compiles(MyThingy) - def visit_thingy(thingy, compiler, **kw): - return ">>%s<<" % thingy.name - - self.assert_compile( - select([column('foo'), MyThingy()]), - "SELECT foo, >>MYTHINGY!<<" - ) - - self.assert_compile( - select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5), - "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1" - ) - - def test_stateful(self): - class MyThingy(ColumnClause): - def __init__(self): - super(MyThingy, self).__init__('MYTHINGY!') - - @compiles(MyThingy) - def visit_thingy(thingy, compiler, **kw): - if not hasattr(compiler, 'counter'): - compiler.counter = 0 - compiler.counter += 1 - return str(compiler.counter) - - self.assert_compile( - select([column('foo'), MyThingy()]).order_by(desc(MyThingy())), - "SELECT foo, 1 ORDER BY 2 DESC" - ) - - self.assert_compile( - select([MyThingy(), MyThingy()]).where(MyThingy() == 5), - "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1" - ) - - def test_callout_to_compiler(self): - class InsertFromSelect(ClauseElement): - def __init__(self, table, select): - self.table = table - self.select = select - - @compiles(InsertFromSelect) - def visit_insert_from_select(element, compiler, **kw): - return "INSERT INTO %s (%s)" % ( - compiler.process(element.table, asfrom=True), - compiler.process(element.select) - ) - - t1 = table("mytable", column('x'), column('y'), column('z')) - self.assert_compile( - InsertFromSelect( - t1, - select([t1]).where(t1.c.x>5) - ), - "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)" - ) - - def test_dialect_specific(self): - class AddThingy(ClauseElement): - __visit_name__ = 'add_thingy' - - class DropThingy(ClauseElement): - __visit_name__ = 'drop_thingy' - - @compiles(AddThingy, 'sqlite') - def visit_add_thingy(thingy, compiler, **kw): - return "ADD SPECIAL SL THINGY" - - @compiles(AddThingy) - def visit_add_thingy(thingy, compiler, **kw): - return "ADD THINGY" - - @compiles(DropThingy) - def visit_drop_thingy(thingy, compiler, **kw): - return "DROP THINGY" - - self.assert_compile(AddThingy(), - "ADD THINGY" - ) - - self.assert_compile(DropThingy(), - "DROP THINGY" - ) - - from sqlalchemy.databases import sqlite as base - self.assert_compile(AddThingy(), - "ADD SPECIAL SL THINGY", - dialect=base.dialect() - ) - - self.assert_compile(DropThingy(), - "DROP THINGY", - dialect=base.dialect() - ) - - @compiles(DropThingy, 'sqlite') - def visit_drop_thingy(thingy, compiler, **kw): - return "DROP SPECIAL SL THINGY" - - self.assert_compile(DropThingy(), - "DROP SPECIAL SL THINGY", - dialect=base.dialect() - ) - - self.assert_compile(DropThingy(), - "DROP THINGY", - ) - -if __name__ == '__main__': - testenv.main() diff --git a/test/ext/declarative.py b/test/ext/declarative.py deleted file mode 100644 index f5130b215..000000000 --- a/test/ext/declarative.py +++ /dev/null @@ -1,1538 +0,0 @@ -import testenv; testenv.configure_for_tests() - -from sqlalchemy.ext import declarative as decl -from sqlalchemy import exc -from testlib import sa, testing -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index -from testlib.sa.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred -from testlib.testing import eq_ - - -from orm._base import ComparableEntity, MappedTest - -class DeclarativeTestBase(testing.TestBase, testing.AssertsExecutionResults): - def setUp(self): - global Base - Base = decl.declarative_base(testing.db) - - def tearDown(self): - clear_mappers() - Base.metadata.drop_all() - -class DeclarativeTest(DeclarativeTestBase): - def test_basic(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - addresses = relation("Address", backref="user") - - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - - id = Column(Integer, primary_key=True) - email = Column(String(50), key='_email') - user_id = Column('user_id', Integer, ForeignKey('users.id'), - key='_user_id') - - Base.metadata.create_all() - - eq_(Address.__table__.c['id'].name, 'id') - eq_(Address.__table__.c['_email'].name, 'email') - eq_(Address.__table__.c['_user_id'].name, 'user_id') - - u1 = User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), [User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ])]) - - a1 = sess.query(Address).filter(Address.email == 'two').one() - eq_(a1, Address(email='two')) - eq_(a1.user, User(name='u1')) - - def test_no_table(self): - def go(): - class User(Base): - id = Column('id', Integer, primary_key=True) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "does not have a __table__", go) - - def test_cant_add_columns(self): - t = Table('t', Base.metadata, Column('id', Integer, primary_key=True), Column('data', String)) - def go(): - class User(Base): - __table__ = t - foo = Column(Integer, primary_key=True) - # can't specify new columns not already in the table - self.assertRaisesMessage(sa.exc.ArgumentError, "Can't add additional column 'foo' when specifying __table__", go) - - # regular re-mapping works tho - class Bar(Base): - __table__ = t - some_data = t.c.data - - assert class_mapper(Bar).get_property('some_data').columns[0] is t.c.data - - def test_undefer_column_name(self): - # TODO: not sure if there was an explicit - # test for this elsewhere - foo = Column(Integer) - eq_(str(foo), '(no name)') - eq_(foo.key, None) - eq_(foo.name, None) - decl._undefer_column_name('foo', foo) - eq_(str(foo), 'foo') - eq_(foo.key, 'foo') - eq_(foo.name, 'foo') - - def test_recompile_on_othermapper(self): - """declarative version of the same test in mappers.py""" - - from sqlalchemy.orm import mapperlib - - class User(Base): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - - class Address(Base): - __tablename__ = 'addresses' - - id = Column('id', Integer, primary_key=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) - user = relation("User", primaryjoin=user_id == User.id, - backref="addresses") - - assert mapperlib._new_mappers is True - u = User() - assert User.addresses - assert mapperlib._new_mappers is False - - def test_string_dependency_resolution(self): - from sqlalchemy.sql import desc - - class User(Base, ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - addresses = relation("Address", order_by="desc(Address.email)", - primaryjoin="User.id==Address.user_id", foreign_keys="[Address.user_id]", - backref=backref('user', primaryjoin="User.id==Address.user_id", foreign_keys="[Address.user_id]") - ) - - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) - email = Column(String(50)) - user_id = Column(Integer) # note no foreign key - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='ed', addresses=[Address(email='abc'), Address(email='def'), Address(email='xyz')]) - sess.add(u1) - sess.flush() - sess.expunge_all() - self.assertEquals(sess.query(User).filter(User.name == 'ed').one(), - User(name='ed', addresses=[Address(email='xyz'), Address(email='def'), Address(email='abc')]) - ) - - class Foo(Base, ComparableEntity): - __tablename__ = 'foo' - id = Column(Integer, primary_key=True) - rel = relation("User", primaryjoin="User.addresses==Foo.id") - self.assertRaisesMessage(exc.InvalidRequestError, "'addresses' is not an instance of ColumnProperty", compile_mappers) - - def test_string_dependency_resolution_in_backref(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - addresses = relation("Address", - primaryjoin="User.id==Address.user_id", - backref="user" - ) - - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) - email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) - - compile_mappers() - eq_(str(User.addresses.property.primaryjoin), str(Address.user.property.primaryjoin)) - - - def test_uncompiled_attributes_in_relation(self): - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) - email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - addresses = relation("Address", order_by=Address.email, - foreign_keys=Address.user_id, - remote_side=Address.user_id, - ) - - # get the mapper for User. User mapper will compile, - # "addresses" relation will call upon Address.user_id for - # its clause element. Address.user_id is a _CompileOnAttr, - # which then calls class_mapper(Address). But ! We're already - # "in compilation", but class_mapper(Address) needs to initialize - # regardless, or COA's assertion fails - # and things generally go downhill from there. - class_mapper(User) - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='ed', addresses=[Address(email='abc'), Address(email='xyz'), Address(email='def')]) - sess.add(u1) - sess.flush() - sess.expunge_all() - self.assertEquals(sess.query(User).filter(User.name == 'ed').one(), - User(name='ed', addresses=[Address(email='abc'), Address(email='def'), Address(email='xyz')]) - ) - - def test_nice_dependency_error(self): - class User(Base): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) - addresses = relation("Address") - - class Address(Base): - __tablename__ = 'addresses' - - id = Column(Integer, primary_key=True) - foo = sa.orm.column_property(User.id == 5) - - # this used to raise an error when accessing User.id but that's no longer the case - # since we got rid of _CompileOnAttr. - self.assertRaises(sa.exc.ArgumentError, compile_mappers) - - def test_nice_dependency_error_works_with_hasattr(self): - class User(Base): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) - addresses = relation("Addresss") - - # hasattr() on a compile-loaded attribute - hasattr(User.addresses, 'property') - # the exeption is preserved - self.assertRaisesMessage(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers) - - def test_custom_base(self): - class MyBase(object): - def foobar(self): - return "foobar" - Base = decl.declarative_base(cls=MyBase) - assert hasattr(Base, 'metadata') - assert Base().foobar() == "foobar" - - def test_index_doesnt_compile(self): - class User(Base): - __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - error = relation("Address") - - i = Index('my_index', User.name) - - # compile fails due to the nonexistent Addresses relation - self.assertRaises(sa.exc.InvalidRequestError, compile_mappers) - - # index configured - assert i in User.__table__.indexes - assert User.__table__.c.id not in set(i.columns) - assert User.__table__.c.name in set(i.columns) - - # tables create fine - Base.metadata.create_all() - - def test_add_prop(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - User.name = Column('name', String(50)) - User.addresses = relation("Address", backref="user") - - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - - id = Column(Integer, primary_key=True) - Address.email = Column(String(50), key='_email') - Address.user_id = Column('user_id', Integer, ForeignKey('users.id'), - key='_user_id') - - Base.metadata.create_all() - - eq_(Address.__table__.c['id'].name, 'id') - eq_(Address.__table__.c['_email'].name, 'email') - eq_(Address.__table__.c['_user_id'].name, 'user_id') - - u1 = User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), [User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ])]) - - a1 = sess.query(Address).filter(Address.email == 'two').one() - eq_(a1, Address(email='two')) - eq_(a1.user, User(name='u1')) - - def test_eager_order_by(self): - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - - id = Column('id', Integer, primary_key=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - addresses = relation("Address", order_by=Address.email) - - Base.metadata.create_all() - u1 = User(name='u1', addresses=[ - Address(email='two'), - Address(email='one'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - eq_(sess.query(User).options(eagerload(User.addresses)).all(), [User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ])]) - - def test_order_by_multi(self): - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - - id = Column('id', Integer, primary_key=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - addresses = relation("Address", order_by=(Address.email, Address.id)) - - Base.metadata.create_all() - u1 = User(name='u1', addresses=[ - Address(email='two'), - Address(email='one'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - u = sess.query(User).filter(User.name == 'u1').one() - a = u.addresses - - def test_as_declarative(self): - class User(ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - addresses = relation("Address", backref="user") - - class Address(ComparableEntity): - __tablename__ = 'addresses' - - id = Column('id', Integer, primary_key=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) - - reg = {} - decl.instrument_declarative(User, reg, Base.metadata) - decl.instrument_declarative(Address, reg, Base.metadata) - Base.metadata.create_all() - - u1 = User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), [User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ])]) - - def test_custom_mapper(self): - class MyExt(sa.orm.MapperExtension): - def create_instance(self): - return "CHECK" - - def mymapper(cls, tbl, **kwargs): - kwargs['extension'] = MyExt() - return sa.orm.mapper(cls, tbl, **kwargs) - - from sqlalchemy.orm.mapper import Mapper - class MyMapper(Mapper): - def __init__(self, *args, **kwargs): - kwargs['extension'] = MyExt() - Mapper.__init__(self, *args, **kwargs) - - from sqlalchemy.orm import scoping - ss = scoping.ScopedSession(create_session) - ss.extension = MyExt() - ss_mapper = ss.mapper - - for mapperfunc in (mymapper, MyMapper, ss_mapper): - base = decl.declarative_base() - class Foo(base): - __tablename__ = 'foo' - __mapper_cls__ = mapperfunc - id = Column(Integer, primary_key=True) - eq_(Foo.__mapper__.compile().extension.create_instance(), 'CHECK') - - base = decl.declarative_base(mapper=mapperfunc) - class Foo(base): - __tablename__ = 'foo' - id = Column(Integer, primary_key=True) - eq_(Foo.__mapper__.compile().extension.create_instance(), 'CHECK') - - - @testing.emits_warning('Ignoring declarative-like tuple value of ' - 'attribute id') - def test_oops(self): - def define(): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True), - name = Column('name', String(50)) - assert False - self.assertRaisesMessage( - sa.exc.ArgumentError, - "Mapper Mapper|User|users could not assemble any primary key", - define) - - def test_table_args(self): - class Foo(Base): - __tablename__ = 'foo' - __table_args__ = {'mysql_engine':'InnoDB'} - id = Column('id', Integer, primary_key=True) - - assert Foo.__table__.kwargs['mysql_engine'] == 'InnoDB' - - class Bar(Base): - __tablename__ = 'bar' - __table_args__ = (ForeignKeyConstraint(['id'], ['foo.id']), {'mysql_engine':'InnoDB'}) - id = Column('id', Integer, primary_key=True) - - assert Bar.__table__.c.id.references(Foo.__table__.c.id) - assert Bar.__table__.kwargs['mysql_engine'] == 'InnoDB' - - def test_expression(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - addresses = relation("Address", backref="user") - - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - - id = Column('id', Integer, primary_key=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) - - User.address_count = sa.orm.column_property( - sa.select([sa.func.count(Address.id)]). - where(Address.user_id == User.id).as_scalar()) - - Base.metadata.create_all() - - u1 = User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), - [User(name='u1', address_count=2, addresses=[ - Address(email='one'), - Address(email='two')])]) - - def test_column(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - - User.a = Column('a', String(10)) - User.b = Column(String(10)) - - Base.metadata.create_all() - - u1 = User(name='u1', a='a', b='b') - eq_(u1.a, 'a') - eq_(User.a.get_history(u1), (['a'], (), ())) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), - [User(name='u1', a='a', b='b')]) - - def test_column_properties(self): - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) - email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - adr_count = sa.orm.column_property( - sa.select([sa.func.count(Address.id)], Address.user_id == id). - as_scalar()) - addresses = relation(Address) - - Base.metadata.create_all() - - u1 = User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), - [User(name='u1', adr_count=2, addresses=[ - Address(email='one'), - Address(email='two')])]) - - def test_column_properties_2(self): - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) - email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - # this is not "valid" but we want to test that Address.id doesnt - # get stuck into user's table - adr_count = Address.id - - eq_(set(User.__table__.c.keys()), set(['id', 'name'])) - eq_(set(Address.__table__.c.keys()), set(['id', 'email', 'user_id'])) - - def test_deferred(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column(Integer, primary_key=True) - name = sa.orm.deferred(Column(String(50))) - - Base.metadata.create_all() - sess = create_session() - sess.add(User(name='u1')) - sess.flush() - sess.expunge_all() - - u1 = sess.query(User).filter(User.name == 'u1').one() - assert 'name' not in u1.__dict__ - def go(): - eq_(u1.name, 'u1') - self.assert_sql_count(testing.db, go, 1) - - def test_synonym_inline(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - _name = Column('name', String(50)) - def _set_name(self, name): - self._name = "SOMENAME " + name - def _get_name(self): - return self._name - name = sa.orm.synonym('_name', - descriptor=property(_get_name, _set_name)) - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, "SOMENAME someuser") - sess.add(u1) - sess.flush() - eq_(sess.query(User).filter(User.name == "SOMENAME someuser").one(), u1) - - def test_synonym_no_descriptor(self): - from sqlalchemy.orm.properties import ColumnProperty - - class CustomCompare(ColumnProperty.Comparator): - __hash__ = None - def __eq__(self, other): - return self.__clause_element__() == other + ' FOO' - - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - _name = Column('name', String(50)) - name = sa.orm.synonym('_name', comparator_factory=CustomCompare) - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='someuser FOO') - sess.add(u1) - sess.flush() - eq_(sess.query(User).filter(User.name == "someuser").one(), u1) - - def test_synonym_added(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - _name = Column('name', String(50)) - def _set_name(self, name): - self._name = "SOMENAME " + name - def _get_name(self): - return self._name - name = property(_get_name, _set_name) - User.name = sa.orm.synonym('_name', descriptor=User.name) - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, "SOMENAME someuser") - sess.add(u1) - sess.flush() - eq_(sess.query(User).filter(User.name == "SOMENAME someuser").one(), u1) - - def test_reentrant_compile_via_foreignkey(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - addresses = relation("Address", backref="user") - - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - - id = Column('id', Integer, primary_key=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey(User.id)) - - # previous versions would force a re-entrant mapper compile - # via the User.id inside the ForeignKey but this is no - # longer the case - sa.orm.compile_mappers() - - eq_(str(Address.user_id.property.columns[0].foreign_keys[0]), "ForeignKey('users.id')") - - Base.metadata.create_all() - u1 = User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), [User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ])]) - - def test_relation_reference(self): - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - - id = Column('id', Integer, primary_key=True) - email = Column('email', String(50)) - user_id = Column('user_id', Integer, ForeignKey('users.id')) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - addresses = relation("Address", backref="user", - primaryjoin=id == Address.user_id) - - User.address_count = sa.orm.column_property( - sa.select([sa.func.count(Address.id)]). - where(Address.user_id == User.id).as_scalar()) - - Base.metadata.create_all() - - u1 = User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), - [User(name='u1', address_count=2, addresses=[ - Address(email='one'), - Address(email='two')])]) - - def test_pk_with_fk_init(self): - class Bar(Base): - __tablename__ = 'bar' - - id = sa.Column(sa.Integer, sa.ForeignKey("foo.id"), primary_key=True) - ex = sa.Column(sa.Integer, primary_key=True) - - class Foo(Base): - __tablename__ = 'foo' - - id = sa.Column(sa.Integer, primary_key=True) - bars = sa.orm.relation(Bar) - - assert Bar.__mapper__.primary_key[0] is Bar.__table__.c.id - assert Bar.__mapper__.primary_key[1] is Bar.__table__.c.ex - - - def test_with_explicit_autoloaded(self): - meta = MetaData(testing.db) - t1 = Table('t1', meta, - Column('id', String(50), primary_key=True), - Column('data', String(50))) - meta.create_all() - try: - class MyObj(Base): - __table__ = Table('t1', Base.metadata, autoload=True) - - sess = create_session() - m = MyObj(id="someid", data="somedata") - sess.add(m) - sess.flush() - - eq_(t1.select().execute().fetchall(), [('someid', 'somedata')]) - finally: - meta.drop_all() - -class DeclarativeInheritanceTest(DeclarativeTestBase): - def test_custom_join_condition(self): - class Foo(Base): - __tablename__ = 'foo' - id = Column('id', Integer, primary_key=True) - - class Bar(Foo): - __tablename__ = 'bar' - id = Column('id', Integer, primary_key=True) - foo_id = Column('foo_id', Integer) - __mapper_args__ = {'inherit_condition':foo_id==Foo.id} - - # compile succeeds because inherit_condition is honored - compile_mappers() - - def test_joined(self): - class Company(Base, ComparableEntity): - __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - employees = relation("Person") - - class Person(Base, ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) - company_id = Column('company_id', Integer, - ForeignKey('companies.id')) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on':discriminator} - - class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity':'engineer'} - id = Column('id', Integer, ForeignKey('people.id'), primary_key=True) - primary_language = Column('primary_language', String(50)) - - class Manager(Person): - __tablename__ = 'managers' - __mapper_args__ = {'polymorphic_identity':'manager'} - id = Column('id', Integer, ForeignKey('people.id'), primary_key=True) - golf_swing = Column('golf_swing', String(50)) - - Base.metadata.create_all() - - sess = create_session() - - c1 = Company(name="MegaCorp, Inc.", employees=[ - Engineer(name="dilbert", primary_language="java"), - Engineer(name="wally", primary_language="c++"), - Manager(name="dogbert", golf_swing="fore!") - ]) - - c2 = Company(name="Elbonia, Inc.", employees=[ - Engineer(name="vlad", primary_language="cobol") - ]) - - sess.add(c1) - sess.add(c2) - sess.flush() - sess.expunge_all() - - eq_((sess.query(Company). - filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language == 'cobol')).first()), - c2) - - # ensure that the Manager mapper was compiled - # with the Person id column as higher priority. - # this ensures that "id" will get loaded from the Person row - # and not the possibly non-present Manager row - assert Manager.id.property.columns == [Person.__table__.c.id, Manager.__table__.c.id] - - # assert that the "id" column is available without a second load. - # this would be the symptom of the previous step not being correct. - sess.expunge_all() - def go(): - assert sess.query(Manager).filter(Manager.name=='dogbert').one().id - self.assert_sql_count(testing.db, go, 1) - sess.expunge_all() - def go(): - assert sess.query(Person).filter(Manager.name=='dogbert').one().id - self.assert_sql_count(testing.db, go, 1) - - def test_subclass_mixin(self): - class Person(Base, ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on':discriminator} - - class MyMixin(object): - pass - - class Engineer(MyMixin, Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity':'engineer'} - id = Column('id', Integer, ForeignKey('people.id'), primary_key=True) - primary_language = Column('primary_language', String(50)) - - assert class_mapper(Engineer).inherits is class_mapper(Person) - - def test_with_undefined_foreignkey(self): - class Parent(Base): - __tablename__ = 'parent' - id = Column('id', Integer, primary_key=True) - tp = Column('type', String(50)) - __mapper_args__ = dict(polymorphic_on = tp) - - class Child1(Parent): - __tablename__ = 'child1' - id = Column('id', Integer, ForeignKey('parent.id'), primary_key=True) - related_child2 = Column('c2', Integer, ForeignKey('child2.id')) - __mapper_args__ = dict(polymorphic_identity = 'child1') - - # no exception is raised by the ForeignKey to "child2" even though - # child2 doesn't exist yet - - class Child2(Parent): - __tablename__ = 'child2' - id = Column('id', Integer, ForeignKey('parent.id'), primary_key=True) - related_child1 = Column('c1', Integer) - __mapper_args__ = dict(polymorphic_identity = 'child2') - - sa.orm.compile_mappers() # no exceptions here - - def test_single_colsonbase(self): - """test single inheritance where all the columns are on the base class.""" - - class Company(Base, ComparableEntity): - __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - employees = relation("Person") - - class Person(Base, ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) - company_id = Column('company_id', Integer, - ForeignKey('companies.id')) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - primary_language = Column('primary_language', String(50)) - golf_swing = Column('golf_swing', String(50)) - __mapper_args__ = {'polymorphic_on':discriminator} - - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity':'engineer'} - - class Manager(Person): - __mapper_args__ = {'polymorphic_identity':'manager'} - - Base.metadata.create_all() - - sess = create_session() - c1 = Company(name="MegaCorp, Inc.", employees=[ - Engineer(name="dilbert", primary_language="java"), - Engineer(name="wally", primary_language="c++"), - Manager(name="dogbert", golf_swing="fore!") - ]) - - c2 = Company(name="Elbonia, Inc.", employees=[ - Engineer(name="vlad", primary_language="cobol") - ]) - - sess.add(c1) - sess.add(c2) - sess.flush() - sess.expunge_all() - - eq_((sess.query(Person). - filter(Engineer.primary_language == 'cobol').first()), - Engineer(name='vlad')) - eq_((sess.query(Company). - filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language == 'cobol')).first()), - c2) - - def test_single_colsonsub(self): - """test single inheritance where the columns are local to their class. - - this is a newer usage. - - """ - - class Company(Base, ComparableEntity): - __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - employees = relation("Person") - - class Person(Base, ComparableEntity): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - company_id = Column(Integer, - ForeignKey('companies.id')) - name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on':discriminator} - - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity':'engineer'} - primary_language = Column(String(50)) - - class Manager(Person): - __mapper_args__ = {'polymorphic_identity':'manager'} - golf_swing = Column(String(50)) - - # we have here a situation that is somewhat unique. - # the Person class is mapped to the "people" table, but it - # was mapped when the table did not include the "primary_language" - # or "golf_swing" columns. declarative will also manipulate - # the exclude_properties collection so that sibling classes - # don't cross-pollinate. - - assert Person.__table__.c.company_id - assert Person.__table__.c.golf_swing - assert Person.__table__.c.primary_language - assert Engineer.primary_language - assert Manager.golf_swing - assert not hasattr(Person, 'primary_language') - assert not hasattr(Person, 'golf_swing') - assert not hasattr(Engineer, 'golf_swing') - assert not hasattr(Manager, 'primary_language') - - Base.metadata.create_all() - - sess = create_session() - - e1 = Engineer(name="dilbert", primary_language="java") - e2 = Engineer(name="wally", primary_language="c++") - m1 = Manager(name="dogbert", golf_swing="fore!") - c1 = Company(name="MegaCorp, Inc.", employees=[e1, e2, m1]) - - e3 =Engineer(name="vlad", primary_language="cobol") - c2 = Company(name="Elbonia, Inc.", employees=[e3]) - sess.add(c1) - sess.add(c2) - sess.flush() - sess.expunge_all() - - eq_((sess.query(Person). - filter(Engineer.primary_language == 'cobol').first()), - Engineer(name='vlad')) - eq_((sess.query(Company). - filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language == 'cobol')).first()), - c2) - - eq_( - sess.query(Engineer).filter_by(primary_language='cobol').one(), - Engineer(name="vlad", primary_language="cobol") - ) - - def test_joined_from_single(self): - class Company(Base, ComparableEntity): - __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - employees = relation("Person") - - class Person(Base, ComparableEntity): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - company_id = Column(Integer, ForeignKey('companies.id')) - name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on':discriminator} - - class Manager(Person): - __mapper_args__ = {'polymorphic_identity':'manager'} - golf_swing = Column(String(50)) - - class Engineer(Person): - __tablename__ = 'engineers' - __mapper_args__ = {'polymorphic_identity':'engineer'} - id = Column(Integer, ForeignKey('people.id'), primary_key=True) - primary_language = Column(String(50)) - - assert Person.__table__.c.golf_swing - assert not Person.__table__.c.has_key('primary_language') - assert Engineer.__table__.c.primary_language - assert Engineer.primary_language - assert Manager.golf_swing - assert not hasattr(Person, 'primary_language') - assert not hasattr(Person, 'golf_swing') - assert not hasattr(Engineer, 'golf_swing') - assert not hasattr(Manager, 'primary_language') - - Base.metadata.create_all() - - sess = create_session() - - e1 = Engineer(name="dilbert", primary_language="java") - e2 = Engineer(name="wally", primary_language="c++") - m1 = Manager(name="dogbert", golf_swing="fore!") - c1 = Company(name="MegaCorp, Inc.", employees=[e1, e2, m1]) - e3 =Engineer(name="vlad", primary_language="cobol") - c2 = Company(name="Elbonia, Inc.", employees=[e3]) - sess.add(c1) - sess.add(c2) - sess.flush() - sess.expunge_all() - - eq_((sess.query(Person).with_polymorphic(Engineer). - filter(Engineer.primary_language == 'cobol').first()), - Engineer(name='vlad')) - eq_((sess.query(Company). - filter(Company.employees.of_type(Engineer). - any(Engineer.primary_language == 'cobol')).first()), - c2) - - eq_( - sess.query(Engineer).filter_by(primary_language='cobol').one(), - Engineer(name="vlad", primary_language="cobol") - ) - - def test_add_deferred(self): - class Person(Base, ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) - - Person.name = deferred(Column(String(10))) - - Base.metadata.create_all() - sess = create_session() - p = Person(name='ratbert') - - sess.add(p) - sess.flush() - sess.expunge_all() - eq_( - sess.query(Person).all(), - [ - Person(name='ratbert') - ] - ) - person = sess.query(Person).filter(Person.name == 'ratbert').one() - assert 'name' not in person.__dict__ - - def test_single_fksonsub(self): - """test single inheritance with a foreign key-holding column on a subclass. - - """ - - class Person(Base, ComparableEntity): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on':discriminator} - - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity':'engineer'} - primary_language_id = Column(Integer, ForeignKey('languages.id')) - primary_language = relation("Language") - - class Language(Base, ComparableEntity): - __tablename__ = 'languages' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - - assert not hasattr(Person, 'primary_language_id') - - Base.metadata.create_all() - - sess = create_session() - - java, cpp, cobol = Language(name='java'),Language(name='cpp'), Language(name='cobol') - e1 = Engineer(name="dilbert", primary_language=java) - e2 = Engineer(name="wally", primary_language=cpp) - e3 =Engineer(name="vlad", primary_language=cobol) - sess.add_all([e1, e2, e3]) - sess.flush() - sess.expunge_all() - - eq_((sess.query(Person). - filter(Engineer.primary_language.has(Language.name=='cobol')).first()), - Engineer(name='vlad', primary_language=Language(name='cobol'))) - - eq_( - sess.query(Engineer).filter(Engineer.primary_language.has(Language.name=='cobol')).one(), - Engineer(name="vlad", primary_language=Language(name='cobol')) - ) - - eq_( - sess.query(Person).join(Engineer.primary_language).order_by(Language.name).all(), - [ - Engineer(name='vlad', primary_language=Language(name='cobol')), - Engineer(name='wally', primary_language=Language(name='cpp')), - Engineer(name='dilbert', primary_language=Language(name='java')), - ] - ) - - def test_single_three_levels(self): - class Person(Base, ComparableEntity): - __tablename__ = 'people' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on':discriminator} - - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity':'engineer'} - primary_language = Column(String(50)) - - class JuniorEngineer(Engineer): - __mapper_args__ = {'polymorphic_identity':'junior_engineer'} - nerf_gun = Column(String(50)) - - class Manager(Person): - __mapper_args__ = {'polymorphic_identity':'manager'} - golf_swing = Column(String(50)) - - assert JuniorEngineer.nerf_gun - assert JuniorEngineer.primary_language - assert JuniorEngineer.name - assert Manager.golf_swing - assert Engineer.primary_language - assert not hasattr(Engineer, 'golf_swing') - assert not hasattr(Engineer, 'nerf_gun') - assert not hasattr(Manager, 'nerf_gun') - assert not hasattr(Manager, 'primary_language') - - def test_single_no_special_cols(self): - class Person(Base, ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on':discriminator} - - def go(): - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity':'engineer'} - primary_language = Column('primary_language', String(50)) - foo_bar = Column(Integer, primary_key=True) - self.assertRaisesMessage(sa.exc.ArgumentError, "place primary key", go) - - def test_single_no_table_args(self): - class Person(Base, ComparableEntity): - __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - discriminator = Column('type', String(50)) - __mapper_args__ = {'polymorphic_on':discriminator} - - def go(): - class Engineer(Person): - __mapper_args__ = {'polymorphic_identity':'engineer'} - primary_language = Column('primary_language', String(50)) - __table_args__ = () - self.assertRaisesMessage(sa.exc.ArgumentError, "place __table_args__", go) - - def test_concrete(self): - engineers = Table('engineers', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('primary_language', String(50)) - ) - managers = Table('managers', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('golf_swing', String(50)) - ) - - punion = polymorphic_union({ - 'engineer':engineers, - 'manager':managers - }, 'type', 'punion') - - class Person(Base, ComparableEntity): - __table__ = punion - __mapper_args__ = {'polymorphic_on':punion.c.type} - - class Engineer(Person): - __table__ = engineers - __mapper_args__ = {'polymorphic_identity':'engineer', 'concrete':True} - - class Manager(Person): - __table__ = managers - __mapper_args__ = {'polymorphic_identity':'manager', 'concrete':True} - - Base.metadata.create_all() - sess = create_session() - - e1 = Engineer(name="dilbert", primary_language="java") - e2 = Engineer(name="wally", primary_language="c++") - m1 = Manager(name="dogbert", golf_swing="fore!") - e3 = Engineer(name="vlad", primary_language="cobol") - - sess.add_all([e1, e2, m1, e3]) - sess.flush() - sess.expunge_all() - eq_( - sess.query(Person).order_by(Person.name).all(), - [ - Engineer(name='dilbert'), Manager(name='dogbert'), - Engineer(name='vlad'), Engineer(name='wally') - ] - ) - - -def produce_test(inline, stringbased): - class ExplicitJoinTest(MappedTest): - - def define_tables(self, metadata): - global User, Address - Base = decl.declarative_base(metadata=metadata) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - id = Column(Integer, primary_key=True) - name = Column(String(50)) - - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) - email = Column(String(50)) - user_id = Column(Integer, ForeignKey('users.id')) - if inline: - if stringbased: - user = relation("User", primaryjoin="User.id==Address.user_id", backref="addresses") - else: - user = relation(User, primaryjoin=User.id==user_id, backref="addresses") - - if not inline: - compile_mappers() - if stringbased: - Address.user = relation("User", primaryjoin="User.id==Address.user_id", backref="addresses") - else: - Address.user = relation(User, primaryjoin=User.id==Address.user_id, backref="addresses") - - def insert_data(self): - params = [dict(zip(('id', 'name'), column_values)) for column_values in - [(7, 'jack'), - (8, 'ed'), - (9, 'fred'), - (10, 'chuck')] - ] - User.__table__.insert().execute(params) - - Address.__table__.insert().execute( - [dict(zip(('id', 'user_id', 'email'), column_values)) for column_values in - [(1, 7, "jack@bean.com"), - (2, 8, "ed@wood.com"), - (3, 8, "ed@bettyboop.com"), - (4, 8, "ed@lala.com"), - (5, 9, "fred@fred.com")] - ] - ) - - def test_aliased_join(self): - # this query will screw up if the aliasing - # enabled in query.join() gets applied to the right half of the join condition inside the any(). - # the join condition inside of any() comes from the "primaryjoin" of the relation, - # and should not be annotated with _orm_adapt. PropertyLoader.Comparator will annotate - # the left side with _orm_adapt, though. - sess = create_session() - eq_( - sess.query(User).join(User.addresses, aliased=True). - filter(Address.email=='ed@wood.com').filter(User.addresses.any(Address.email=='jack@bean.com')).all(), - [] - ) - - ExplicitJoinTest.__name__ = "ExplicitJoinTest%s%s" % (inline and 'Inline' or 'Separate', stringbased and 'String' or 'Literal') - return ExplicitJoinTest - -for inline in (True, False): - for stringbased in (True, False): - testclass = produce_test(inline, stringbased) - exec("%s = testclass" % testclass.__name__) - del testclass - -class DeclarativeReflectionTest(testing.TestBase): - def setUpAll(self): - global reflection_metadata - reflection_metadata = MetaData(testing.db) - - Table('users', reflection_metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - test_needs_fk=True) - Table('addresses', reflection_metadata, - Column('id', Integer, primary_key=True), - Column('email', String(50)), - Column('user_id', Integer, ForeignKey('users.id')), - test_needs_fk=True) - Table('imhandles', reflection_metadata, - Column('id', Integer, primary_key=True), - Column('user_id', Integer), - Column('network', String(50)), - Column('handle', String(50)), - test_needs_fk=True) - - reflection_metadata.create_all() - - def setUp(self): - global Base - Base = decl.declarative_base(testing.db) - - def tearDown(self): - for t in reversed(reflection_metadata.sorted_tables): - t.delete().execute() - - def tearDownAll(self): - reflection_metadata.drop_all() - - def test_basic(self): - meta = MetaData(testing.db) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - __autoload__ = True - addresses = relation("Address", backref="user") - - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - __autoload__ = True - - u1 = User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), [User(name='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ])]) - - a1 = sess.query(Address).filter(Address.email == 'two').one() - eq_(a1, Address(email='two')) - eq_(a1.user, User(name='u1')) - - def test_rekey(self): - meta = MetaData(testing.db) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - __autoload__ = True - nom = Column('name', String(50), key='nom') - addresses = relation("Address", backref="user") - - class Address(Base, ComparableEntity): - __tablename__ = 'addresses' - __autoload__ = True - - u1 = User(nom='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), [User(nom='u1', addresses=[ - Address(email='one'), - Address(email='two'), - ])]) - - a1 = sess.query(Address).filter(Address.email == 'two').one() - eq_(a1, Address(email='two')) - eq_(a1.user, User(nom='u1')) - - self.assertRaises(TypeError, User, name='u3') - - def test_supplied_fk(self): - meta = MetaData(testing.db) - - class IMHandle(Base, ComparableEntity): - __tablename__ = 'imhandles' - __autoload__ = True - - user_id = Column('user_id', Integer, - ForeignKey('users.id')) - class User(Base, ComparableEntity): - __tablename__ = 'users' - __autoload__ = True - handles = relation("IMHandle", backref="user") - - u1 = User(name='u1', handles=[ - IMHandle(network='blabber', handle='foo'), - IMHandle(network='lol', handle='zomg') - ]) - sess = create_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_(sess.query(User).all(), [User(name='u1', handles=[ - IMHandle(network='blabber', handle='foo'), - IMHandle(network='lol', handle='zomg') - ])]) - - a1 = sess.query(IMHandle).filter(IMHandle.handle == 'zomg').one() - eq_(a1, IMHandle(network='lol', handle='zomg')) - eq_(a1.user, User(name='u1')) - - def test_synonym_for(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - - @decl.synonym_for('name') - @property - def namesyn(self): - return self.name - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, "someuser") - eq_(u1.namesyn, 'someuser') - sess.add(u1) - sess.flush() - - rt = sess.query(User).filter(User.namesyn == 'someuser').one() - eq_(rt, u1) - - def test_comparable_using(self): - class NameComparator(sa.orm.PropComparator): - @property - def upperself(self): - cls = self.prop.parent.class_ - col = getattr(cls, 'name') - return sa.func.upper(col) - - def operate(self, op, other, **kw): - return op(self.upperself, other, **kw) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - - @decl.comparable_using(NameComparator) - @property - def uc_name(self): - return self.name is not None and self.name.upper() or None - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, "someuser", u1.name) - eq_(u1.uc_name, 'SOMEUSER', u1.uc_name) - sess.add(u1) - sess.flush() - sess.expunge_all() - - rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one() - eq_(rt, u1) - sess.expunge_all() - - rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one() - eq_(rt, u1) - - -if __name__ == '__main__': - testing.main() diff --git a/test/ext/orderinglist.py b/test/ext/orderinglist.py deleted file mode 100644 index c111a02de..000000000 --- a/test/ext/orderinglist.py +++ /dev/null @@ -1,403 +0,0 @@ -import testenv; testenv.configure_for_tests() -from sqlalchemy import * -from sqlalchemy.orm import * -from sqlalchemy.ext.orderinglist import * -from testlib.testing import eq_ -from testlib import * - - -metadata = None - -# order in whole steps -def step_numbering(step): - def f(index, collection): - return step * index - return f - -# almost fibonacci- skip the first 2 steps -# e.g. 1, 2, 3, 5, 8, ... instead of 0, 1, 1, 2, 3, ... -# otherwise ordering of the elements at '1' is undefined... ;) -def fibonacci_numbering(order_col): - def f(index, collection): - if index == 0: - return 1 - elif index == 1: - return 2 - else: - return (getattr(collection[index - 1], order_col) + - getattr(collection[index - 2], order_col)) - return f - -# 0 -> A, 1 -> B, ... 25 -> Z, 26 -> AA, 27 -> AB, ... -def alpha_ordering(index, collection): - s = '' - while index > 25: - d = index / 26 - s += chr((d % 26) + 64) - index -= d * 26 - s += chr(index + 65) - return s - -class OrderingListTest(TestBase): - def setUp(self): - global metadata, slides_table, bullets_table, Slide, Bullet - slides_table, bullets_table = None, None - Slide, Bullet = None, None - if metadata: - metadata.clear() - - def _setup(self, test_collection_class): - """Build a relation situation using the given test_collection_class - factory""" - - global metadata, slides_table, bullets_table, Slide, Bullet - - metadata = MetaData(testing.db) - slides_table = Table('test_Slides', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('name', String(128))) - bullets_table = Table('test_Bullets', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('slide_id', Integer, - ForeignKey('test_Slides.id')), - Column('position', Integer), - Column('text', String(128))) - - class Slide(object): - def __init__(self, name): - self.name = name - def __repr__(self): - return '' % self.name - - class Bullet(object): - def __init__(self, text): - self.text = text - def __repr__(self): - return '' % (self.text, self.position) - - mapper(Slide, slides_table, properties={ - 'bullets': relation(Bullet, lazy=False, - collection_class=test_collection_class, - backref='slide', - order_by=[bullets_table.c.position]) - }) - mapper(Bullet, bullets_table) - - metadata.create_all() - - def tearDown(self): - metadata.drop_all() - - def test_append_no_reorder(self): - self._setup(ordering_list('position', count_from=1, - reorder_on_append=False)) - - s1 = Slide('Slide #1') - - self.assert_(not s1.bullets) - self.assert_(len(s1.bullets) == 0) - - s1.bullets.append(Bullet('s1/b1')) - - self.assert_(s1.bullets) - self.assert_(len(s1.bullets) == 1) - self.assert_(s1.bullets[0].position == 1) - - s1.bullets.append(Bullet('s1/b2')) - - self.assert_(len(s1.bullets) == 2) - self.assert_(s1.bullets[0].position == 1) - self.assert_(s1.bullets[1].position == 2) - - bul = Bullet('s1/b100') - bul.position = 100 - s1.bullets.append(bul) - - self.assert_(s1.bullets[0].position == 1) - self.assert_(s1.bullets[1].position == 2) - self.assert_(s1.bullets[2].position == 100) - - s1.bullets.append(Bullet('s1/b4')) - self.assert_(s1.bullets[0].position == 1) - self.assert_(s1.bullets[1].position == 2) - self.assert_(s1.bullets[2].position == 100) - self.assert_(s1.bullets[3].position == 4) - - s1.bullets._reorder() - self.assert_(s1.bullets[0].position == 1) - self.assert_(s1.bullets[1].position == 2) - self.assert_(s1.bullets[2].position == 3) - self.assert_(s1.bullets[3].position == 4) - - session = create_session() - session.add(s1) - session.flush() - - id = s1.id - session.expunge_all() - del s1 - - srt = session.query(Slide).get(id) - - self.assert_(srt.bullets) - self.assert_(len(srt.bullets) == 4) - - titles = ['s1/b1','s1/b2','s1/b100','s1/b4'] - found = [b.text for b in srt.bullets] - - self.assert_(titles == found) - - def test_append_reorder(self): - self._setup(ordering_list('position', count_from=1, - reorder_on_append=True)) - - s1 = Slide('Slide #1') - - self.assert_(not s1.bullets) - self.assert_(len(s1.bullets) == 0) - - s1.bullets.append(Bullet('s1/b1')) - - self.assert_(s1.bullets) - self.assert_(len(s1.bullets) == 1) - self.assert_(s1.bullets[0].position == 1) - - s1.bullets.append(Bullet('s1/b2')) - - self.assert_(len(s1.bullets) == 2) - self.assert_(s1.bullets[0].position == 1) - self.assert_(s1.bullets[1].position == 2) - - bul = Bullet('s1/b100') - bul.position = 100 - s1.bullets.append(bul) - - self.assert_(s1.bullets[0].position == 1) - self.assert_(s1.bullets[1].position == 2) - self.assert_(s1.bullets[2].position == 3) - - s1.bullets.append(Bullet('s1/b4')) - self.assert_(s1.bullets[0].position == 1) - self.assert_(s1.bullets[1].position == 2) - self.assert_(s1.bullets[2].position == 3) - self.assert_(s1.bullets[3].position == 4) - - s1.bullets._reorder() - self.assert_(s1.bullets[0].position == 1) - self.assert_(s1.bullets[1].position == 2) - self.assert_(s1.bullets[2].position == 3) - self.assert_(s1.bullets[3].position == 4) - - s1.bullets._raw_append(Bullet('raw')) - self.assert_(s1.bullets[4].position is None) - - s1.bullets._reorder() - self.assert_(s1.bullets[4].position == 5) - session = create_session() - session.add(s1) - session.flush() - - id = s1.id - session.expunge_all() - del s1 - - srt = session.query(Slide).get(id) - - self.assert_(srt.bullets) - self.assert_(len(srt.bullets) == 5) - - titles = ['s1/b1','s1/b2','s1/b100','s1/b4', 'raw'] - found = [b.text for b in srt.bullets] - eq_(titles, found) - - srt.bullets._raw_append(Bullet('raw2')) - srt.bullets[-1].position = 6 - session.flush() - session.expunge_all() - - srt = session.query(Slide).get(id) - titles = ['s1/b1','s1/b2','s1/b100','s1/b4', 'raw', 'raw2'] - found = [b.text for b in srt.bullets] - eq_(titles, found) - - def test_insert(self): - self._setup(ordering_list('position')) - - s1 = Slide('Slide #1') - s1.bullets.append(Bullet('1')) - s1.bullets.append(Bullet('2')) - s1.bullets.append(Bullet('3')) - s1.bullets.append(Bullet('4')) - - self.assert_(s1.bullets[0].position == 0) - self.assert_(s1.bullets[1].position == 1) - self.assert_(s1.bullets[2].position == 2) - self.assert_(s1.bullets[3].position == 3) - - s1.bullets.insert(2, Bullet('insert_at_2')) - self.assert_(s1.bullets[0].position == 0) - self.assert_(s1.bullets[1].position == 1) - self.assert_(s1.bullets[2].position == 2) - self.assert_(s1.bullets[3].position == 3) - self.assert_(s1.bullets[4].position == 4) - - self.assert_(s1.bullets[1].text == '2') - self.assert_(s1.bullets[2].text == 'insert_at_2') - self.assert_(s1.bullets[3].text == '3') - - s1.bullets.insert(999, Bullet('999')) - - self.assert_(len(s1.bullets) == 6) - self.assert_(s1.bullets[5].position == 5) - - session = create_session() - session.add(s1) - session.flush() - - id = s1.id - session.expunge_all() - del s1 - - srt = session.query(Slide).get(id) - - self.assert_(srt.bullets) - self.assert_(len(srt.bullets) == 6) - - texts = ['1','2','insert_at_2','3','4','999'] - found = [b.text for b in srt.bullets] - - self.assert_(texts == found) - - def test_slice(self): - self._setup(ordering_list('position')) - - b = [ Bullet('1'), Bullet('2'), Bullet('3'), - Bullet('4'), Bullet('5'), Bullet('6') ] - s1 = Slide('Slide #1') - - # 1, 2, 3 - s1.bullets[0:3] = b[0:3] - for i in 0, 1, 2: - self.assert_(s1.bullets[i].position == i) - self.assert_(s1.bullets[i] == b[i]) - - # 1, 4, 5, 6, 3 - s1.bullets[1:2] = b[3:6] - for li, bi in (0,0), (1,3), (2,4), (3,5), (4,2): - self.assert_(s1.bullets[li].position == li) - self.assert_(s1.bullets[li] == b[bi]) - - # 1, 6, 3 - del s1.bullets[1:3] - for li, bi in (0,0), (1,5), (2,2): - self.assert_(s1.bullets[li].position == li) - self.assert_(s1.bullets[li] == b[bi]) - - session = create_session() - session.add(s1) - session.flush() - - id = s1.id - session.expunge_all() - del s1 - - srt = session.query(Slide).get(id) - - self.assert_(srt.bullets) - self.assert_(len(srt.bullets) == 3) - - texts = ['1', '6', '3'] - for i, text in enumerate(texts): - self.assert_(srt.bullets[i].position == i) - self.assert_(srt.bullets[i].text == text) - - def test_replace(self): - self._setup(ordering_list('position')) - - s1 = Slide('Slide #1') - s1.bullets = [ Bullet('1'), Bullet('2'), Bullet('3') ] - - self.assert_(len(s1.bullets) == 3) - self.assert_(s1.bullets[2].position == 2) - - session = create_session() - session.add(s1) - session.flush() - - new_bullet = Bullet('new 2') - self.assert_(new_bullet.position is None) - - # mark existing bullet as db-deleted before replacement. - #session.delete(s1.bullets[1]) - s1.bullets[1] = new_bullet - - self.assert_(new_bullet.position == 1) - self.assert_(len(s1.bullets) == 3) - - id = s1.id - - session.flush() - session.expunge_all() - - srt = session.query(Slide).get(id) - - self.assert_(srt.bullets) - self.assert_(len(srt.bullets) == 3) - - self.assert_(srt.bullets[1].text == 'new 2') - self.assert_(srt.bullets[2].text == '3') - - def test_funky_ordering(self): - class Pos(object): - def __init__(self): - self.position = None - - step_factory = ordering_list('position', - ordering_func=step_numbering(2)) - - stepped = step_factory() - stepped.append(Pos()) - stepped.append(Pos()) - stepped.append(Pos()) - stepped.append(Pos()) - - for li, pos in (0,0), (1,2), (2,4), (3,6): - self.assert_(stepped[li].position == pos) - - fib_factory = ordering_list('position', - ordering_func=fibonacci_numbering('position')) - - fibbed = fib_factory() - fibbed.append(Pos()) - fibbed.append(Pos()) - fibbed.append(Pos()) - fibbed.append(Pos()) - fibbed.append(Pos()) - - for li, pos in (0,1), (1,2), (2,3), (3,5), (4,8): - self.assert_(fibbed[li].position == pos) - - fibbed.insert(2, Pos()) - fibbed.insert(4, Pos()) - fibbed.insert(6, Pos()) - - for li, pos in (0,1), (1,2), (2,3), (3,5), (4,8), (5,13), (6,21), (7,34): - self.assert_(fibbed[li].position == pos) - - alpha_factory = ordering_list('position', - ordering_func=alpha_ordering) - alpha = alpha_factory() - alpha.append(Pos()) - alpha.append(Pos()) - alpha.append(Pos()) - - alpha.insert(1, Pos()) - - for li, pos in (0,'A'), (1,'B'), (2,'C'), (3,'D'): - self.assert_(alpha[li].position == pos) - - -if __name__ == "__main__": - testenv.main() diff --git a/test/ext/serializer.py b/test/ext/serializer.py deleted file mode 100644 index 048eccdfd..000000000 --- a/test/ext/serializer.py +++ /dev/null @@ -1,139 +0,0 @@ -import testenv; testenv.configure_for_tests() - -from sqlalchemy.ext import serializer -from sqlalchemy import exc -from testlib import sa, testing -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, select, desc, func, util -from testlib.sa.orm import relation, sessionmaker, scoped_session, class_mapper, mapper, eagerload, compile_mappers, aliased -from testlib.testing import eq_ - -from orm._base import ComparableEntity, MappedTest - - -class User(ComparableEntity): - pass - -class Address(ComparableEntity): - pass - -class SerializeTest(MappedTest): - run_setup_mappers = 'once' - run_inserts = 'once' - run_deletes = None - - def define_tables(self, metadata): - global users, addresses - users = Table('users', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)) - ) - addresses = Table('addresses', metadata, - Column('id', Integer, primary_key=True), - Column('email', String(50)), - Column('user_id', Integer, ForeignKey('users.id')), - ) - - def setup_mappers(self): - global Session - Session = scoped_session(sessionmaker()) - - mapper(User, users, properties={ - 'addresses':relation(Address, backref='user', order_by=addresses.c.id) - }) - mapper(Address, addresses) - - compile_mappers() - - def insert_data(self): - params = [dict(zip(('id', 'name'), column_values)) for column_values in - [(7, 'jack'), - (8, 'ed'), - (9, 'fred'), - (10, 'chuck')] - ] - users.insert().execute(params) - - addresses.insert().execute( - [dict(zip(('id', 'user_id', 'email'), column_values)) for column_values in - [(1, 7, "jack@bean.com"), - (2, 8, "ed@wood.com"), - (3, 8, "ed@bettyboop.com"), - (4, 8, "ed@lala.com"), - (5, 9, "fred@fred.com")] - ] - ) - - def test_tables(self): - assert serializer.loads(serializer.dumps(users), users.metadata, Session) is users - - def test_columns(self): - assert serializer.loads(serializer.dumps(users.c.name), users.metadata, Session) is users.c.name - - def test_mapper(self): - user_mapper = class_mapper(User) - assert serializer.loads(serializer.dumps(user_mapper), None, None) is user_mapper - - def test_attribute(self): - assert serializer.loads(serializer.dumps(User.name), None, None) is User.name - - def test_expression(self): - - expr = select([users]).select_from(users.join(addresses)).limit(5) - re_expr = serializer.loads(serializer.dumps(expr), users.metadata, None) - eq_( - str(expr), - str(re_expr) - ) - - assert re_expr.bind is testing.db - eq_( - re_expr.execute().fetchall(), - [(7, u'jack'), (8, u'ed'), (8, u'ed'), (8, u'ed'), (9, u'fred')] - ) - - # fails due to pure Python pickle bug: http://bugs.python.org/issue998998 - @testing.fails_if(lambda: util.py3k) - def test_query(self): - q = Session.query(User).filter(User.name=='ed').options(eagerload(User.addresses)) - eq_(q.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])]) - - q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) - def go(): - eq_(q2.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])]) - self.assert_sql_count(testing.db, go, 1) - - eq_(q2.join(User.addresses).filter(Address.email=='ed@bettyboop.com').value(func.count('*')), 1) - - u1 = Session.query(User).get(8) - - q = Session.query(Address).filter(Address.user==u1).order_by(desc(Address.email)) - q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) - - eq_(q2.all(), [Address(email='ed@wood.com'), Address(email='ed@lala.com'), Address(email='ed@bettyboop.com')]) - - q = Session.query(User).join(User.addresses).filter(Address.email.like('%fred%')) - q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) - eq_(q2.all(), [User(name='fred')]) - - eq_(list(q2.values(User.id, User.name)), [(9, u'fred')]) - - @testing.exclude('sqlite', '<=', (3, 5, 9), 'id comparison failing on the buildbot') - def test_aliases(self): - u7, u8, u9, u10 = Session.query(User).order_by(User.id).all() - - ualias = aliased(User) - q = Session.query(User, ualias).join((ualias, User.id < ualias.id)).filter(User.id<9).order_by(User.id, ualias.id) - eq_(list(q.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)]) - - q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) - - eq_(list(q2.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)]) - - def test_any(self): - r = User.addresses.any(Address.email=='x') - ser = serializer.dumps(r) - x = serializer.loads(ser, users.metadata) - eq_(str(r), str(x)) - -if __name__ == '__main__': - testing.main() diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py new file mode 100644 index 000000000..742f98baf --- /dev/null +++ b/test/ext/test_associationproxy.py @@ -0,0 +1,885 @@ +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +import gc +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.orm.collections import collection +from sqlalchemy.ext.associationproxy import * +from sqlalchemy.test import * + + +class DictCollection(dict): + @collection.appender + def append(self, obj): + self[obj.foo] = obj + @collection.remover + def remove(self, obj): + del self[obj.foo] + +class SetCollection(set): + pass + +class ListCollection(list): + pass + +class ObjectCollection(object): + def __init__(self): + self.values = list() + @collection.appender + def append(self, obj): + self.values.append(obj) + @collection.remover + def remove(self, obj): + self.values.remove(obj) + def __iter__(self): + return iter(self.values) + +class _CollectionOperations(TestBase): + def setup(self): + collection_class = self.collection_class + + metadata = MetaData(testing.db) + + parents_table = Table('Parent', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(128))) + children_table = Table('Children', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('parent_id', Integer, + ForeignKey('Parent.id')), + Column('foo', String(128)), + Column('name', String(128))) + + class Parent(object): + children = association_proxy('_children', 'name') + + def __init__(self, name): + self.name = name + + class Child(object): + if collection_class and issubclass(collection_class, dict): + def __init__(self, foo, name): + self.foo = foo + self.name = name + else: + def __init__(self, name): + self.name = name + + mapper(Parent, parents_table, properties={ + '_children': relation(Child, lazy=False, + collection_class=collection_class)}) + mapper(Child, children_table) + + metadata.create_all() + + self.metadata = metadata + self.session = create_session() + self.Parent, self.Child = Parent, Child + + def teardown(self): + self.metadata.drop_all() + + def roundtrip(self, obj): + if obj not in self.session: + self.session.add(obj) + self.session.flush() + id, type_ = obj.id, type(obj) + self.session.expunge_all() + return self.session.query(type_).get(id) + + def _test_sequence_ops(self): + Parent, Child = self.Parent, self.Child + + p1 = Parent('P1') + + self.assert_(not p1._children) + self.assert_(not p1.children) + + ch = Child('regular') + p1._children.append(ch) + + self.assert_(ch in p1._children) + self.assert_(len(p1._children) == 1) + + self.assert_(p1.children) + self.assert_(len(p1.children) == 1) + self.assert_(ch not in p1.children) + self.assert_('regular' in p1.children) + + p1.children.append('proxied') + + self.assert_('proxied' in p1.children) + self.assert_('proxied' not in p1._children) + self.assert_(len(p1.children) == 2) + self.assert_(len(p1._children) == 2) + + self.assert_(p1._children[0].name == 'regular') + self.assert_(p1._children[1].name == 'proxied') + + del p1._children[1] + + self.assert_(len(p1._children) == 1) + self.assert_(len(p1.children) == 1) + self.assert_(p1._children[0] == ch) + + del p1.children[0] + + self.assert_(len(p1._children) == 0) + self.assert_(len(p1.children) == 0) + + p1.children = ['a','b','c'] + self.assert_(len(p1._children) == 3) + self.assert_(len(p1.children) == 3) + + del ch + p1 = self.roundtrip(p1) + + self.assert_(len(p1._children) == 3) + self.assert_(len(p1.children) == 3) + + popped = p1.children.pop() + self.assert_(len(p1.children) == 2) + self.assert_(popped not in p1.children) + p1 = self.roundtrip(p1) + self.assert_(len(p1.children) == 2) + self.assert_(popped not in p1.children) + + p1.children[1] = 'changed-in-place' + self.assert_(p1.children[1] == 'changed-in-place') + inplace_id = p1._children[1].id + p1 = self.roundtrip(p1) + self.assert_(p1.children[1] == 'changed-in-place') + assert p1._children[1].id == inplace_id + + p1.children.append('changed-in-place') + self.assert_(p1.children.count('changed-in-place') == 2) + + p1.children.remove('changed-in-place') + self.assert_(p1.children.count('changed-in-place') == 1) + + p1 = self.roundtrip(p1) + self.assert_(p1.children.count('changed-in-place') == 1) + + p1._children = [] + self.assert_(len(p1.children) == 0) + + after = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] + p1.children = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] + self.assert_(len(p1.children) == 10) + self.assert_([c.name for c in p1._children] == after) + + p1.children[2:6] = ['x'] * 4 + after = ['a', 'b', 'x', 'x', 'x', 'x', 'g', 'h', 'i', 'j'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children[2:6] = ['y'] + after = ['a', 'b', 'y', 'g', 'h', 'i', 'j'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children[2:3] = ['z'] * 4 + after = ['a', 'b', 'z', 'z', 'z', 'z', 'g', 'h', 'i', 'j'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children[2::2] = ['O'] * 4 + after = ['a', 'b', 'O', 'z', 'O', 'z', 'O', 'h', 'O', 'j'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + assert_raises(TypeError, set, [p1.children]) + + p1.children *= 0 + after = [] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children += ['a', 'b'] + after = ['a', 'b'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children += ['c'] + after = ['a', 'b', 'c'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children *= 1 + after = ['a', 'b', 'c'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children *= 2 + after = ['a', 'b', 'c', 'a', 'b', 'c'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + p1.children = ['a'] + after = ['a'] + self.assert_(p1.children == after) + self.assert_([c.name for c in p1._children] == after) + + self.assert_((p1.children * 2) == ['a', 'a']) + self.assert_((2 * p1.children) == ['a', 'a']) + self.assert_((p1.children * 0) == []) + self.assert_((0 * p1.children) == []) + + self.assert_((p1.children + ['b']) == ['a', 'b']) + self.assert_((['b'] + p1.children) == ['b', 'a']) + + try: + p1.children + 123 + assert False + except TypeError: + assert True + +class DefaultTest(_CollectionOperations): + def __init__(self, *args, **kw): + super(DefaultTest, self).__init__(*args, **kw) + self.collection_class = None + + def test_sequence_ops(self): + self._test_sequence_ops() + + +class ListTest(_CollectionOperations): + def __init__(self, *args, **kw): + super(ListTest, self).__init__(*args, **kw) + self.collection_class = list + + def test_sequence_ops(self): + self._test_sequence_ops() + +class CustomListTest(ListTest): + def __init__(self, *args, **kw): + super(CustomListTest, self).__init__(*args, **kw) + self.collection_class = list + +# No-can-do until ticket #213 +class DictTest(_CollectionOperations): + pass + +class CustomDictTest(DictTest): + def __init__(self, *args, **kw): + super(DictTest, self).__init__(*args, **kw) + self.collection_class = DictCollection + + def test_mapping_ops(self): + Parent, Child = self.Parent, self.Child + + p1 = Parent('P1') + + self.assert_(not p1._children) + self.assert_(not p1.children) + + ch = Child('a', 'regular') + p1._children.append(ch) + + self.assert_(ch in p1._children.values()) + self.assert_(len(p1._children) == 1) + + self.assert_(p1.children) + self.assert_(len(p1.children) == 1) + self.assert_(ch not in p1.children) + self.assert_('a' in p1.children) + self.assert_(p1.children['a'] == 'regular') + self.assert_(p1._children['a'] == ch) + + p1.children['b'] = 'proxied' + + self.assert_('proxied' in p1.children.values()) + self.assert_('b' in p1.children) + self.assert_('proxied' not in p1._children) + self.assert_(len(p1.children) == 2) + self.assert_(len(p1._children) == 2) + + self.assert_(p1._children['a'].name == 'regular') + self.assert_(p1._children['b'].name == 'proxied') + + del p1._children['b'] + + self.assert_(len(p1._children) == 1) + self.assert_(len(p1.children) == 1) + self.assert_(p1._children['a'] == ch) + + del p1.children['a'] + + self.assert_(len(p1._children) == 0) + self.assert_(len(p1.children) == 0) + + p1.children = {'d': 'v d', 'e': 'v e', 'f': 'v f'} + self.assert_(len(p1._children) == 3) + self.assert_(len(p1.children) == 3) + + self.assert_(set(p1.children) == set(['d','e','f'])) + + del ch + p1 = self.roundtrip(p1) + self.assert_(len(p1._children) == 3) + self.assert_(len(p1.children) == 3) + + p1.children['e'] = 'changed-in-place' + self.assert_(p1.children['e'] == 'changed-in-place') + inplace_id = p1._children['e'].id + p1 = self.roundtrip(p1) + self.assert_(p1.children['e'] == 'changed-in-place') + self.assert_(p1._children['e'].id == inplace_id) + + p1._children = {} + self.assert_(len(p1.children) == 0) + + try: + p1._children = [] + self.assert_(False) + except TypeError: + self.assert_(True) + + try: + p1._children = None + self.assert_(False) + except TypeError: + self.assert_(True) + + assert_raises(TypeError, set, [p1.children]) + + +class SetTest(_CollectionOperations): + def __init__(self, *args, **kw): + super(SetTest, self).__init__(*args, **kw) + self.collection_class = set + + def test_set_operations(self): + Parent, Child = self.Parent, self.Child + + p1 = Parent('P1') + + self.assert_(not p1._children) + self.assert_(not p1.children) + + ch1 = Child('regular') + p1._children.add(ch1) + + self.assert_(ch1 in p1._children) + self.assert_(len(p1._children) == 1) + + self.assert_(p1.children) + self.assert_(len(p1.children) == 1) + self.assert_(ch1 not in p1.children) + self.assert_('regular' in p1.children) + + p1.children.add('proxied') + + self.assert_('proxied' in p1.children) + self.assert_('proxied' not in p1._children) + self.assert_(len(p1.children) == 2) + self.assert_(len(p1._children) == 2) + + self.assert_(set([o.name for o in p1._children]) == + set(['regular', 'proxied'])) + + ch2 = None + for o in p1._children: + if o.name == 'proxied': + ch2 = o + break + + p1._children.remove(ch2) + + self.assert_(len(p1._children) == 1) + self.assert_(len(p1.children) == 1) + self.assert_(p1._children == set([ch1])) + + p1.children.remove('regular') + + self.assert_(len(p1._children) == 0) + self.assert_(len(p1.children) == 0) + + p1.children = ['a','b','c'] + self.assert_(len(p1._children) == 3) + self.assert_(len(p1.children) == 3) + + del ch1 + p1 = self.roundtrip(p1) + + self.assert_(len(p1._children) == 3) + self.assert_(len(p1.children) == 3) + + self.assert_('a' in p1.children) + self.assert_('b' in p1.children) + self.assert_('d' not in p1.children) + + self.assert_(p1.children == set(['a','b','c'])) + + try: + p1.children.remove('d') + self.fail() + except KeyError: + pass + + self.assert_(len(p1.children) == 3) + p1.children.discard('d') + self.assert_(len(p1.children) == 3) + p1 = self.roundtrip(p1) + self.assert_(len(p1.children) == 3) + + popped = p1.children.pop() + self.assert_(len(p1.children) == 2) + self.assert_(popped not in p1.children) + p1 = self.roundtrip(p1) + self.assert_(len(p1.children) == 2) + self.assert_(popped not in p1.children) + + p1.children = ['a','b','c'] + p1 = self.roundtrip(p1) + self.assert_(p1.children == set(['a','b','c'])) + + p1.children.discard('b') + p1 = self.roundtrip(p1) + self.assert_(p1.children == set(['a', 'c'])) + + p1.children.remove('a') + p1 = self.roundtrip(p1) + self.assert_(p1.children == set(['c'])) + + p1._children = set() + self.assert_(len(p1.children) == 0) + + try: + p1._children = [] + self.assert_(False) + except TypeError: + self.assert_(True) + + try: + p1._children = None + self.assert_(False) + except TypeError: + self.assert_(True) + + assert_raises(TypeError, set, [p1.children]) + + + def test_set_comparisons(self): + Parent, Child = self.Parent, self.Child + + p1 = Parent('P1') + p1.children = ['a','b','c'] + control = set(['a','b','c']) + + for other in (set(['a','b','c']), set(['a','b','c','d']), + set(['a']), set(['a','b']), + set(['c','d']), set(['e', 'f', 'g']), + set()): + + eq_(p1.children.union(other), + control.union(other)) + eq_(p1.children.difference(other), + control.difference(other)) + eq_((p1.children - other), + (control - other)) + eq_(p1.children.intersection(other), + control.intersection(other)) + eq_(p1.children.symmetric_difference(other), + control.symmetric_difference(other)) + eq_(p1.children.issubset(other), + control.issubset(other)) + eq_(p1.children.issuperset(other), + control.issuperset(other)) + + self.assert_((p1.children == other) == (control == other)) + self.assert_((p1.children != other) == (control != other)) + self.assert_((p1.children < other) == (control < other)) + self.assert_((p1.children <= other) == (control <= other)) + self.assert_((p1.children > other) == (control > other)) + self.assert_((p1.children >= other) == (control >= other)) + + def test_set_mutation(self): + Parent, Child = self.Parent, self.Child + + # mutations + for op in ('update', 'intersection_update', + 'difference_update', 'symmetric_difference_update'): + for base in (['a', 'b', 'c'], []): + for other in (set(['a','b','c']), set(['a','b','c','d']), + set(['a']), set(['a','b']), + set(['c','d']), set(['e', 'f', 'g']), + set()): + p = Parent('p') + p.children = base[:] + control = set(base[:]) + + getattr(p.children, op)(other) + getattr(control, op)(other) + try: + self.assert_(p.children == control) + except: + print 'Test %s.%s(%s):' % (set(base), op, other) + print 'want', repr(control) + print 'got', repr(p.children) + raise + + p = self.roundtrip(p) + + try: + self.assert_(p.children == control) + except: + print 'Test %s.%s(%s):' % (base, op, other) + print 'want', repr(control) + print 'got', repr(p.children) + raise + + # in-place mutations + for op in ('|=', '-=', '&=', '^='): + for base in (['a', 'b', 'c'], []): + for other in (set(['a','b','c']), set(['a','b','c','d']), + set(['a']), set(['a','b']), + set(['c','d']), set(['e', 'f', 'g']), + frozenset(['e', 'f', 'g']), + set()): + p = Parent('p') + p.children = base[:] + control = set(base[:]) + + exec "p.children %s other" % op + exec "control %s other" % op + + try: + self.assert_(p.children == control) + except: + print 'Test %s %s %s:' % (set(base), op, other) + print 'want', repr(control) + print 'got', repr(p.children) + raise + + p = self.roundtrip(p) + + try: + self.assert_(p.children == control) + except: + print 'Test %s %s %s:' % (base, op, other) + print 'want', repr(control) + print 'got', repr(p.children) + raise + + +class CustomSetTest(SetTest): + def __init__(self, *args, **kw): + super(CustomSetTest, self).__init__(*args, **kw) + self.collection_class = SetCollection + +class CustomObjectTest(_CollectionOperations): + def __init__(self, *args, **kw): + super(CustomObjectTest, self).__init__(*args, **kw) + self.collection_class = ObjectCollection + + def test_basic(self): + Parent, Child = self.Parent, self.Child + + p = Parent('p1') + self.assert_(len(list(p.children)) == 0) + + p.children.append('child') + self.assert_(len(list(p.children)) == 1) + + p = self.roundtrip(p) + self.assert_(len(list(p.children)) == 1) + + # We didn't provide an alternate _AssociationList implementation for + # our ObjectCollection, so indexing will fail. + try: + v = p.children[1] + self.fail() + except TypeError: + pass + +class ScalarTest(TestBase): + def test_scalar_proxy(self): + metadata = MetaData(testing.db) + + parents_table = Table('Parent', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(128))) + children_table = Table('Children', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('parent_id', Integer, + ForeignKey('Parent.id')), + Column('foo', String(128)), + Column('bar', String(128)), + Column('baz', String(128))) + + class Parent(object): + foo = association_proxy('child', 'foo') + bar = association_proxy('child', 'bar', + creator=lambda v: Child(bar=v)) + baz = association_proxy('child', 'baz', + creator=lambda v: Child(baz=v)) + + def __init__(self, name): + self.name = name + + class Child(object): + def __init__(self, **kw): + for attr in kw: + setattr(self, attr, kw[attr]) + + mapper(Parent, parents_table, properties={ + 'child': relation(Child, lazy=False, + backref='parent', uselist=False)}) + mapper(Child, children_table) + + metadata.create_all() + session = create_session() + + def roundtrip(obj): + if obj not in session: + session.add(obj) + session.flush() + id, type_ = obj.id, type(obj) + session.expunge_all() + return session.query(type_).get(id) + + p = Parent('p') + + # No child + try: + v = p.foo + self.fail() + except: + pass + + p.child = Child(foo='a', bar='b', baz='c') + + self.assert_(p.foo == 'a') + self.assert_(p.bar == 'b') + self.assert_(p.baz == 'c') + + p.bar = 'x' + self.assert_(p.foo == 'a') + self.assert_(p.bar == 'x') + self.assert_(p.baz == 'c') + + p = roundtrip(p) + + self.assert_(p.foo == 'a') + self.assert_(p.bar == 'x') + self.assert_(p.baz == 'c') + + p.child = None + + # No child again + try: + v = p.foo + self.fail() + except: + pass + + # Bogus creator for this scalar type + try: + p.foo = 'zzz' + self.fail() + except TypeError: + pass + + p.bar = 'yyy' + + self.assert_(p.foo is None) + self.assert_(p.bar == 'yyy') + self.assert_(p.baz is None) + + del p.child + + p = roundtrip(p) + + self.assert_(p.child is None) + + p.baz = 'xxx' + + self.assert_(p.foo is None) + self.assert_(p.bar is None) + self.assert_(p.baz == 'xxx') + + p = roundtrip(p) + + self.assert_(p.foo is None) + self.assert_(p.bar is None) + self.assert_(p.baz == 'xxx') + + # Ensure an immediate __set__ works. + p2 = Parent('p2') + p2.bar = 'quux' + + +class LazyLoadTest(TestBase): + def setup(self): + metadata = MetaData(testing.db) + + parents_table = Table('Parent', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(128))) + children_table = Table('Children', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('parent_id', Integer, + ForeignKey('Parent.id')), + Column('foo', String(128)), + Column('name', String(128))) + + class Parent(object): + children = association_proxy('_children', 'name') + + def __init__(self, name): + self.name = name + + class Child(object): + def __init__(self, name): + self.name = name + + + mapper(Child, children_table) + metadata.create_all() + + self.metadata = metadata + self.session = create_session() + self.Parent, self.Child = Parent, Child + self.table = parents_table + + def teardown(self): + self.metadata.drop_all() + + def roundtrip(self, obj): + self.session.add(obj) + self.session.flush() + id, type_ = obj.id, type(obj) + self.session.expunge_all() + return self.session.query(type_).get(id) + + def test_lazy_list(self): + Parent, Child = self.Parent, self.Child + + mapper(Parent, self.table, properties={ + '_children': relation(Child, lazy=True, + collection_class=list)}) + + p = Parent('p') + p.children = ['a','b','c'] + + p = self.roundtrip(p) + + # Is there a better way to ensure that the association_proxy + # didn't convert a lazy load to an eager load? This does work though. + self.assert_('_children' not in p.__dict__) + self.assert_(len(p._children) == 3) + self.assert_('_children' in p.__dict__) + + def test_eager_list(self): + Parent, Child = self.Parent, self.Child + + mapper(Parent, self.table, properties={ + '_children': relation(Child, lazy=False, + collection_class=list)}) + + p = Parent('p') + p.children = ['a','b','c'] + + p = self.roundtrip(p) + + self.assert_('_children' in p.__dict__) + self.assert_(len(p._children) == 3) + + def test_lazy_scalar(self): + Parent, Child = self.Parent, self.Child + + mapper(Parent, self.table, properties={ + '_children': relation(Child, lazy=True, uselist=False)}) + + + p = Parent('p') + p.children = 'value' + + p = self.roundtrip(p) + + self.assert_('_children' not in p.__dict__) + self.assert_(p._children is not None) + + def test_eager_scalar(self): + Parent, Child = self.Parent, self.Child + + mapper(Parent, self.table, properties={ + '_children': relation(Child, lazy=False, uselist=False)}) + + + p = Parent('p') + p.children = 'value' + + p = self.roundtrip(p) + + self.assert_('_children' in p.__dict__) + self.assert_(p._children is not None) + + +class ReconstitutionTest(TestBase): + def setup(self): + metadata = MetaData(testing.db) + parents = Table('parents', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(30))) + children = Table('children', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('parent_id', Integer, ForeignKey('parents.id')), + Column('name', String(30))) + metadata.create_all() + parents.insert().execute(name='p1') + + class Parent(object): + kids = association_proxy('children', 'name') + def __init__(self, name): + self.name = name + + class Child(object): + def __init__(self, name): + self.name = name + + mapper(Parent, parents, properties=dict(children=relation(Child))) + mapper(Child, children) + + self.metadata = metadata + self.Parent = Parent + + def teardown(self): + self.metadata.drop_all() + + def test_weak_identity_map(self): + session = create_session(weak_identity_map=True) + + def add_child(parent_name, child_name): + parent = (session.query(self.Parent). + filter_by(name=parent_name)).one() + parent.kids.append(child_name) + + + add_child('p1', 'c1') + gc.collect() + add_child('p1', 'c2') + + session.flush() + p = session.query(self.Parent).filter_by(name='p1').one() + assert set(p.kids) == set(['c1', 'c2']), p.kids + + def test_copy(self): + import copy + p = self.Parent('p1') + p.kids.extend(['c1', 'c2']) + p_copy = copy.copy(p) + del p + gc.collect() + + assert set(p_copy.kids) == set(['c1', 'c2']), p.kids + + diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py new file mode 100644 index 000000000..ce2549099 --- /dev/null +++ b/test/ext/test_compiler.py @@ -0,0 +1,123 @@ +from sqlalchemy import * +from sqlalchemy.sql.expression import ClauseElement, ColumnClause +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql import table, column +from sqlalchemy.test import * + +class UserDefinedTest(TestBase, AssertsCompiledSQL): + + def test_column(self): + + class MyThingy(ColumnClause): + def __init__(self, arg= None): + super(MyThingy, self).__init__(arg or 'MYTHINGY!') + + @compiles(MyThingy) + def visit_thingy(thingy, compiler, **kw): + return ">>%s<<" % thingy.name + + self.assert_compile( + select([column('foo'), MyThingy()]), + "SELECT foo, >>MYTHINGY!<<" + ) + + self.assert_compile( + select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5), + "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1" + ) + + def test_stateful(self): + class MyThingy(ColumnClause): + def __init__(self): + super(MyThingy, self).__init__('MYTHINGY!') + + @compiles(MyThingy) + def visit_thingy(thingy, compiler, **kw): + if not hasattr(compiler, 'counter'): + compiler.counter = 0 + compiler.counter += 1 + return str(compiler.counter) + + self.assert_compile( + select([column('foo'), MyThingy()]).order_by(desc(MyThingy())), + "SELECT foo, 1 ORDER BY 2 DESC" + ) + + self.assert_compile( + select([MyThingy(), MyThingy()]).where(MyThingy() == 5), + "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1" + ) + + def test_callout_to_compiler(self): + class InsertFromSelect(ClauseElement): + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True), + compiler.process(element.select) + ) + + t1 = table("mytable", column('x'), column('y'), column('z')) + self.assert_compile( + InsertFromSelect( + t1, + select([t1]).where(t1.c.x>5) + ), + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)" + ) + + def test_dialect_specific(self): + class AddThingy(ClauseElement): + __visit_name__ = 'add_thingy' + + class DropThingy(ClauseElement): + __visit_name__ = 'drop_thingy' + + @compiles(AddThingy, 'sqlite') + def visit_add_thingy(thingy, compiler, **kw): + return "ADD SPECIAL SL THINGY" + + @compiles(AddThingy) + def visit_add_thingy(thingy, compiler, **kw): + return "ADD THINGY" + + @compiles(DropThingy) + def visit_drop_thingy(thingy, compiler, **kw): + return "DROP THINGY" + + self.assert_compile(AddThingy(), + "ADD THINGY" + ) + + self.assert_compile(DropThingy(), + "DROP THINGY" + ) + + from sqlalchemy.databases import sqlite as base + self.assert_compile(AddThingy(), + "ADD SPECIAL SL THINGY", + dialect=base.dialect() + ) + + self.assert_compile(DropThingy(), + "DROP THINGY", + dialect=base.dialect() + ) + + @compiles(DropThingy, 'sqlite') + def visit_drop_thingy(thingy, compiler, **kw): + return "DROP SPECIAL SL THINGY" + + self.assert_compile(DropThingy(), + "DROP SPECIAL SL THINGY", + dialect=base.dialect() + ) + + self.assert_compile(DropThingy(), + "DROP THINGY", + ) + diff --git a/test/ext/test_declarative.py b/test/ext/test_declarative.py new file mode 100644 index 000000000..c49c00cec --- /dev/null +++ b/test/ext/test_declarative.py @@ -0,0 +1,1545 @@ + +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy.ext import declarative as decl +from sqlalchemy import exc +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import MetaData, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred +from sqlalchemy.test.testing import eq_ + + +from test.orm._base import ComparableEntity, MappedTest + +class DeclarativeTestBase(testing.TestBase, testing.AssertsExecutionResults): + def setup(self): + global Base + Base = decl.declarative_base(testing.db) + + def teardown(self): + clear_mappers() + Base.metadata.drop_all() + +class DeclarativeTest(DeclarativeTestBase): + def test_basic(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user") + + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + email = Column(String(50), key='_email') + user_id = Column('user_id', Integer, ForeignKey('users.id'), + key='_user_id') + + Base.metadata.create_all() + + eq_(Address.__table__.c['id'].name, 'id') + eq_(Address.__table__.c['_email'].name, 'email') + eq_(Address.__table__.c['_user_id'].name, 'user_id') + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + a1 = sess.query(Address).filter(Address.email == 'two').one() + eq_(a1, Address(email='two')) + eq_(a1.user, User(name='u1')) + + def test_no_table(self): + def go(): + class User(Base): + id = Column('id', Integer, primary_key=True) + assert_raises_message(sa.exc.InvalidRequestError, "does not have a __table__", go) + + def test_cant_add_columns(self): + t = Table('t', Base.metadata, Column('id', Integer, primary_key=True), Column('data', String)) + def go(): + class User(Base): + __table__ = t + foo = Column(Integer, primary_key=True) + # can't specify new columns not already in the table + assert_raises_message(sa.exc.ArgumentError, "Can't add additional column 'foo' when specifying __table__", go) + + # regular re-mapping works tho + class Bar(Base): + __table__ = t + some_data = t.c.data + + assert class_mapper(Bar).get_property('some_data').columns[0] is t.c.data + + def test_undefer_column_name(self): + # TODO: not sure if there was an explicit + # test for this elsewhere + foo = Column(Integer) + eq_(str(foo), '(no name)') + eq_(foo.key, None) + eq_(foo.name, None) + decl._undefer_column_name('foo', foo) + eq_(str(foo), 'foo') + eq_(foo.key, 'foo') + eq_(foo.name, 'foo') + + def test_recompile_on_othermapper(self): + """declarative version of the same test in mappers.py""" + + from sqlalchemy.orm import mapperlib + + class User(Base): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + + class Address(Base): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + user = relation("User", primaryjoin=user_id == User.id, + backref="addresses") + + assert mapperlib._new_mappers is True + u = User() + assert User.addresses + assert mapperlib._new_mappers is False + + def test_string_dependency_resolution(self): + from sqlalchemy.sql import desc + + class User(Base, ComparableEntity): + __tablename__ = 'users' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + addresses = relation("Address", order_by="desc(Address.email)", + primaryjoin="User.id==Address.user_id", foreign_keys="[Address.user_id]", + backref=backref('user', primaryjoin="User.id==Address.user_id", foreign_keys="[Address.user_id]") + ) + + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer) # note no foreign key + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='ed', addresses=[Address(email='abc'), Address(email='def'), Address(email='xyz')]) + sess.add(u1) + sess.flush() + sess.expunge_all() + eq_(sess.query(User).filter(User.name == 'ed').one(), + User(name='ed', addresses=[Address(email='xyz'), Address(email='def'), Address(email='abc')]) + ) + + class Foo(Base, ComparableEntity): + __tablename__ = 'foo' + id = Column(Integer, primary_key=True) + rel = relation("User", primaryjoin="User.addresses==Foo.id") + assert_raises_message(exc.InvalidRequestError, "'addresses' is not an instance of ColumnProperty", compile_mappers) + + def test_string_dependency_resolution_in_backref(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + addresses = relation("Address", + primaryjoin="User.id==Address.user_id", + backref="user" + ) + + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + + compile_mappers() + eq_(str(User.addresses.property.primaryjoin), str(Address.user.property.primaryjoin)) + + + def test_uncompiled_attributes_in_relation(self): + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + addresses = relation("Address", order_by=Address.email, + foreign_keys=Address.user_id, + remote_side=Address.user_id, + ) + + # get the mapper for User. User mapper will compile, + # "addresses" relation will call upon Address.user_id for + # its clause element. Address.user_id is a _CompileOnAttr, + # which then calls class_mapper(Address). But ! We're already + # "in compilation", but class_mapper(Address) needs to initialize + # regardless, or COA's assertion fails + # and things generally go downhill from there. + class_mapper(User) + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='ed', addresses=[Address(email='abc'), Address(email='xyz'), Address(email='def')]) + sess.add(u1) + sess.flush() + sess.expunge_all() + eq_(sess.query(User).filter(User.name == 'ed').one(), + User(name='ed', addresses=[Address(email='abc'), Address(email='def'), Address(email='xyz')]) + ) + + def test_nice_dependency_error(self): + class User(Base): + __tablename__ = 'users' + id = Column('id', Integer, primary_key=True) + addresses = relation("Address") + + class Address(Base): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + foo = sa.orm.column_property(User.id == 5) + + # this used to raise an error when accessing User.id but that's no longer the case + # since we got rid of _CompileOnAttr. + assert_raises(sa.exc.ArgumentError, compile_mappers) + + def test_nice_dependency_error_works_with_hasattr(self): + class User(Base): + __tablename__ = 'users' + id = Column('id', Integer, primary_key=True) + addresses = relation("Addresss") + + # hasattr() on a compile-loaded attribute + hasattr(User.addresses, 'property') + # the exeption is preserved + assert_raises_message(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers) + + def test_custom_base(self): + class MyBase(object): + def foobar(self): + return "foobar" + Base = decl.declarative_base(cls=MyBase) + assert hasattr(Base, 'metadata') + assert Base().foobar() == "foobar" + + def test_index_doesnt_compile(self): + class User(Base): + __tablename__ = 'users' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + error = relation("Address") + + i = Index('my_index', User.name) + + # compile fails due to the nonexistent Addresses relation + assert_raises(sa.exc.InvalidRequestError, compile_mappers) + + # index configured + assert i in User.__table__.indexes + assert User.__table__.c.id not in set(i.columns) + assert User.__table__.c.name in set(i.columns) + + # tables create fine + Base.metadata.create_all() + + def test_add_prop(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + User.name = Column('name', String(50)) + User.addresses = relation("Address", backref="user") + + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + + id = Column(Integer, primary_key=True) + Address.email = Column(String(50), key='_email') + Address.user_id = Column('user_id', Integer, ForeignKey('users.id'), + key='_user_id') + + Base.metadata.create_all() + + eq_(Address.__table__.c['id'].name, 'id') + eq_(Address.__table__.c['_email'].name, 'email') + eq_(Address.__table__.c['_user_id'].name, 'user_id') + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + a1 = sess.query(Address).filter(Address.email == 'two').one() + eq_(a1, Address(email='two')) + eq_(a1.user, User(name='u1')) + + def test_eager_order_by(self): + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", order_by=Address.email) + + Base.metadata.create_all() + u1 = User(name='u1', addresses=[ + Address(email='two'), + Address(email='one'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + eq_(sess.query(User).options(eagerload(User.addresses)).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + def test_order_by_multi(self): + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", order_by=(Address.email, Address.id)) + + Base.metadata.create_all() + u1 = User(name='u1', addresses=[ + Address(email='two'), + Address(email='one'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + u = sess.query(User).filter(User.name == 'u1').one() + a = u.addresses + + def test_as_declarative(self): + class User(ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user") + + class Address(ComparableEntity): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + + reg = {} + decl.instrument_declarative(User, reg, Base.metadata) + decl.instrument_declarative(Address, reg, Base.metadata) + Base.metadata.create_all() + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + def test_custom_mapper(self): + class MyExt(sa.orm.MapperExtension): + def create_instance(self): + return "CHECK" + + def mymapper(cls, tbl, **kwargs): + kwargs['extension'] = MyExt() + return sa.orm.mapper(cls, tbl, **kwargs) + + from sqlalchemy.orm.mapper import Mapper + class MyMapper(Mapper): + def __init__(self, *args, **kwargs): + kwargs['extension'] = MyExt() + Mapper.__init__(self, *args, **kwargs) + + from sqlalchemy.orm import scoping + ss = scoping.ScopedSession(create_session) + ss.extension = MyExt() + ss_mapper = ss.mapper + + for mapperfunc in (mymapper, MyMapper, ss_mapper): + base = decl.declarative_base() + class Foo(base): + __tablename__ = 'foo' + __mapper_cls__ = mapperfunc + id = Column(Integer, primary_key=True) + eq_(Foo.__mapper__.compile().extension.create_instance(), 'CHECK') + + base = decl.declarative_base(mapper=mapperfunc) + class Foo(base): + __tablename__ = 'foo' + id = Column(Integer, primary_key=True) + eq_(Foo.__mapper__.compile().extension.create_instance(), 'CHECK') + + + @testing.emits_warning('Ignoring declarative-like tuple value of ' + 'attribute id') + def test_oops(self): + def define(): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True), + name = Column('name', String(50)) + assert False + assert_raises_message( + sa.exc.ArgumentError, + "Mapper Mapper|User|users could not assemble any primary key", + define) + + def test_table_args(self): + class Foo(Base): + __tablename__ = 'foo' + __table_args__ = {'mysql_engine':'InnoDB'} + id = Column('id', Integer, primary_key=True) + + assert Foo.__table__.kwargs['mysql_engine'] == 'InnoDB' + + class Bar(Base): + __tablename__ = 'bar' + __table_args__ = (ForeignKeyConstraint(['id'], ['foo.id']), {'mysql_engine':'InnoDB'}) + id = Column('id', Integer, primary_key=True) + + assert Bar.__table__.c.id.references(Foo.__table__.c.id) + assert Bar.__table__.kwargs['mysql_engine'] == 'InnoDB' + + def test_expression(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user") + + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + + User.address_count = sa.orm.column_property( + sa.select([sa.func.count(Address.id)]). + where(Address.user_id == User.id).as_scalar()) + + Base.metadata.create_all() + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), + [User(name='u1', address_count=2, addresses=[ + Address(email='one'), + Address(email='two')])]) + + def test_column(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + + User.a = Column('a', String(10)) + User.b = Column(String(10)) + + Base.metadata.create_all() + + u1 = User(name='u1', a='a', b='b') + eq_(u1.a, 'a') + eq_(User.a.get_history(u1), (['a'], (), ())) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), + [User(name='u1', a='a', b='b')]) + + def test_column_properties(self): + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + adr_count = sa.orm.column_property( + sa.select([sa.func.count(Address.id)], Address.user_id == id). + as_scalar()) + addresses = relation(Address) + + Base.metadata.create_all() + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), + [User(name='u1', adr_count=2, addresses=[ + Address(email='one'), + Address(email='two')])]) + + def test_column_properties_2(self): + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + # this is not "valid" but we want to test that Address.id doesnt + # get stuck into user's table + adr_count = Address.id + + eq_(set(User.__table__.c.keys()), set(['id', 'name'])) + eq_(set(Address.__table__.c.keys()), set(['id', 'email', 'user_id'])) + + def test_deferred(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True) + name = sa.orm.deferred(Column(String(50))) + + Base.metadata.create_all() + sess = create_session() + sess.add(User(name='u1')) + sess.flush() + sess.expunge_all() + + u1 = sess.query(User).filter(User.name == 'u1').one() + assert 'name' not in u1.__dict__ + def go(): + eq_(u1.name, 'u1') + self.assert_sql_count(testing.db, go, 1) + + def test_synonym_inline(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + _name = Column('name', String(50)) + def _set_name(self, name): + self._name = "SOMENAME " + name + def _get_name(self): + return self._name + name = sa.orm.synonym('_name', + descriptor=property(_get_name, _set_name)) + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + eq_(u1.name, "SOMENAME someuser") + sess.add(u1) + sess.flush() + eq_(sess.query(User).filter(User.name == "SOMENAME someuser").one(), u1) + + def test_synonym_no_descriptor(self): + from sqlalchemy.orm.properties import ColumnProperty + + class CustomCompare(ColumnProperty.Comparator): + __hash__ = None + def __eq__(self, other): + return self.__clause_element__() == other + ' FOO' + + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + _name = Column('name', String(50)) + name = sa.orm.synonym('_name', comparator_factory=CustomCompare) + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser FOO') + sess.add(u1) + sess.flush() + eq_(sess.query(User).filter(User.name == "someuser").one(), u1) + + def test_synonym_added(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + _name = Column('name', String(50)) + def _set_name(self, name): + self._name = "SOMENAME " + name + def _get_name(self): + return self._name + name = property(_get_name, _set_name) + User.name = sa.orm.synonym('_name', descriptor=User.name) + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + eq_(u1.name, "SOMENAME someuser") + sess.add(u1) + sess.flush() + eq_(sess.query(User).filter(User.name == "SOMENAME someuser").one(), u1) + + def test_reentrant_compile_via_foreignkey(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user") + + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey(User.id)) + + # previous versions would force a re-entrant mapper compile + # via the User.id inside the ForeignKey but this is no + # longer the case + sa.orm.compile_mappers() + + eq_(str(Address.user_id.property.columns[0].foreign_keys[0]), "ForeignKey('users.id')") + + Base.metadata.create_all() + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + def test_relation_reference(self): + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + + id = Column('id', Integer, primary_key=True) + email = Column('email', String(50)) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + addresses = relation("Address", backref="user", + primaryjoin=id == Address.user_id) + + User.address_count = sa.orm.column_property( + sa.select([sa.func.count(Address.id)]). + where(Address.user_id == User.id).as_scalar()) + + Base.metadata.create_all() + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), + [User(name='u1', address_count=2, addresses=[ + Address(email='one'), + Address(email='two')])]) + + def test_pk_with_fk_init(self): + class Bar(Base): + __tablename__ = 'bar' + + id = sa.Column(sa.Integer, sa.ForeignKey("foo.id"), primary_key=True) + ex = sa.Column(sa.Integer, primary_key=True) + + class Foo(Base): + __tablename__ = 'foo' + + id = sa.Column(sa.Integer, primary_key=True) + bars = sa.orm.relation(Bar) + + assert Bar.__mapper__.primary_key[0] is Bar.__table__.c.id + assert Bar.__mapper__.primary_key[1] is Bar.__table__.c.ex + + + def test_with_explicit_autoloaded(self): + meta = MetaData(testing.db) + t1 = Table('t1', meta, + Column('id', String(50), primary_key=True), + Column('data', String(50))) + meta.create_all() + try: + class MyObj(Base): + __table__ = Table('t1', Base.metadata, autoload=True) + + sess = create_session() + m = MyObj(id="someid", data="somedata") + sess.add(m) + sess.flush() + + eq_(t1.select().execute().fetchall(), [('someid', 'somedata')]) + finally: + meta.drop_all() + +class DeclarativeInheritanceTest(DeclarativeTestBase): + def test_custom_join_condition(self): + class Foo(Base): + __tablename__ = 'foo' + id = Column('id', Integer, primary_key=True) + + class Bar(Foo): + __tablename__ = 'bar' + id = Column('id', Integer, primary_key=True) + foo_id = Column('foo_id', Integer) + __mapper_args__ = {'inherit_condition':foo_id==Foo.id} + + # compile succeeds because inherit_condition is honored + compile_mappers() + + def test_joined(self): + class Company(Base, ComparableEntity): + __tablename__ = 'companies' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + employees = relation("Person") + + class Person(Base, ComparableEntity): + __tablename__ = 'people' + id = Column('id', Integer, primary_key=True) + company_id = Column('company_id', Integer, + ForeignKey('companies.id')) + name = Column('name', String(50)) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'polymorphic_identity':'engineer'} + id = Column('id', Integer, ForeignKey('people.id'), primary_key=True) + primary_language = Column('primary_language', String(50)) + + class Manager(Person): + __tablename__ = 'managers' + __mapper_args__ = {'polymorphic_identity':'manager'} + id = Column('id', Integer, ForeignKey('people.id'), primary_key=True) + golf_swing = Column('golf_swing', String(50)) + + Base.metadata.create_all() + + sess = create_session() + + c1 = Company(name="MegaCorp, Inc.", employees=[ + Engineer(name="dilbert", primary_language="java"), + Engineer(name="wally", primary_language="c++"), + Manager(name="dogbert", golf_swing="fore!") + ]) + + c2 = Company(name="Elbonia, Inc.", employees=[ + Engineer(name="vlad", primary_language="cobol") + ]) + + sess.add(c1) + sess.add(c2) + sess.flush() + sess.expunge_all() + + eq_((sess.query(Company). + filter(Company.employees.of_type(Engineer). + any(Engineer.primary_language == 'cobol')).first()), + c2) + + # ensure that the Manager mapper was compiled + # with the Person id column as higher priority. + # this ensures that "id" will get loaded from the Person row + # and not the possibly non-present Manager row + assert Manager.id.property.columns == [Person.__table__.c.id, Manager.__table__.c.id] + + # assert that the "id" column is available without a second load. + # this would be the symptom of the previous step not being correct. + sess.expunge_all() + def go(): + assert sess.query(Manager).filter(Manager.name=='dogbert').one().id + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + def go(): + assert sess.query(Person).filter(Manager.name=='dogbert').one().id + self.assert_sql_count(testing.db, go, 1) + + def test_subclass_mixin(self): + class Person(Base, ComparableEntity): + __tablename__ = 'people' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class MyMixin(object): + pass + + class Engineer(MyMixin, Person): + __tablename__ = 'engineers' + __mapper_args__ = {'polymorphic_identity':'engineer'} + id = Column('id', Integer, ForeignKey('people.id'), primary_key=True) + primary_language = Column('primary_language', String(50)) + + assert class_mapper(Engineer).inherits is class_mapper(Person) + + def test_with_undefined_foreignkey(self): + class Parent(Base): + __tablename__ = 'parent' + id = Column('id', Integer, primary_key=True) + tp = Column('type', String(50)) + __mapper_args__ = dict(polymorphic_on = tp) + + class Child1(Parent): + __tablename__ = 'child1' + id = Column('id', Integer, ForeignKey('parent.id'), primary_key=True) + related_child2 = Column('c2', Integer, ForeignKey('child2.id')) + __mapper_args__ = dict(polymorphic_identity = 'child1') + + # no exception is raised by the ForeignKey to "child2" even though + # child2 doesn't exist yet + + class Child2(Parent): + __tablename__ = 'child2' + id = Column('id', Integer, ForeignKey('parent.id'), primary_key=True) + related_child1 = Column('c1', Integer) + __mapper_args__ = dict(polymorphic_identity = 'child2') + + sa.orm.compile_mappers() # no exceptions here + + def test_single_colsonbase(self): + """test single inheritance where all the columns are on the base class.""" + + class Company(Base, ComparableEntity): + __tablename__ = 'companies' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + employees = relation("Person") + + class Person(Base, ComparableEntity): + __tablename__ = 'people' + id = Column('id', Integer, primary_key=True) + company_id = Column('company_id', Integer, + ForeignKey('companies.id')) + name = Column('name', String(50)) + discriminator = Column('type', String(50)) + primary_language = Column('primary_language', String(50)) + golf_swing = Column('golf_swing', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity':'engineer'} + + class Manager(Person): + __mapper_args__ = {'polymorphic_identity':'manager'} + + Base.metadata.create_all() + + sess = create_session() + c1 = Company(name="MegaCorp, Inc.", employees=[ + Engineer(name="dilbert", primary_language="java"), + Engineer(name="wally", primary_language="c++"), + Manager(name="dogbert", golf_swing="fore!") + ]) + + c2 = Company(name="Elbonia, Inc.", employees=[ + Engineer(name="vlad", primary_language="cobol") + ]) + + sess.add(c1) + sess.add(c2) + sess.flush() + sess.expunge_all() + + eq_((sess.query(Person). + filter(Engineer.primary_language == 'cobol').first()), + Engineer(name='vlad')) + eq_((sess.query(Company). + filter(Company.employees.of_type(Engineer). + any(Engineer.primary_language == 'cobol')).first()), + c2) + + def test_single_colsonsub(self): + """test single inheritance where the columns are local to their class. + + this is a newer usage. + + """ + + class Company(Base, ComparableEntity): + __tablename__ = 'companies' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + employees = relation("Person") + + class Person(Base, ComparableEntity): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + company_id = Column(Integer, + ForeignKey('companies.id')) + name = Column(String(50)) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity':'engineer'} + primary_language = Column(String(50)) + + class Manager(Person): + __mapper_args__ = {'polymorphic_identity':'manager'} + golf_swing = Column(String(50)) + + # we have here a situation that is somewhat unique. + # the Person class is mapped to the "people" table, but it + # was mapped when the table did not include the "primary_language" + # or "golf_swing" columns. declarative will also manipulate + # the exclude_properties collection so that sibling classes + # don't cross-pollinate. + + assert Person.__table__.c.company_id + assert Person.__table__.c.golf_swing + assert Person.__table__.c.primary_language + assert Engineer.primary_language + assert Manager.golf_swing + assert not hasattr(Person, 'primary_language') + assert not hasattr(Person, 'golf_swing') + assert not hasattr(Engineer, 'golf_swing') + assert not hasattr(Manager, 'primary_language') + + Base.metadata.create_all() + + sess = create_session() + + e1 = Engineer(name="dilbert", primary_language="java") + e2 = Engineer(name="wally", primary_language="c++") + m1 = Manager(name="dogbert", golf_swing="fore!") + c1 = Company(name="MegaCorp, Inc.", employees=[e1, e2, m1]) + + e3 =Engineer(name="vlad", primary_language="cobol") + c2 = Company(name="Elbonia, Inc.", employees=[e3]) + sess.add(c1) + sess.add(c2) + sess.flush() + sess.expunge_all() + + eq_((sess.query(Person). + filter(Engineer.primary_language == 'cobol').first()), + Engineer(name='vlad')) + eq_((sess.query(Company). + filter(Company.employees.of_type(Engineer). + any(Engineer.primary_language == 'cobol')).first()), + c2) + + eq_( + sess.query(Engineer).filter_by(primary_language='cobol').one(), + Engineer(name="vlad", primary_language="cobol") + ) + + def test_joined_from_single(self): + class Company(Base, ComparableEntity): + __tablename__ = 'companies' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + employees = relation("Person") + + class Person(Base, ComparableEntity): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + company_id = Column(Integer, ForeignKey('companies.id')) + name = Column(String(50)) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class Manager(Person): + __mapper_args__ = {'polymorphic_identity':'manager'} + golf_swing = Column(String(50)) + + class Engineer(Person): + __tablename__ = 'engineers' + __mapper_args__ = {'polymorphic_identity':'engineer'} + id = Column(Integer, ForeignKey('people.id'), primary_key=True) + primary_language = Column(String(50)) + + assert Person.__table__.c.golf_swing + assert not Person.__table__.c.has_key('primary_language') + assert Engineer.__table__.c.primary_language + assert Engineer.primary_language + assert Manager.golf_swing + assert not hasattr(Person, 'primary_language') + assert not hasattr(Person, 'golf_swing') + assert not hasattr(Engineer, 'golf_swing') + assert not hasattr(Manager, 'primary_language') + + Base.metadata.create_all() + + sess = create_session() + + e1 = Engineer(name="dilbert", primary_language="java") + e2 = Engineer(name="wally", primary_language="c++") + m1 = Manager(name="dogbert", golf_swing="fore!") + c1 = Company(name="MegaCorp, Inc.", employees=[e1, e2, m1]) + e3 =Engineer(name="vlad", primary_language="cobol") + c2 = Company(name="Elbonia, Inc.", employees=[e3]) + sess.add(c1) + sess.add(c2) + sess.flush() + sess.expunge_all() + + eq_((sess.query(Person).with_polymorphic(Engineer). + filter(Engineer.primary_language == 'cobol').first()), + Engineer(name='vlad')) + eq_((sess.query(Company). + filter(Company.employees.of_type(Engineer). + any(Engineer.primary_language == 'cobol')).first()), + c2) + + eq_( + sess.query(Engineer).filter_by(primary_language='cobol').one(), + Engineer(name="vlad", primary_language="cobol") + ) + + def test_add_deferred(self): + class Person(Base, ComparableEntity): + __tablename__ = 'people' + id = Column('id', Integer, primary_key=True) + + Person.name = deferred(Column(String(10))) + + Base.metadata.create_all() + sess = create_session() + p = Person(name='ratbert') + + sess.add(p) + sess.flush() + sess.expunge_all() + eq_( + sess.query(Person).all(), + [ + Person(name='ratbert') + ] + ) + person = sess.query(Person).filter(Person.name == 'ratbert').one() + assert 'name' not in person.__dict__ + + def test_single_fksonsub(self): + """test single inheritance with a foreign key-holding column on a subclass. + + """ + + class Person(Base, ComparableEntity): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity':'engineer'} + primary_language_id = Column(Integer, ForeignKey('languages.id')) + primary_language = relation("Language") + + class Language(Base, ComparableEntity): + __tablename__ = 'languages' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + assert not hasattr(Person, 'primary_language_id') + + Base.metadata.create_all() + + sess = create_session() + + java, cpp, cobol = Language(name='java'),Language(name='cpp'), Language(name='cobol') + e1 = Engineer(name="dilbert", primary_language=java) + e2 = Engineer(name="wally", primary_language=cpp) + e3 =Engineer(name="vlad", primary_language=cobol) + sess.add_all([e1, e2, e3]) + sess.flush() + sess.expunge_all() + + eq_((sess.query(Person). + filter(Engineer.primary_language.has(Language.name=='cobol')).first()), + Engineer(name='vlad', primary_language=Language(name='cobol'))) + + eq_( + sess.query(Engineer).filter(Engineer.primary_language.has(Language.name=='cobol')).one(), + Engineer(name="vlad", primary_language=Language(name='cobol')) + ) + + eq_( + sess.query(Person).join(Engineer.primary_language).order_by(Language.name).all(), + [ + Engineer(name='vlad', primary_language=Language(name='cobol')), + Engineer(name='wally', primary_language=Language(name='cpp')), + Engineer(name='dilbert', primary_language=Language(name='java')), + ] + ) + + def test_single_three_levels(self): + class Person(Base, ComparableEntity): + __tablename__ = 'people' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity':'engineer'} + primary_language = Column(String(50)) + + class JuniorEngineer(Engineer): + __mapper_args__ = {'polymorphic_identity':'junior_engineer'} + nerf_gun = Column(String(50)) + + class Manager(Person): + __mapper_args__ = {'polymorphic_identity':'manager'} + golf_swing = Column(String(50)) + + assert JuniorEngineer.nerf_gun + assert JuniorEngineer.primary_language + assert JuniorEngineer.name + assert Manager.golf_swing + assert Engineer.primary_language + assert not hasattr(Engineer, 'golf_swing') + assert not hasattr(Engineer, 'nerf_gun') + assert not hasattr(Manager, 'nerf_gun') + assert not hasattr(Manager, 'primary_language') + + def test_single_no_special_cols(self): + class Person(Base, ComparableEntity): + __tablename__ = 'people' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + def go(): + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity':'engineer'} + primary_language = Column('primary_language', String(50)) + foo_bar = Column(Integer, primary_key=True) + assert_raises_message(sa.exc.ArgumentError, "place primary key", go) + + def test_single_no_table_args(self): + class Person(Base, ComparableEntity): + __tablename__ = 'people' + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + discriminator = Column('type', String(50)) + __mapper_args__ = {'polymorphic_on':discriminator} + + def go(): + class Engineer(Person): + __mapper_args__ = {'polymorphic_identity':'engineer'} + primary_language = Column('primary_language', String(50)) + __table_args__ = () + assert_raises_message(sa.exc.ArgumentError, "place __table_args__", go) + + def test_concrete(self): + engineers = Table('engineers', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('primary_language', String(50)) + ) + managers = Table('managers', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('golf_swing', String(50)) + ) + + punion = polymorphic_union({ + 'engineer':engineers, + 'manager':managers + }, 'type', 'punion') + + class Person(Base, ComparableEntity): + __table__ = punion + __mapper_args__ = {'polymorphic_on':punion.c.type} + + class Engineer(Person): + __table__ = engineers + __mapper_args__ = {'polymorphic_identity':'engineer', 'concrete':True} + + class Manager(Person): + __table__ = managers + __mapper_args__ = {'polymorphic_identity':'manager', 'concrete':True} + + Base.metadata.create_all() + sess = create_session() + + e1 = Engineer(name="dilbert", primary_language="java") + e2 = Engineer(name="wally", primary_language="c++") + m1 = Manager(name="dogbert", golf_swing="fore!") + e3 = Engineer(name="vlad", primary_language="cobol") + + sess.add_all([e1, e2, m1, e3]) + sess.flush() + sess.expunge_all() + eq_( + sess.query(Person).order_by(Person.name).all(), + [ + Engineer(name='dilbert'), Manager(name='dogbert'), + Engineer(name='vlad'), Engineer(name='wally') + ] + ) + + +def _produce_test(inline, stringbased): + class ExplicitJoinTest(MappedTest): + + @classmethod + def define_tables(cls, metadata): + global User, Address + Base = decl.declarative_base(metadata=metadata) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + if inline: + if stringbased: + user = relation("User", primaryjoin="User.id==Address.user_id", backref="addresses") + else: + user = relation(User, primaryjoin=User.id==user_id, backref="addresses") + + if not inline: + compile_mappers() + if stringbased: + Address.user = relation("User", primaryjoin="User.id==Address.user_id", backref="addresses") + else: + Address.user = relation(User, primaryjoin=User.id==Address.user_id, backref="addresses") + + @classmethod + def insert_data(cls): + params = [dict(zip(('id', 'name'), column_values)) for column_values in + [(7, 'jack'), + (8, 'ed'), + (9, 'fred'), + (10, 'chuck')] + ] + User.__table__.insert().execute(params) + + Address.__table__.insert().execute( + [dict(zip(('id', 'user_id', 'email'), column_values)) for column_values in + [(1, 7, "jack@bean.com"), + (2, 8, "ed@wood.com"), + (3, 8, "ed@bettyboop.com"), + (4, 8, "ed@lala.com"), + (5, 9, "fred@fred.com")] + ] + ) + + def test_aliased_join(self): + # this query will screw up if the aliasing + # enabled in query.join() gets applied to the right half of the join condition inside the any(). + # the join condition inside of any() comes from the "primaryjoin" of the relation, + # and should not be annotated with _orm_adapt. PropertyLoader.Comparator will annotate + # the left side with _orm_adapt, though. + sess = create_session() + eq_( + sess.query(User).join(User.addresses, aliased=True). + filter(Address.email=='ed@wood.com').filter(User.addresses.any(Address.email=='jack@bean.com')).all(), + [] + ) + + ExplicitJoinTest.__name__ = "ExplicitJoinTest%s%s" % (inline and 'Inline' or 'Separate', stringbased and 'String' or 'Literal') + return ExplicitJoinTest + +for inline in (True, False): + for stringbased in (True, False): + testclass = _produce_test(inline, stringbased) + exec("%s = testclass" % testclass.__name__) + del testclass + +class DeclarativeReflectionTest(testing.TestBase): + @classmethod + def setup_class(cls): + global reflection_metadata + reflection_metadata = MetaData(testing.db) + + Table('users', reflection_metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + test_needs_fk=True) + Table('addresses', reflection_metadata, + Column('id', Integer, primary_key=True), + Column('email', String(50)), + Column('user_id', Integer, ForeignKey('users.id')), + test_needs_fk=True) + Table('imhandles', reflection_metadata, + Column('id', Integer, primary_key=True), + Column('user_id', Integer), + Column('network', String(50)), + Column('handle', String(50)), + test_needs_fk=True) + + reflection_metadata.create_all() + + def setup(self): + global Base + Base = decl.declarative_base(testing.db) + + def teardown(self): + for t in reversed(reflection_metadata.sorted_tables): + t.delete().execute() + + @classmethod + def teardown_class(cls): + reflection_metadata.drop_all() + + def test_basic(self): + meta = MetaData(testing.db) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + __autoload__ = True + addresses = relation("Address", backref="user") + + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + __autoload__ = True + + u1 = User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), [User(name='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + a1 = sess.query(Address).filter(Address.email == 'two').one() + eq_(a1, Address(email='two')) + eq_(a1.user, User(name='u1')) + + def test_rekey(self): + meta = MetaData(testing.db) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + __autoload__ = True + nom = Column('name', String(50), key='nom') + addresses = relation("Address", backref="user") + + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + __autoload__ = True + + u1 = User(nom='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), [User(nom='u1', addresses=[ + Address(email='one'), + Address(email='two'), + ])]) + + a1 = sess.query(Address).filter(Address.email == 'two').one() + eq_(a1, Address(email='two')) + eq_(a1.user, User(nom='u1')) + + assert_raises(TypeError, User, name='u3') + + def test_supplied_fk(self): + meta = MetaData(testing.db) + + class IMHandle(Base, ComparableEntity): + __tablename__ = 'imhandles' + __autoload__ = True + + user_id = Column('user_id', Integer, + ForeignKey('users.id')) + class User(Base, ComparableEntity): + __tablename__ = 'users' + __autoload__ = True + handles = relation("IMHandle", backref="user") + + u1 = User(name='u1', handles=[ + IMHandle(network='blabber', handle='foo'), + IMHandle(network='lol', handle='zomg') + ]) + sess = create_session() + sess.add(u1) + sess.flush() + sess.expunge_all() + + eq_(sess.query(User).all(), [User(name='u1', handles=[ + IMHandle(network='blabber', handle='foo'), + IMHandle(network='lol', handle='zomg') + ])]) + + a1 = sess.query(IMHandle).filter(IMHandle.handle == 'zomg').one() + eq_(a1, IMHandle(network='lol', handle='zomg')) + eq_(a1.user, User(name='u1')) + + def test_synonym_for(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + + @decl.synonym_for('name') + @property + def namesyn(self): + return self.name + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + eq_(u1.name, "someuser") + eq_(u1.namesyn, 'someuser') + sess.add(u1) + sess.flush() + + rt = sess.query(User).filter(User.namesyn == 'someuser').one() + eq_(rt, u1) + + def test_comparable_using(self): + class NameComparator(sa.orm.PropComparator): + @property + def upperself(self): + cls = self.prop.parent.class_ + col = getattr(cls, 'name') + return sa.func.upper(col) + + def operate(self, op, other, **kw): + return op(self.upperself, other, **kw) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True) + name = Column('name', String(50)) + + @decl.comparable_using(NameComparator) + @property + def uc_name(self): + return self.name is not None and self.name.upper() or None + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + eq_(u1.name, "someuser", u1.name) + eq_(u1.uc_name, 'SOMEUSER', u1.uc_name) + sess.add(u1) + sess.flush() + sess.expunge_all() + + rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one() + eq_(rt, u1) + sess.expunge_all() + + rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one() + eq_(rt, u1) + + +if __name__ == '__main__': + testing.main() diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py new file mode 100644 index 000000000..4adc77960 --- /dev/null +++ b/test/ext/test_orderinglist.py @@ -0,0 +1,400 @@ +from sqlalchemy import * +from sqlalchemy.orm import * +from sqlalchemy.ext.orderinglist import * +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test import * + + +metadata = None + +# order in whole steps +def step_numbering(step): + def f(index, collection): + return step * index + return f + +# almost fibonacci- skip the first 2 steps +# e.g. 1, 2, 3, 5, 8, ... instead of 0, 1, 1, 2, 3, ... +# otherwise ordering of the elements at '1' is undefined... ;) +def fibonacci_numbering(order_col): + def f(index, collection): + if index == 0: + return 1 + elif index == 1: + return 2 + else: + return (getattr(collection[index - 1], order_col) + + getattr(collection[index - 2], order_col)) + return f + +# 0 -> A, 1 -> B, ... 25 -> Z, 26 -> AA, 27 -> AB, ... +def alpha_ordering(index, collection): + s = '' + while index > 25: + d = index / 26 + s += chr((d % 26) + 64) + index -= d * 26 + s += chr(index + 65) + return s + +class OrderingListTest(TestBase): + def setup(self): + global metadata, slides_table, bullets_table, Slide, Bullet + slides_table, bullets_table = None, None + Slide, Bullet = None, None + if metadata: + metadata.clear() + + def _setup(self, test_collection_class): + """Build a relation situation using the given test_collection_class + factory""" + + global metadata, slides_table, bullets_table, Slide, Bullet + + metadata = MetaData(testing.db) + slides_table = Table('test_Slides', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(128))) + bullets_table = Table('test_Bullets', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('slide_id', Integer, + ForeignKey('test_Slides.id')), + Column('position', Integer), + Column('text', String(128))) + + class Slide(object): + def __init__(self, name): + self.name = name + def __repr__(self): + return '' % self.name + + class Bullet(object): + def __init__(self, text): + self.text = text + def __repr__(self): + return '' % (self.text, self.position) + + mapper(Slide, slides_table, properties={ + 'bullets': relation(Bullet, lazy=False, + collection_class=test_collection_class, + backref='slide', + order_by=[bullets_table.c.position]) + }) + mapper(Bullet, bullets_table) + + metadata.create_all() + + def teardown(self): + metadata.drop_all() + + def test_append_no_reorder(self): + self._setup(ordering_list('position', count_from=1, + reorder_on_append=False)) + + s1 = Slide('Slide #1') + + self.assert_(not s1.bullets) + self.assert_(len(s1.bullets) == 0) + + s1.bullets.append(Bullet('s1/b1')) + + self.assert_(s1.bullets) + self.assert_(len(s1.bullets) == 1) + self.assert_(s1.bullets[0].position == 1) + + s1.bullets.append(Bullet('s1/b2')) + + self.assert_(len(s1.bullets) == 2) + self.assert_(s1.bullets[0].position == 1) + self.assert_(s1.bullets[1].position == 2) + + bul = Bullet('s1/b100') + bul.position = 100 + s1.bullets.append(bul) + + self.assert_(s1.bullets[0].position == 1) + self.assert_(s1.bullets[1].position == 2) + self.assert_(s1.bullets[2].position == 100) + + s1.bullets.append(Bullet('s1/b4')) + self.assert_(s1.bullets[0].position == 1) + self.assert_(s1.bullets[1].position == 2) + self.assert_(s1.bullets[2].position == 100) + self.assert_(s1.bullets[3].position == 4) + + s1.bullets._reorder() + self.assert_(s1.bullets[0].position == 1) + self.assert_(s1.bullets[1].position == 2) + self.assert_(s1.bullets[2].position == 3) + self.assert_(s1.bullets[3].position == 4) + + session = create_session() + session.add(s1) + session.flush() + + id = s1.id + session.expunge_all() + del s1 + + srt = session.query(Slide).get(id) + + self.assert_(srt.bullets) + self.assert_(len(srt.bullets) == 4) + + titles = ['s1/b1','s1/b2','s1/b100','s1/b4'] + found = [b.text for b in srt.bullets] + + self.assert_(titles == found) + + def test_append_reorder(self): + self._setup(ordering_list('position', count_from=1, + reorder_on_append=True)) + + s1 = Slide('Slide #1') + + self.assert_(not s1.bullets) + self.assert_(len(s1.bullets) == 0) + + s1.bullets.append(Bullet('s1/b1')) + + self.assert_(s1.bullets) + self.assert_(len(s1.bullets) == 1) + self.assert_(s1.bullets[0].position == 1) + + s1.bullets.append(Bullet('s1/b2')) + + self.assert_(len(s1.bullets) == 2) + self.assert_(s1.bullets[0].position == 1) + self.assert_(s1.bullets[1].position == 2) + + bul = Bullet('s1/b100') + bul.position = 100 + s1.bullets.append(bul) + + self.assert_(s1.bullets[0].position == 1) + self.assert_(s1.bullets[1].position == 2) + self.assert_(s1.bullets[2].position == 3) + + s1.bullets.append(Bullet('s1/b4')) + self.assert_(s1.bullets[0].position == 1) + self.assert_(s1.bullets[1].position == 2) + self.assert_(s1.bullets[2].position == 3) + self.assert_(s1.bullets[3].position == 4) + + s1.bullets._reorder() + self.assert_(s1.bullets[0].position == 1) + self.assert_(s1.bullets[1].position == 2) + self.assert_(s1.bullets[2].position == 3) + self.assert_(s1.bullets[3].position == 4) + + s1.bullets._raw_append(Bullet('raw')) + self.assert_(s1.bullets[4].position is None) + + s1.bullets._reorder() + self.assert_(s1.bullets[4].position == 5) + session = create_session() + session.add(s1) + session.flush() + + id = s1.id + session.expunge_all() + del s1 + + srt = session.query(Slide).get(id) + + self.assert_(srt.bullets) + self.assert_(len(srt.bullets) == 5) + + titles = ['s1/b1','s1/b2','s1/b100','s1/b4', 'raw'] + found = [b.text for b in srt.bullets] + eq_(titles, found) + + srt.bullets._raw_append(Bullet('raw2')) + srt.bullets[-1].position = 6 + session.flush() + session.expunge_all() + + srt = session.query(Slide).get(id) + titles = ['s1/b1','s1/b2','s1/b100','s1/b4', 'raw', 'raw2'] + found = [b.text for b in srt.bullets] + eq_(titles, found) + + def test_insert(self): + self._setup(ordering_list('position')) + + s1 = Slide('Slide #1') + s1.bullets.append(Bullet('1')) + s1.bullets.append(Bullet('2')) + s1.bullets.append(Bullet('3')) + s1.bullets.append(Bullet('4')) + + self.assert_(s1.bullets[0].position == 0) + self.assert_(s1.bullets[1].position == 1) + self.assert_(s1.bullets[2].position == 2) + self.assert_(s1.bullets[3].position == 3) + + s1.bullets.insert(2, Bullet('insert_at_2')) + self.assert_(s1.bullets[0].position == 0) + self.assert_(s1.bullets[1].position == 1) + self.assert_(s1.bullets[2].position == 2) + self.assert_(s1.bullets[3].position == 3) + self.assert_(s1.bullets[4].position == 4) + + self.assert_(s1.bullets[1].text == '2') + self.assert_(s1.bullets[2].text == 'insert_at_2') + self.assert_(s1.bullets[3].text == '3') + + s1.bullets.insert(999, Bullet('999')) + + self.assert_(len(s1.bullets) == 6) + self.assert_(s1.bullets[5].position == 5) + + session = create_session() + session.add(s1) + session.flush() + + id = s1.id + session.expunge_all() + del s1 + + srt = session.query(Slide).get(id) + + self.assert_(srt.bullets) + self.assert_(len(srt.bullets) == 6) + + texts = ['1','2','insert_at_2','3','4','999'] + found = [b.text for b in srt.bullets] + + self.assert_(texts == found) + + def test_slice(self): + self._setup(ordering_list('position')) + + b = [ Bullet('1'), Bullet('2'), Bullet('3'), + Bullet('4'), Bullet('5'), Bullet('6') ] + s1 = Slide('Slide #1') + + # 1, 2, 3 + s1.bullets[0:3] = b[0:3] + for i in 0, 1, 2: + self.assert_(s1.bullets[i].position == i) + self.assert_(s1.bullets[i] == b[i]) + + # 1, 4, 5, 6, 3 + s1.bullets[1:2] = b[3:6] + for li, bi in (0,0), (1,3), (2,4), (3,5), (4,2): + self.assert_(s1.bullets[li].position == li) + self.assert_(s1.bullets[li] == b[bi]) + + # 1, 6, 3 + del s1.bullets[1:3] + for li, bi in (0,0), (1,5), (2,2): + self.assert_(s1.bullets[li].position == li) + self.assert_(s1.bullets[li] == b[bi]) + + session = create_session() + session.add(s1) + session.flush() + + id = s1.id + session.expunge_all() + del s1 + + srt = session.query(Slide).get(id) + + self.assert_(srt.bullets) + self.assert_(len(srt.bullets) == 3) + + texts = ['1', '6', '3'] + for i, text in enumerate(texts): + self.assert_(srt.bullets[i].position == i) + self.assert_(srt.bullets[i].text == text) + + def test_replace(self): + self._setup(ordering_list('position')) + + s1 = Slide('Slide #1') + s1.bullets = [ Bullet('1'), Bullet('2'), Bullet('3') ] + + self.assert_(len(s1.bullets) == 3) + self.assert_(s1.bullets[2].position == 2) + + session = create_session() + session.add(s1) + session.flush() + + new_bullet = Bullet('new 2') + self.assert_(new_bullet.position is None) + + # mark existing bullet as db-deleted before replacement. + #session.delete(s1.bullets[1]) + s1.bullets[1] = new_bullet + + self.assert_(new_bullet.position == 1) + self.assert_(len(s1.bullets) == 3) + + id = s1.id + + session.flush() + session.expunge_all() + + srt = session.query(Slide).get(id) + + self.assert_(srt.bullets) + self.assert_(len(srt.bullets) == 3) + + self.assert_(srt.bullets[1].text == 'new 2') + self.assert_(srt.bullets[2].text == '3') + + def test_funky_ordering(self): + class Pos(object): + def __init__(self): + self.position = None + + step_factory = ordering_list('position', + ordering_func=step_numbering(2)) + + stepped = step_factory() + stepped.append(Pos()) + stepped.append(Pos()) + stepped.append(Pos()) + stepped.append(Pos()) + + for li, pos in (0,0), (1,2), (2,4), (3,6): + self.assert_(stepped[li].position == pos) + + fib_factory = ordering_list('position', + ordering_func=fibonacci_numbering('position')) + + fibbed = fib_factory() + fibbed.append(Pos()) + fibbed.append(Pos()) + fibbed.append(Pos()) + fibbed.append(Pos()) + fibbed.append(Pos()) + + for li, pos in (0,1), (1,2), (2,3), (3,5), (4,8): + self.assert_(fibbed[li].position == pos) + + fibbed.insert(2, Pos()) + fibbed.insert(4, Pos()) + fibbed.insert(6, Pos()) + + for li, pos in (0,1), (1,2), (2,3), (3,5), (4,8), (5,13), (6,21), (7,34): + self.assert_(fibbed[li].position == pos) + + alpha_factory = ordering_list('position', + ordering_func=alpha_ordering) + alpha = alpha_factory() + alpha.append(Pos()) + alpha.append(Pos()) + alpha.append(Pos()) + + alpha.insert(1, Pos()) + + for li, pos in (0,'A'), (1,'B'), (2,'C'), (3,'D'): + self.assert_(alpha[li].position == pos) + + diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py new file mode 100644 index 000000000..b8a8e3fef --- /dev/null +++ b/test/ext/test_serializer.py @@ -0,0 +1,144 @@ + +from sqlalchemy.ext import serializer +from sqlalchemy import exc +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import MetaData, Integer, String, ForeignKey, select, desc, func, util +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import relation, sessionmaker, scoped_session, class_mapper, mapper, eagerload, compile_mappers, aliased +from sqlalchemy.test.testing import eq_ + +from test.orm._base import ComparableEntity, MappedTest + + +class User(ComparableEntity): + pass + +class Address(ComparableEntity): + pass + +class SerializeTest(MappedTest): + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + global users, addresses + users = Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)) + ) + addresses = Table('addresses', metadata, + Column('id', Integer, primary_key=True), + Column('email', String(50)), + Column('user_id', Integer, ForeignKey('users.id')), + ) + + @classmethod + def setup_mappers(cls): + global Session + Session = scoped_session(sessionmaker()) + + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', order_by=addresses.c.id) + }) + mapper(Address, addresses) + + compile_mappers() + + @classmethod + def insert_data(cls): + params = [dict(zip(('id', 'name'), column_values)) for column_values in + [(7, 'jack'), + (8, 'ed'), + (9, 'fred'), + (10, 'chuck')] + ] + users.insert().execute(params) + + addresses.insert().execute( + [dict(zip(('id', 'user_id', 'email'), column_values)) for column_values in + [(1, 7, "jack@bean.com"), + (2, 8, "ed@wood.com"), + (3, 8, "ed@bettyboop.com"), + (4, 8, "ed@lala.com"), + (5, 9, "fred@fred.com")] + ] + ) + + def test_tables(self): + assert serializer.loads(serializer.dumps(users), users.metadata, Session) is users + + def test_columns(self): + assert serializer.loads(serializer.dumps(users.c.name), users.metadata, Session) is users.c.name + + def test_mapper(self): + user_mapper = class_mapper(User) + assert serializer.loads(serializer.dumps(user_mapper), None, None) is user_mapper + + def test_attribute(self): + assert serializer.loads(serializer.dumps(User.name), None, None) is User.name + + def test_expression(self): + + expr = select([users]).select_from(users.join(addresses)).limit(5) + re_expr = serializer.loads(serializer.dumps(expr), users.metadata, None) + eq_( + str(expr), + str(re_expr) + ) + + assert re_expr.bind is testing.db + eq_( + re_expr.execute().fetchall(), + [(7, u'jack'), (8, u'ed'), (8, u'ed'), (8, u'ed'), (9, u'fred')] + ) + + # fails due to pure Python pickle bug: http://bugs.python.org/issue998998 + @testing.fails_if(lambda: util.py3k) + def test_query(self): + q = Session.query(User).filter(User.name=='ed').options(eagerload(User.addresses)) + eq_(q.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])]) + + q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) + def go(): + eq_(q2.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])]) + self.assert_sql_count(testing.db, go, 1) + + eq_(q2.join(User.addresses).filter(Address.email=='ed@bettyboop.com').value(func.count('*')), 1) + + u1 = Session.query(User).get(8) + + q = Session.query(Address).filter(Address.user==u1).order_by(desc(Address.email)) + q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) + + eq_(q2.all(), [Address(email='ed@wood.com'), Address(email='ed@lala.com'), Address(email='ed@bettyboop.com')]) + + q = Session.query(User).join(User.addresses).filter(Address.email.like('%fred%')) + q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) + eq_(q2.all(), [User(name='fred')]) + + eq_(list(q2.values(User.id, User.name)), [(9, u'fred')]) + + @testing.exclude('sqlite', '<=', (3, 5, 9), 'id comparison failing on the buildbot') + def test_aliases(self): + u7, u8, u9, u10 = Session.query(User).order_by(User.id).all() + + ualias = aliased(User) + q = Session.query(User, ualias).join((ualias, User.id < ualias.id)).filter(User.id<9).order_by(User.id, ualias.id) + eq_(list(q.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)]) + + q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) + + eq_(list(q2.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)]) + + def test_any(self): + r = User.addresses.any(Address.email=='x') + ser = serializer.dumps(r) + x = serializer.loads(ser, users.metadata) + eq_(str(r), str(x)) + +if __name__ == '__main__': + testing.main() -- cgit v1.2.1