summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-06-26 19:55:48 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-06-26 19:55:48 +0000
commit49090a7a9434245b03a7f867add2401d3a78fead (patch)
treec22b55212f6a0b843be63d35cff3b4fbb5f39c5a
parent22fcdbe81f528a584eb499b9dbf2270365926b71 (diff)
downloadsqlalchemy-49090a7a9434245b03a7f867add2401d3a78fead.tar.gz
fixed attribute manager's ability to traverse the full set of managed attributes for a descendant class, + 2 unit tests
-rw-r--r--lib/sqlalchemy/attributes.py5
-rw-r--r--test/base/attributes.py16
-rw-r--r--test/orm/inheritance.py44
3 files changed, 61 insertions, 4 deletions
diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py
index b7ad5249b..2bf336398 100644
--- a/lib/sqlalchemy/attributes.py
+++ b/lib/sqlalchemy/attributes.py
@@ -519,7 +519,7 @@ class AttributeHistory(object):
else:
self._deleted_items = []
self._unchanged_items = []
- #print "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items
+ #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items
def __iter__(self):
return iter(self._current)
def added_items(self):
@@ -566,7 +566,8 @@ class AttributeManager(object):
"""returns an iterator of all InstrumentedAttribute objects associated with the given class."""
if not isinstance(class_, type):
raise repr(class_) + " is not a type"
- for value in class_.__dict__.values():
+ for key in dir(class_):
+ value = getattr(class_, key)
if isinstance(value, InstrumentedAttribute):
yield value
diff --git a/test/base/attributes.py b/test/base/attributes.py
index 4b8bfd39a..19eedd0f6 100644
--- a/test/base/attributes.py
+++ b/test/base/attributes.py
@@ -183,6 +183,22 @@ class AttributesTest(PersistTest):
assert x.element2 == 'this is the shared attr'
assert y.element2 == 'this is the shared attr'
+ def testinheritance2(self):
+ """test that the attribute manager can properly traverse the managed attributes of an object,
+ if the object is of a descendant class with managed attributes in the parent class"""
+ class Foo(object):pass
+ class Bar(Foo):pass
+ manager = attributes.AttributeManager()
+ manager.register_attribute(Foo, 'element', uselist=False)
+ x = Bar()
+ x.element = 'this is the element'
+ hist = manager.get_history(x, 'element')
+ assert hist.added_items() == ['this is the element']
+ manager.commit(x)
+ hist = manager.get_history(x, 'element')
+ assert hist.added_items() == []
+ assert hist.unchanged_items() == ['this is the element']
+
def testlazyhistory(self):
"""tests that history functions work with lazy-loading attributes"""
class Foo(object):pass
diff --git a/test/orm/inheritance.py b/test/orm/inheritance.py
index 842a63a26..bca0ffde3 100644
--- a/test/orm/inheritance.py
+++ b/test/orm/inheritance.py
@@ -442,8 +442,11 @@ class InheritTest7(testbase.AssertMixin):
metadata.create_all()
def tearDownAll(self):
metadata.drop_all()
-
- def testbasic(self):
+ def tearDown(self):
+ for t in metadata.table_iterator(reverse=True):
+ t.delete().execute()
+
+ def testone(self):
class User(object):pass
class Role(object):pass
class Admin(User):pass
@@ -469,6 +472,43 @@ class InheritTest7(testbase.AssertMixin):
sess.flush()
assert user_roles.count().scalar() == 1
+
+ def testtwo(self):
+ class User(object):
+ def __init__(self, email=None, password=None):
+ self.email = email
+ self.password = password
+
+ class Role(object):
+ def __init__(self, description=None):
+ self.description = description
+
+ class Admin(User):pass
+
+ role_mapper = mapper(Role, roles)
+ user_mapper = mapper(User, users, properties = {
+ 'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
+ }
+ )
+
+ admin_mapper = mapper(Admin, admins, inherits=user_mapper)
+
+ # create roles
+ adminrole = Role('admin')
+
+ sess = create_session()
+ sess.save(adminrole)
+ sess.flush()
+
+ # create admin user
+ a = Admin(email='tim', password='admin')
+ a.roles.append(adminrole)
+ sess.save(a)
+ sess.flush()
+
+ a.password = 'sadmin'
+ sess.flush()
+ assert user_roles.count().scalar() == 1
if __name__ == "__main__":
testbase.main()