import testenv; testenv.configure_for_tests() import pickle from sqlalchemy import util import sqlalchemy.orm.attributes as attributes from sqlalchemy.orm.collections import collection from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import InstrumentationManager from testlib import * from orm import _base class MyTypesManager(InstrumentationManager): def instrument_attribute(self, class_, key, attr): pass def install_descriptor(self, class_, key, attr): pass def uninstall_descriptor(self, class_, key): pass def instrument_collection_class(self, class_, key, collection_class): return MyListLike def get_instance_dict(self, class_, instance): return instance._goofy_dict def initialize_instance_dict(self, class_, instance): instance.__dict__['_goofy_dict'] = {} def install_state(self, class_, instance, state): instance.__dict__['_my_state'] = state def state_getter(self, class_): return lambda instance: instance.__dict__['_my_state'] class MyListLike(list): # add @appender, @remover decorators as needed _sa_iterator = list.__iter__ def _sa_appender(self, item, _sa_initiator=None): if _sa_initiator is not False: self._sa_adapter.fire_append_event(item, _sa_initiator) list.append(self, item) append = _sa_appender def _sa_remover(self, item, _sa_initiator=None): self._sa_adapter.fire_pre_remove_event(_sa_initiator) if _sa_initiator is not False: self._sa_adapter.fire_remove_event(item, _sa_initiator) list.remove(self, item) remove = _sa_remover class MyBaseClass(object): __sa_instrumentation_manager__ = InstrumentationManager class MyClass(object): # This proves that a staticmethod will work here; don't # flatten this back to a class assignment! def __sa_instrumentation_manager__(cls): return MyTypesManager(cls) __sa_instrumentation_manager__ = staticmethod(__sa_instrumentation_manager__) # This proves SA can handle a class with non-string dict keys try: locals()[42] = 99 # Don't remove this line! except: pass def __init__(self, **kwargs): for k in kwargs: setattr(self, k, kwargs[k]) def __getattr__(self, key): if is_instrumented(self, key): return get_attribute(self, key) else: try: return self._goofy_dict[key] except KeyError: raise AttributeError(key) def __setattr__(self, key, value): if is_instrumented(self, key): set_attribute(self, key, value) else: self._goofy_dict[key] = value def __hasattr__(self, key): if is_instrumented(self, key): return True else: return key in self._goofy_dict def __delattr__(self, key): if is_instrumented(self, key): del_attribute(self, key) else: del self._goofy_dict[key] class UserDefinedExtensionTest(_base.ORMTest): def tearDownAll(self): clear_mappers() attributes._install_lookup_strategy(util.symbol('native')) def test_instance_dict(self): class User(MyClass): pass attributes.register_class(User) attributes.register_attribute(User, 'user_id', uselist = False, useobject=False) attributes.register_attribute(User, 'user_name', uselist = False, useobject=False) attributes.register_attribute(User, 'email_address', uselist = False, useobject=False) u = User() u.user_id = 7 u.user_name = 'john' u.email_address = 'lala@123.com' self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}}) def test_basic(self): for base in (object, MyBaseClass, MyClass): class User(base): pass attributes.register_class(User) attributes.register_attribute(User, 'user_id', uselist = False, useobject=False) attributes.register_attribute(User, 'user_name', uselist = False, useobject=False) attributes.register_attribute(User, 'email_address', uselist = False, useobject=False) u = User() u.user_id = 7 u.user_name = 'john' u.email_address = 'lala@123.com' self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') attributes.instance_state(u).commit_all() self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') u.user_name = 'heythere' u.email_address = 'foo@bar.com' self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com') def test_deferred(self): for base in (object, MyBaseClass, MyClass): class Foo(base):pass data = {'a':'this is a', 'b':12} def loader(state, keys): for k in keys: state.dict[k] = data[k] return attributes.ATTR_WAS_SET attributes.register_class(Foo) manager = attributes.manager_of_class(Foo) manager.deferred_scalar_loader = loader attributes.register_attribute(Foo, 'a', uselist=False, useobject=False) attributes.register_attribute(Foo, 'b', uselist=False, useobject=False) assert Foo in attributes.instrumentation_registry.state_finders f = Foo() attributes.instance_state(f).expire_attributes(None) self.assertEquals(f.a, "this is a") self.assertEquals(f.b, 12) f.a = "this is some new a" attributes.instance_state(f).expire_attributes(None) self.assertEquals(f.a, "this is a") self.assertEquals(f.b, 12) attributes.instance_state(f).expire_attributes(None) f.a = "this is another new a" self.assertEquals(f.a, "this is another new a") self.assertEquals(f.b, 12) attributes.instance_state(f).expire_attributes(None) self.assertEquals(f.a, "this is a") self.assertEquals(f.b, 12) del f.a self.assertEquals(f.a, None) self.assertEquals(f.b, 12) attributes.instance_state(f).commit_all() self.assertEquals(f.a, None) self.assertEquals(f.b, 12) def test_inheritance(self): """tests that attributes are polymorphic""" for base in (object, MyBaseClass, MyClass): class Foo(base):pass class Bar(Foo):pass attributes.register_class(Foo) attributes.register_class(Bar) def func1(): print "func1" return "this is the foo attr" def func2(): print "func2" return "this is the bar attr" def func3(): print "func3" return "this is the shared attr" attributes.register_attribute(Foo, 'element', uselist=False, callable_=lambda o:func1, useobject=True) attributes.register_attribute(Foo, 'element2', uselist=False, callable_=lambda o:func3, useobject=True) attributes.register_attribute(Bar, 'element', uselist=False, callable_=lambda o:func2, useobject=True) x = Foo() y = Bar() assert x.element == 'this is the foo attr' assert y.element == 'this is the bar attr', y.element assert x.element2 == 'this is the shared attr' assert y.element2 == 'this is the shared attr' def test_collection_with_backref(self): for base in (object, MyBaseClass, MyClass): class Post(base):pass class Blog(base):pass attributes.register_class(Post) attributes.register_class(Blog) attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) b = Blog() (p1, p2, p3) = (Post(), Post(), Post()) b.posts.append(p1) b.posts.append(p2) b.posts.append(p3) self.assert_(b.posts == [p1, p2, p3]) self.assert_(p2.blog is b) p3.blog = None self.assert_(b.posts == [p1, p2]) p4 = Post() p4.blog = b self.assert_(b.posts == [p1, p2, p4]) p4.blog = b p4.blog = b self.assert_(b.posts == [p1, p2, p4]) # assert no failure removing None p5 = Post() p5.blog = None del p5.blog def test_history(self): for base in (object, MyBaseClass, MyClass): class Foo(base): pass class Bar(base): pass attributes.register_class(Foo) attributes.register_class(Bar) attributes.register_attribute(Foo, "name", uselist=False, useobject=False) attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True) attributes.register_attribute(Bar, "name", uselist=False, useobject=False) f1 = Foo() f1.name = 'f1' self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], (), ())) b1 = Bar() b1.name = 'b1' f1.bars.append(b1) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], [])) attributes.instance_state(f1).commit_all() attributes.instance_state(b1).commit_all() self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ())) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ())) f1.name = 'f1mod' b2 = Bar() b2.name = 'b2' f1.bars.append(b2) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1'])) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], [])) f1.bars.remove(b1) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1])) def test_null_instrumentation(self): class Foo(MyBaseClass): pass attributes.register_class(Foo) attributes.register_attribute(Foo, "name", uselist=False, useobject=False) attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True) assert Foo.name == attributes.manager_of_class(Foo).get_inst('name') assert Foo.bars == attributes.manager_of_class(Foo).get_inst('bars') def test_alternate_finders(self): """Ensure the generic finder front-end deals with edge cases.""" class Unknown(object): pass class Known(MyBaseClass): pass attributes.register_class(Known) k, u = Known(), Unknown() assert attributes.manager_of_class(Unknown) is None assert attributes.manager_of_class(Known) is not None assert attributes.manager_of_class(None) is None assert attributes.instance_state(k) is not None self.assertRaises((AttributeError, KeyError), attributes.instance_state, u) self.assertRaises((AttributeError, KeyError), attributes.instance_state, None) if __name__ == '__main__': testing.main()