diff options
Diffstat (limited to 'test/orm')
| -rw-r--r-- | test/orm/attributes.py | 242 | ||||
| -rw-r--r-- | test/orm/lazy_relations.py | 34 | ||||
| -rw-r--r-- | test/orm/mapper.py | 8 | ||||
| -rw-r--r-- | test/orm/unitofwork.py | 3 |
4 files changed, 208 insertions, 79 deletions
diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 4e41f0a29..b321dc50a 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -150,82 +150,9 @@ class AttributesTest(PersistTest): self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') self.assert_(len(attributes.get_history(u, 'addresses').unchanged_items()) == 1) - def test_backref(self): - class Student(object):pass - class Course(object):pass - - attributes.register_class(Student) - attributes.register_class(Course) - attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True) - attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True) - - s = Student() - c = Course() - s.courses.append(c) - self.assert_(c.students == [s]) - s.courses.remove(c) - self.assert_(c.students == []) - - (s1, s2, s3) = (Student(), Student(), Student()) - - c.students = [s1, s2, s3] - self.assert_(s2.courses == [c]) - self.assert_(s1.courses == [c]) - print "--------------------------------" - print s1 - print s1.courses - print c - print c.students - s1.courses.remove(c) - self.assert_(c.students == [s2,s3]) - class Post(object):pass - class Blog(object):pass - - attributes.register_class(Post) - attributes.register_class(Blog) - attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) - b = Blog() - (p1, p2, p3) = (Post(), Post(), Post()) - b.posts.append(p1) - b.posts.append(p2) - b.posts.append(p3) - self.assert_(b.posts == [p1, p2, p3]) - self.assert_(p2.blog is b) - p3.blog = None - self.assert_(b.posts == [p1, p2]) - p4 = Post() - p4.blog = b - self.assert_(b.posts == [p1, p2, p4]) - - p4.blog = b - p4.blog = b - self.assert_(b.posts == [p1, p2, p4]) - - # assert no failure removing None - p5 = Post() - p5.blog = None - del p5.blog - - class Port(object):pass - class Jack(object):pass - attributes.register_class(Port) - attributes.register_class(Jack) - attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True) - attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True) - p = Port() - j = Jack() - p.jack = j - self.assert_(j.port is p) - self.assert_(p.jack is not None) - - j.port = None - self.assert_(p.jack is None) - def test_lazytrackparent(self): """test that the "hasparent" flag works properly when lazy loaders and backrefs are used""" - class Post(object):pass class Blog(object):pass @@ -449,6 +376,173 @@ class AttributesTest(PersistTest): assert True except exceptions.ArgumentError, e: assert False - + + +class BackrefTest(PersistTest): + + def test_manytomany(self): + class Student(object):pass + class Course(object):pass + + attributes.register_class(Student) + attributes.register_class(Course) + attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True) + attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True) + + s = Student() + c = Course() + s.courses.append(c) + self.assert_(c.students == [s]) + s.courses.remove(c) + self.assert_(c.students == []) + + (s1, s2, s3) = (Student(), Student(), Student()) + + c.students = [s1, s2, s3] + self.assert_(s2.courses == [c]) + self.assert_(s1.courses == [c]) + print "--------------------------------" + print s1 + print s1.courses + print c + print c.students + s1.courses.remove(c) + self.assert_(c.students == [s2,s3]) + + def test_onetomany(self): + class Post(object):pass + class Blog(object):pass + + attributes.register_class(Post) + attributes.register_class(Blog) + attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) + b = Blog() + (p1, p2, p3) = (Post(), Post(), Post()) + b.posts.append(p1) + b.posts.append(p2) + b.posts.append(p3) + self.assert_(b.posts == [p1, p2, p3]) + self.assert_(p2.blog is b) + + p3.blog = None + self.assert_(b.posts == [p1, p2]) + p4 = Post() + p4.blog = b + self.assert_(b.posts == [p1, p2, p4]) + + p4.blog = b + p4.blog = b + self.assert_(b.posts == [p1, p2, p4]) + + # assert no failure removing None + p5 = Post() + p5.blog = None + del p5.blog + + def test_onetoone(self): + class Port(object):pass + class Jack(object):pass + attributes.register_class(Port) + attributes.register_class(Jack) + attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True) + attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True) + p = Port() + j = Jack() + p.jack = j + self.assert_(j.port is p) + self.assert_(p.jack is not None) + + j.port = None + self.assert_(p.jack is None) + +class DeferredBackrefTest(PersistTest): + def setUp(self): + global Post, Blog, called, lazy_load + + class Post(object): + def __init__(self, name): + self.name = name + def __eq__(self, other): + return other.name == self.name + + class Blog(object): + def __init__(self, name): + self.name = name + def __eq__(self, other): + return other.name == self.name + + called = [0] + + lazy_load = [] + def lazy_posts(instance): + def load(): + called[0] += 1 + return lazy_load + return load + + attributes.register_class(Post) + attributes.register_class(Blog) + attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), callable_=lazy_posts, trackparent=True, useobject=True) + + def test_lazy_add(self): + global lazy_load + + p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3") + lazy_load = [p1, p2, p3] + + b = Blog("blog 1") + p = Post("post 4") + p.blog = b + p = Post("post 5") + p.blog = b + # setting blog doesnt call 'posts' callable + assert called[0] == 0 + + # calling backref calls the callable, populates extra posts + assert b.posts == [p1, p2, p3, Post("post 4"), Post("post 5")] + assert called[0] == 1 + + def test_lazy_remove(self): + global lazy_load + called[0] = 0 + lazy_load = [] + + b = Blog("blog 1") + p = Post("post 1") + p.blog = b + assert called[0] == 0 + + lazy_load = [p] + + p.blog = None + p2 = Post("post 2") + p2.blog = b + assert called[0] == 0 + assert b.posts == [p2] + assert called[0] == 1 + + def test_normal_load(self): + global lazy_load + lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] + called[0] = 0 + + b = Blog("blog 1") + + # assign without using backref system + p2.__dict__['blog'] = b + + assert b.posts == [Post("post 1"), Post("post 2"), Post("post 3")] + assert called[0] == 1 + p2.blog = None + p4 = Post("post 4") + p4.blog = b + assert b.posts == [Post("post 1"), Post("post 3"), Post("post 4")] + assert called[0] == 1 + + called[0] = 0 + lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")] + if __name__ == "__main__": testbase.main() diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py index 97eda3006..487eb7716 100644 --- a/test/orm/lazy_relations.py +++ b/test/orm/lazy_relations.py @@ -272,7 +272,41 @@ class LazyTest(FixtureTest): u1 = sess.query(User).get(7) assert a.user is u1 + + def test_backrefs_dont_lazyload(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user') + }) + mapper(Address, addresses) + sess = create_session() + ad = sess.query(Address).filter_by(id=1).one() + assert ad.user.id == 7 + def go(): + ad.user = None + assert ad.user is None + self.assert_sql_count(testbase.db, go, 0) + + u1 = sess.query(User).filter_by(id=7).one() + def go(): + assert ad not in u1.addresses + self.assert_sql_count(testbase.db, go, 1) + + sess.expire(u1, ['addresses']) + def go(): + assert ad in u1.addresses + self.assert_sql_count(testbase.db, go, 1) + sess.expire(u1, ['addresses']) + ad2 = Address() + def go(): + ad2.user = u1 + assert ad2.user is u1 + self.assert_sql_count(testbase.db, go, 0) + + def go(): + assert ad2 in u1.addresses + self.assert_sql_count(testbase.db, go, 1) + class M2OGetTest(FixtureTest): keep_mappers = False keep_data = True diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 36f056156..df1b6bba1 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1322,15 +1322,17 @@ class RequirementsTest(AssertMixin): h1.h1s.append(H1()) s.flush() - + self.assertEquals(t1.count().scalar(), 4) + h6 = H6() h6.h1a = h1 h6.h1b = h1 h6 = H6() h6.h1a = h1 - h6.h1b = H1() - + h6.h1b = x = H1() + assert x in s + h6.h1b.h2s.append(H2()) s.flush() diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 158813cd7..11d731377 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -33,8 +33,7 @@ class HistoryTest(ORMTest): u = User(_sa_session=s) a = Address(_sa_session=s) a.user = u - #print repr(a.__class__._attribute_manager.get_history(a, 'user').added_items()) - #print repr(u.addresses.added_items()) + self.assert_(u.addresses == [a]) s.commit() |
