diff options
| -rw-r--r-- | CHANGES | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 57 | ||||
| -rw-r--r-- | test/orm/mapper.py | 114 |
3 files changed, 124 insertions, 65 deletions
@@ -82,6 +82,24 @@ CHANGES issued directly by the ORM in the form of UPDATE statements, by setting the flag "passive_cascades=False". + - inheriting mappers now inherit the MapperExtensions of their parent + mapper directly, so that all methods for a particular MapperExtension + are called for subclasses as well. As always, any MapperExtension + can return either EXT_CONTINUE to continue extension processing + or EXT_STOP to stop processing. The order of mapper resolution is: + <extensions declared on the classes mapper> <extensions declared on the + classes' parent mapper> <globally declared extensions>. + + Note that if you instantiate the same extension class separately + and then apply it individually for two mappers in the same inheritance + chain, the extension will be applied twice to the inheriting class, + and each method will be called twice. + + To apply a mapper extension explicitly to each inheriting class but + have each method called only once per operation, use the same + instance of the extension for both mappers. + [ticket:490] + - new synonym() behavior: an attribute will be placed on the mapped class, if one does not exist already, in all cases. if a property already exists on the class, the synonym will decorate the property diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 2c3d3ea40..8c375ea39 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -153,8 +153,8 @@ class Mapper(object): self.__should_log_debug = logging.is_debug_enabled(self.logger) self._compile_class() - self._compile_extensions() self._compile_inheritance() + self._compile_extensions() self._compile_tables() self._compile_properties() self._compile_pks() @@ -281,12 +281,19 @@ class Mapper(object): for ext_obj in util.to_list(extension): # local MapperExtensions have already instrumented the class extlist.add(ext_obj) - - for ext in global_extensions: - if isinstance(ext, type): - ext = ext() - extlist.add(ext) - ext.instrument_class(self, self.class_) + + if self.inherits is not None: + for ext in self.inherits.extension: + if ext not in extlist: + extlist.add(ext) + ext.instrument_class(self, self.class_) + else: + for ext in global_extensions: + if isinstance(ext, type): + ext = ext() + if ext not in extlist: + extlist.add(ext) + ext.instrument_class(self, self.class_) self.extension = ExtensionCarrier() for ext in extlist: @@ -960,14 +967,13 @@ class Mapper(object): if not postupdate: # call before_XXX extensions for state, connection, has_identity in tups: + mapper = _state_mapper(state) if not has_identity: - for mapper in _state_mapper(state).iterate_to_root(): - if 'before_insert' in mapper.extension.methods: - mapper.extension.before_insert(mapper, connection, state.obj()) + if 'before_insert' in mapper.extension.methods: + mapper.extension.before_insert(mapper, connection, state.obj()) else: - for mapper in _state_mapper(state).iterate_to_root(): - if 'before_update' in mapper.extension.methods: - mapper.extension.before_update(mapper, connection, state.obj()) + if 'before_update' in mapper.extension.methods: + mapper.extension.before_update(mapper, connection, state.obj()) for state, connection, has_identity in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), @@ -1131,13 +1137,14 @@ class Mapper(object): if not postupdate: # call after_XXX extensions for state, connection in inserted_objects: - for mapper in _state_mapper(state).iterate_to_root(): - if 'after_insert' in mapper.extension.methods: - mapper.extension.after_insert(mapper, connection, state.obj()) + mapper = _state_mapper(state) + if 'after_insert' in mapper.extension.methods: + mapper.extension.after_insert(mapper, connection, state.obj()) + for state, connection in updated_objects: - for mapper in _state_mapper(state).iterate_to_root(): - if 'after_update' in mapper.extension.methods: - mapper.extension.after_update(mapper, connection, state.obj()) + mapper = _state_mapper(state) + if 'after_update' in mapper.extension.methods: + mapper.extension.after_update(mapper, connection, state.obj()) def _postfetch(self, connection, table, state, resultproxy, params, value_params): """After an ``INSERT`` or ``UPDATE``, assemble newly generated @@ -1177,9 +1184,9 @@ class Mapper(object): tups = [(state, connection) for state in states] for (state, connection) in tups: - for mapper in _state_mapper(state).iterate_to_root(): - if 'before_delete' in mapper.extension.methods: - mapper.extension.before_delete(mapper, connection, state.obj()) + mapper = _state_mapper(state) + if 'before_delete' in mapper.extension.methods: + mapper.extension.before_delete(mapper, connection, state.obj()) deleted_objects = util.Set() table_to_mapper = {} @@ -1225,9 +1232,9 @@ class Mapper(object): raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects))) for state, connection in deleted_objects: - for mapper in _state_mapper(state).iterate_to_root(): - if 'after_delete' in mapper.extension.methods: - mapper.extension.after_delete(mapper, connection, state.obj()) + mapper = _state_mapper(state) + if 'after_delete' in mapper.extension.methods: + mapper.extension.after_delete(mapper, connection, state.obj()) def _register_dependencies(self, uowcommit): """Register ``DependencyProcessor`` instances with a diff --git a/test/orm/mapper.py b/test/orm/mapper.py index a1ca19a64..662ac4a29 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1131,83 +1131,73 @@ class NoLoadTest(MapperSuperTest): {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, ) -class MapperExtensionTest(MapperSuperTest): +class MapperExtensionTest(PersistTest): def setUpAll(self): tables.create() - def tearDownAll(self): - tables.drop() - def tearDown(self): - clear_mappers() - tables.delete() - def setUp(self): - tables.data() - - def test_create_instance(self): - class Ext(MapperExtension): - def create_instance(self, *args, **kwargs): - return User() - m = mapper(Address, addresses) - m = mapper(User, users, extension=Ext(), properties = dict( - addresses = relation(Address, lazy=True), - )) - - q = create_session().query(m) - l = q.select(); - self.assert_result(l, User, *user_address_result) - - def test_methods(self): - """test that common user-defined methods get called.""" - - methods = set() + + global methods, Ext + + methods = [] + class Ext(MapperExtension): def load(self, query, *args, **kwargs): - methods.add('load') + methods.append('load') return EXT_CONTINUE def get(self, query, *args, **kwargs): - methods.add('get') + methods.append('get') return EXT_CONTINUE def translate_row(self, mapper, context, row): - methods.add('translate_row') + methods.append('translate_row') return EXT_CONTINUE def create_instance(self, mapper, selectcontext, row, class_): - methods.add('create_instance') + methods.append('create_instance') return EXT_CONTINUE def append_result(self, mapper, selectcontext, row, instance, result, **flags): - methods.add('append_result') + methods.append('append_result') return EXT_CONTINUE def populate_instance(self, mapper, selectcontext, row, instance, **flags): - methods.add('populate_instance') + methods.append('populate_instance') return EXT_CONTINUE def before_insert(self, mapper, connection, instance): - methods.add('before_insert') + methods.append('before_insert') return EXT_CONTINUE def after_insert(self, mapper, connection, instance): - methods.add('after_insert') + methods.append('after_insert') return EXT_CONTINUE def before_update(self, mapper, connection, instance): - methods.add('before_update') + methods.append('before_update') return EXT_CONTINUE def after_update(self, mapper, connection, instance): - methods.add('after_update') + methods.append('after_update') return EXT_CONTINUE def before_delete(self, mapper, connection, instance): - methods.add('before_delete') + methods.append('before_delete') return EXT_CONTINUE def after_delete(self, mapper, connection, instance): - methods.add('after_delete') + methods.append('after_delete') return EXT_CONTINUE + def tearDown(self): + clear_mappers() + methods[:] = [] + tables.delete() + + def tearDownAll(self): + tables.drop() + + def test_basic(self): + """test that common user-defined methods get called.""" mapper(User, users, extension=Ext()) sess = create_session() u = User() @@ -1220,10 +1210,54 @@ class MapperExtensionTest(MapperSuperTest): sess.flush() sess.delete(u) sess.flush() - self.assertEquals(methods, set(['load', 'before_delete', 'create_instance', 'translate_row', 'get', - 'after_delete', 'after_insert', 'before_update', 'before_insert', 'after_update', 'populate_instance'])) + self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get', + 'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete']) + def test_inheritance(self): + # test using inheritance + class AdminUser(User): + pass + + mapper(User, users, extension=Ext()) + mapper(AdminUser, addresses, inherits=User) + + sess = create_session() + am = AdminUser() + sess.save(am) + sess.flush() + am = sess.query(AdminUser).load(am.user_id) + sess.clear() + am = sess.query(AdminUser).get(am.user_id) + am.user_name = 'foobar' + sess.flush() + sess.delete(am) + sess.flush() + self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get', + 'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete']) + + def test_inheritance_with_dupes(self): + # test using inheritance, same extension on both mappers + class AdminUser(User): + pass + ext = Ext() + mapper(User, users, extension=ext) + mapper(AdminUser, addresses, inherits=User, extension=ext) + + sess = create_session() + am = AdminUser() + sess.save(am) + sess.flush() + am = sess.query(AdminUser).load(am.user_id) + sess.clear() + am = sess.query(AdminUser).get(am.user_id) + am.user_name = 'foobar' + sess.flush() + sess.delete(am) + sess.flush() + self.assertEquals(methods, ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'get', + 'translate_row', 'create_instance', 'populate_instance', 'before_update', 'after_update', 'before_delete', 'after_delete']) + class RequirementsTest(AssertMixin): """Tests the contract for user classes.""" |
