summaryrefslogtreecommitdiff
path: root/test/orm
diff options
context:
space:
mode:
Diffstat (limited to 'test/orm')
-rw-r--r--test/orm/attributes.py242
-rw-r--r--test/orm/lazy_relations.py34
-rw-r--r--test/orm/mapper.py8
-rw-r--r--test/orm/unitofwork.py3
4 files changed, 208 insertions, 79 deletions
diff --git a/test/orm/attributes.py b/test/orm/attributes.py
index 4e41f0a29..b321dc50a 100644
--- a/test/orm/attributes.py
+++ b/test/orm/attributes.py
@@ -150,82 +150,9 @@ class AttributesTest(PersistTest):
self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
self.assert_(len(attributes.get_history(u, 'addresses').unchanged_items()) == 1)
- def test_backref(self):
- class Student(object):pass
- class Course(object):pass
-
- attributes.register_class(Student)
- attributes.register_class(Course)
- attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True)
- attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True)
-
- s = Student()
- c = Course()
- s.courses.append(c)
- self.assert_(c.students == [s])
- s.courses.remove(c)
- self.assert_(c.students == [])
-
- (s1, s2, s3) = (Student(), Student(), Student())
-
- c.students = [s1, s2, s3]
- self.assert_(s2.courses == [c])
- self.assert_(s1.courses == [c])
- print "--------------------------------"
- print s1
- print s1.courses
- print c
- print c.students
- s1.courses.remove(c)
- self.assert_(c.students == [s2,s3])
- class Post(object):pass
- class Blog(object):pass
-
- attributes.register_class(Post)
- attributes.register_class(Blog)
- attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
- attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True)
- b = Blog()
- (p1, p2, p3) = (Post(), Post(), Post())
- b.posts.append(p1)
- b.posts.append(p2)
- b.posts.append(p3)
- self.assert_(b.posts == [p1, p2, p3])
- self.assert_(p2.blog is b)
- p3.blog = None
- self.assert_(b.posts == [p1, p2])
- p4 = Post()
- p4.blog = b
- self.assert_(b.posts == [p1, p2, p4])
-
- p4.blog = b
- p4.blog = b
- self.assert_(b.posts == [p1, p2, p4])
-
- # assert no failure removing None
- p5 = Post()
- p5.blog = None
- del p5.blog
-
- class Port(object):pass
- class Jack(object):pass
- attributes.register_class(Port)
- attributes.register_class(Jack)
- attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True)
- attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True)
- p = Port()
- j = Jack()
- p.jack = j
- self.assert_(j.port is p)
- self.assert_(p.jack is not None)
-
- j.port = None
- self.assert_(p.jack is None)
-
def test_lazytrackparent(self):
"""test that the "hasparent" flag works properly when lazy loaders and backrefs are used"""
-
class Post(object):pass
class Blog(object):pass
@@ -449,6 +376,173 @@ class AttributesTest(PersistTest):
assert True
except exceptions.ArgumentError, e:
assert False
-
+
+
+class BackrefTest(PersistTest):
+
+ def test_manytomany(self):
+ class Student(object):pass
+ class Course(object):pass
+
+ attributes.register_class(Student)
+ attributes.register_class(Course)
+ attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True)
+ attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True)
+
+ s = Student()
+ c = Course()
+ s.courses.append(c)
+ self.assert_(c.students == [s])
+ s.courses.remove(c)
+ self.assert_(c.students == [])
+
+ (s1, s2, s3) = (Student(), Student(), Student())
+
+ c.students = [s1, s2, s3]
+ self.assert_(s2.courses == [c])
+ self.assert_(s1.courses == [c])
+ print "--------------------------------"
+ print s1
+ print s1.courses
+ print c
+ print c.students
+ s1.courses.remove(c)
+ self.assert_(c.students == [s2,s3])
+
+ def test_onetomany(self):
+ class Post(object):pass
+ class Blog(object):pass
+
+ attributes.register_class(Post)
+ attributes.register_class(Blog)
+ attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
+ attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True)
+ b = Blog()
+ (p1, p2, p3) = (Post(), Post(), Post())
+ b.posts.append(p1)
+ b.posts.append(p2)
+ b.posts.append(p3)
+ self.assert_(b.posts == [p1, p2, p3])
+ self.assert_(p2.blog is b)
+
+ p3.blog = None
+ self.assert_(b.posts == [p1, p2])
+ p4 = Post()
+ p4.blog = b
+ self.assert_(b.posts == [p1, p2, p4])
+
+ p4.blog = b
+ p4.blog = b
+ self.assert_(b.posts == [p1, p2, p4])
+
+ # assert no failure removing None
+ p5 = Post()
+ p5.blog = None
+ del p5.blog
+
+ def test_onetoone(self):
+ class Port(object):pass
+ class Jack(object):pass
+ attributes.register_class(Port)
+ attributes.register_class(Jack)
+ attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True)
+ attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True)
+ p = Port()
+ j = Jack()
+ p.jack = j
+ self.assert_(j.port is p)
+ self.assert_(p.jack is not None)
+
+ j.port = None
+ self.assert_(p.jack is None)
+
+class DeferredBackrefTest(PersistTest):
+ def setUp(self):
+ global Post, Blog, called, lazy_load
+
+ class Post(object):
+ def __init__(self, name):
+ self.name = name
+ def __eq__(self, other):
+ return other.name == self.name
+
+ class Blog(object):
+ def __init__(self, name):
+ self.name = name
+ def __eq__(self, other):
+ return other.name == self.name
+
+ called = [0]
+
+ lazy_load = []
+ def lazy_posts(instance):
+ def load():
+ called[0] += 1
+ return lazy_load
+ return load
+
+ attributes.register_class(Post)
+ attributes.register_class(Blog)
+ attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
+ attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), callable_=lazy_posts, trackparent=True, useobject=True)
+
+ def test_lazy_add(self):
+ global lazy_load
+
+ p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3")
+ lazy_load = [p1, p2, p3]
+
+ b = Blog("blog 1")
+ p = Post("post 4")
+ p.blog = b
+ p = Post("post 5")
+ p.blog = b
+ # setting blog doesnt call 'posts' callable
+ assert called[0] == 0
+
+ # calling backref calls the callable, populates extra posts
+ assert b.posts == [p1, p2, p3, Post("post 4"), Post("post 5")]
+ assert called[0] == 1
+
+ def test_lazy_remove(self):
+ global lazy_load
+ called[0] = 0
+ lazy_load = []
+
+ b = Blog("blog 1")
+ p = Post("post 1")
+ p.blog = b
+ assert called[0] == 0
+
+ lazy_load = [p]
+
+ p.blog = None
+ p2 = Post("post 2")
+ p2.blog = b
+ assert called[0] == 0
+ assert b.posts == [p2]
+ assert called[0] == 1
+
+ def test_normal_load(self):
+ global lazy_load
+ lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")]
+ called[0] = 0
+
+ b = Blog("blog 1")
+
+ # assign without using backref system
+ p2.__dict__['blog'] = b
+
+ assert b.posts == [Post("post 1"), Post("post 2"), Post("post 3")]
+ assert called[0] == 1
+ p2.blog = None
+ p4 = Post("post 4")
+ p4.blog = b
+ assert b.posts == [Post("post 1"), Post("post 3"), Post("post 4")]
+ assert called[0] == 1
+
+ called[0] = 0
+ lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")]
+
if __name__ == "__main__":
testbase.main()
diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py
index 97eda3006..487eb7716 100644
--- a/test/orm/lazy_relations.py
+++ b/test/orm/lazy_relations.py
@@ -272,7 +272,41 @@ class LazyTest(FixtureTest):
u1 = sess.query(User).get(7)
assert a.user is u1
+
+ def test_backrefs_dont_lazyload(self):
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref='user')
+ })
+ mapper(Address, addresses)
+ sess = create_session()
+ ad = sess.query(Address).filter_by(id=1).one()
+ assert ad.user.id == 7
+ def go():
+ ad.user = None
+ assert ad.user is None
+ self.assert_sql_count(testbase.db, go, 0)
+
+ u1 = sess.query(User).filter_by(id=7).one()
+ def go():
+ assert ad not in u1.addresses
+ self.assert_sql_count(testbase.db, go, 1)
+
+ sess.expire(u1, ['addresses'])
+ def go():
+ assert ad in u1.addresses
+ self.assert_sql_count(testbase.db, go, 1)
+ sess.expire(u1, ['addresses'])
+ ad2 = Address()
+ def go():
+ ad2.user = u1
+ assert ad2.user is u1
+ self.assert_sql_count(testbase.db, go, 0)
+
+ def go():
+ assert ad2 in u1.addresses
+ self.assert_sql_count(testbase.db, go, 1)
+
class M2OGetTest(FixtureTest):
keep_mappers = False
keep_data = True
diff --git a/test/orm/mapper.py b/test/orm/mapper.py
index 36f056156..df1b6bba1 100644
--- a/test/orm/mapper.py
+++ b/test/orm/mapper.py
@@ -1322,15 +1322,17 @@ class RequirementsTest(AssertMixin):
h1.h1s.append(H1())
s.flush()
-
+ self.assertEquals(t1.count().scalar(), 4)
+
h6 = H6()
h6.h1a = h1
h6.h1b = h1
h6 = H6()
h6.h1a = h1
- h6.h1b = H1()
-
+ h6.h1b = x = H1()
+ assert x in s
+
h6.h1b.h2s.append(H2())
s.flush()
diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py
index 158813cd7..11d731377 100644
--- a/test/orm/unitofwork.py
+++ b/test/orm/unitofwork.py
@@ -33,8 +33,7 @@ class HistoryTest(ORMTest):
u = User(_sa_session=s)
a = Address(_sa_session=s)
a.user = u
- #print repr(a.__class__._attribute_manager.get_history(a, 'user').added_items())
- #print repr(u.addresses.added_items())
+
self.assert_(u.addresses == [a])
s.commit()