diff options
Diffstat (limited to 'lib/sqlalchemy/orm/unitofwork.py')
| -rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 369 |
1 files changed, 92 insertions, 277 deletions
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 66b68770d..4edfeefdc 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -17,16 +17,19 @@ unique against their primary key identity using an *identity map* pattern. The Unit of Work then maintains lists of objects that are new, dirty, or deleted and provides the capability to flush all those changes at once. + """ -import StringIO, weakref -from sqlalchemy import util, logging, topological, exceptions +import StringIO + +from sqlalchemy import util, log, topological from sqlalchemy.orm import attributes, interfaces from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.mapper import object_mapper, _state_mapper, has_identity +from sqlalchemy.orm.mapper import _state_mapper # Load lazily object_session = None +_state_session = None class UOWEventHandler(interfaces.AttributeExtension): """An event handler added to all relation attributes which handles @@ -38,33 +41,33 @@ class UOWEventHandler(interfaces.AttributeExtension): self.class_ = class_ self.cascade = cascade - def _target_mapper(self, obj): - prop = object_mapper(obj).get_property(self.key) + def _target_mapper(self, state): + prop = _state_mapper(state).get_property(self.key) return prop.mapper - def append(self, obj, item, initiator): + def append(self, state, item, initiator): # process "save_update" cascade rules for when an instance is appended to the list of another instance - sess = object_session(obj) + sess = _state_session(state) if sess: if self.cascade.save_update and item not in sess: - sess.save_or_update(item, entity_name=self._target_mapper(obj).entity_name) + sess.save_or_update(item, entity_name=self._target_mapper(state).entity_name) - def remove(self, obj, item, initiator): - sess = object_session(obj) + def remove(self, state, item, initiator): + sess = _state_session(state) if sess: # expunge pending orphans if self.cascade.delete_orphan and item in sess.new: - if self._target_mapper(obj)._is_orphan(item): + if self._target_mapper(state)._is_orphan(attributes.instance_state(item)): sess.expunge(item) - def set(self, obj, newvalue, oldvalue, initiator): + def set(self, state, newvalue, oldvalue, initiator): # process "save_update" cascade rules for when an instance is attached to another instance if oldvalue is newvalue: return - sess = object_session(obj) + sess = _state_session(state) if sess: if newvalue is not None and self.cascade.save_update and newvalue not in sess: - sess.save_or_update(newvalue, entity_name=self._target_mapper(obj).entity_name) + sess.save_or_update(newvalue, entity_name=self._target_mapper(state).entity_name) if self.cascade.delete_orphan and oldvalue in sess.new: sess.expunge(oldvalue) @@ -86,184 +89,6 @@ def register_attribute(class_, key, *args, **kwargs): -class UnitOfWork(object): - """Main UOW object which stores lists of dirty/new/deleted objects. - - Provides top-level *flush* functionality as well as the - default transaction boundaries involved in a write - operation. - """ - - def __init__(self, session): - if session.weak_identity_map: - self.identity_map = attributes.WeakInstanceDict() - else: - self.identity_map = attributes.StrongInstanceDict() - - self.new = {} # InstanceState->object, strong refs object - self.deleted = {} # same - self.logger = logging.instance_logger(self, echoflag=session.echo_uow) - - def _remove_deleted(self, state): - if '_instance_key' in state.dict: - del self.identity_map[state.dict['_instance_key']] - self.deleted.pop(state, None) - self.new.pop(state, None) - - def _is_valid(self, state): - if '_instance_key' in state.dict: - return state.dict['_instance_key'] in self.identity_map - else: - return state in self.new - - def _register_clean(self, state): - """register the given object as 'clean' (i.e. persistent) within this unit of work, after - a save operation has taken place.""" - - mapper = _state_mapper(state) - instance_key = mapper._identity_key_from_state(state) - - if '_instance_key' not in state.dict: - state.dict['_instance_key'] = instance_key - - elif state.dict['_instance_key'] != instance_key: - # primary key switch - del self.identity_map[state.dict['_instance_key']] - state.dict['_instance_key'] = instance_key - - if hasattr(state, 'insert_order'): - delattr(state, 'insert_order') - - o = state.obj() - # prevent against last minute dereferences of the object - # TODO: identify a code path where state.obj() is None - if o is not None: - self.identity_map[state.dict['_instance_key']] = o - state.commit_all() - - # remove from new last, might be the last strong ref - self.new.pop(state, None) - - def register_new(self, obj): - """register the given object as 'new' (i.e. unsaved) within this unit of work.""" - - if hasattr(obj, '_instance_key'): - raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj)) - if obj._state not in self.new: - self.new[obj._state] = obj - obj._state.insert_order = len(self.new) - - def register_deleted(self, obj): - """register the given persistent object as 'to be deleted' within this unit of work.""" - - self.deleted[obj._state] = obj - - def locate_dirty(self): - """return a set of all persistent instances within this unit of work which - either contain changes or are marked as deleted. - """ - - # a little bit of inlining for speed - return util.IdentitySet([x for x in self.identity_map.values() - if x._state not in self.deleted - and ( - x._state.modified - or (x.__class__._class_state.has_mutable_scalars and x._state.is_modified()) - ) - ]) - - def flush(self, session, objects=None): - """create a dependency tree of all pending SQL operations within this unit of work and execute.""" - - dirty = [x for x in self.identity_map.all_states() - if x.modified - or (x.class_._class_state.has_mutable_scalars and x.is_modified()) - ] - - if not dirty and not self.deleted and not self.new: - return - - deleted = util.Set(self.deleted) - new = util.Set(self.new) - - dirty = util.Set(dirty).difference(deleted) - - flush_context = UOWTransaction(self, session) - - if session.extension is not None: - session.extension.before_flush(session, flush_context, objects) - - # create the set of all objects we want to operate upon - if objects: - # specific list passed in - objset = util.Set([o._state for o in objects]) - else: - # or just everything - objset = util.Set(self.identity_map.all_states()).union(new) - - # store objects whose fate has been decided - processed = util.Set() - - # put all saves/updates into the flush context. detect top-level orphans and throw them into deleted. - for state in new.union(dirty).intersection(objset).difference(deleted): - if state in processed: - continue - - obj = state.obj() - is_orphan = _state_mapper(state)._is_orphan(obj) - if is_orphan and not has_identity(obj): - raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" % - ( - obj, - ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in _state_mapper(state).delete_orphans]) - )) - flush_context.register_object(state, isdelete=is_orphan) - processed.add(state) - - # put all remaining deletes into the flush context. - for state in deleted.intersection(objset).difference(processed): - flush_context.register_object(state, isdelete=True) - - if len(flush_context.tasks) == 0: - return - - session.create_transaction(autoflush=False) - flush_context.transaction = session.transaction - try: - flush_context.execute() - - if session.extension is not None: - session.extension.after_flush(session, flush_context) - session.commit() - except: - session.rollback() - raise - - flush_context.post_exec() - - if session.extension is not None: - session.extension.after_flush_postexec(session, flush_context) - - def prune_identity_map(self): - """Removes unreferenced instances cached in a strong-referencing identity map. - - Note that this method is only meaningful if "weak_identity_map" - on the parent Session is set to False and therefore this UnitOfWork's - identity map is a regular dictionary - - Removes any object in the identity map that is not referenced - in user code or scheduled for a unit of work operation. Returns - the number of objects pruned. - """ - - if isinstance(self.identity_map, attributes.WeakInstanceDict): - return 0 - ref_count = len(self.identity_map) - dirty = self.locate_dirty() - keepers = weakref.WeakValueDictionary(self.identity_map) - self.identity_map.clear() - self.identity_map.update(keepers) - return ref_count - len(self.identity_map) class UOWTransaction(object): """Handles the details of organizing and executing transaction @@ -275,8 +100,7 @@ class UOWTransaction(object): packages. """ - def __init__(self, uow, session): - self.uow = uow + def __init__(self, session): self.session = session self.mapper_flush_opts = session._mapper_flush_opts @@ -291,7 +115,7 @@ class UOWTransaction(object): # information. self.attributes = {} - self.logger = logging.instance_logger(self, echoflag=session.echo_uow) + self.logger = log.instance_logger(self, echoflag=session.echo_uow) def get_attribute_history(self, state, key, passive=True): hashkey = ("history", state, key) @@ -310,19 +134,18 @@ class UOWTransaction(object): (added, unchanged, deleted) = attributes.get_history(state, key, passive=passive) self.attributes[hashkey] = (added, unchanged, deleted, passive) - if added is None: + if added is None or not state.get_impl(key).uses_objects: return (added, unchanged, deleted) else: return ( - [getattr(c, '_state', c) for c in added], - [getattr(c, '_state', c) for c in unchanged], - [getattr(c, '_state', c) for c in deleted], + [c is not None and attributes.instance_state(c) or None for c in added], + [c is not None and attributes.instance_state(c) or None for c in unchanged], + [c is not None and attributes.instance_state(c) or None for c in deleted], ) - - def register_object(self, state, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs): + def register_object(self, state, isdelete=False, listonly=False, postupdate=False, post_update_cols=None): # if object is not in the overall session, do nothing - if not self.uow._is_valid(state): + if not self.session._contains_state(state): if self._should_log_debug: self.logger.debug("object %s not part of session, not registering for flush" % (mapperutil.state_str(state))) return @@ -331,12 +154,12 @@ class UOWTransaction(object): self.logger.debug("register object for flush: %s isdelete=%s listonly=%s postupdate=%s" % (mapperutil.state_str(state), isdelete, listonly, postupdate)) mapper = _state_mapper(state) - + task = self.get_task_by_mapper(mapper) if postupdate: task.append_postupdate(state, post_update_cols) else: - task.append(state, listonly, isdelete=isdelete, **kwargs) + task.append(state, listonly=listonly, isdelete=isdelete) def set_row_switch(self, state): """mark a deleted object as a 'row switch'. @@ -451,22 +274,26 @@ class UOWTransaction(object): import uowdumper uowdumper.UOWDumper(tasks, buf) return buf.getvalue() - - def post_exec(self): + + def elements(self): + """return an iterator of all UOWTaskElements within this UOWTransaction.""" + for task in self.tasks.values(): + for elem in task.elements: + yield elem + elements = property(elements) + + def finalize_flush_changes(self): """mark processed objects as clean / deleted after a successful flush(). this method is called within the flush() method after the execute() method has succeeded and the transaction has been committed. """ - for task in self.tasks.values(): - for elem in task.elements: - if elem.state is None: - continue - if elem.isdelete: - self.uow._remove_deleted(elem.state) - else: - self.uow._register_clean(elem.state) + for elem in self.elements: + if elem.isdelete: + self.session._remove_newly_deleted(elem.state) + else: + self.session._register_newly_persistent(elem.state) def _sort_dependencies(self): nodes = topological.sort_with_cycles(self.dependencies, @@ -489,10 +316,9 @@ class UOWTransaction(object): class UOWTask(object): """Represents all of the objects in the UOWTransaction which correspond to - a particular mapper. This is the primary class of three classes used to generate - the elements of the dependency graph. + a particular mapper. + """ - def __init__(self, uowtransaction, mapper, base_task=None): self.uowtransaction = uowtransaction @@ -515,6 +341,7 @@ class UOWTask(object): # mapping of InstanceState -> UOWTaskElement self._objects = {} + self.dependent_tasks = [] self.dependencies = util.Set() self.cyclical_dependencies = util.Set() @@ -564,11 +391,6 @@ class UOWTask(object): rec.update(listonly, isdelete) - def _append_cyclical_childtask(self, task): - if "cyclical" not in self._objects: - self._objects["cyclical"] = UOWTaskElement(None) - self._objects["cyclical"].childtasks.append(task) - def append_postupdate(self, state, post_update_cols): """issue a 'post update' UPDATE statement via this object's mapper immediately. @@ -577,8 +399,8 @@ class UOWTask(object): """ # postupdates are UPDATED immeditely (for now) - # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns - # instead of __eq__ + # convert post_update_cols list to a Set so that __hash__() is used to compare columns + # instead of __eq__() self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols)) def __contains__(self, state): @@ -607,26 +429,42 @@ class UOWTask(object): for rec in callable(task): yield rec return property(collection) - - elements = property(lambda self:self._objects.values()) - polymorphic_elements = _polymorphic_collection(lambda task:task.elements) - - polymorphic_tosave_elements = property(lambda self: [rec for rec in self.polymorphic_elements - if not rec.isdelete]) - - polymorphic_todelete_elements = property(lambda self:[rec for rec in self.polymorphic_elements - if rec.isdelete]) + def _elements(self): + return self._objects.values() + elements = property(_elements) + + polymorphic_elements = _polymorphic_collection(_elements) - polymorphic_tosave_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements - if rec.state is not None and not rec.listonly and rec.isdelete is False]) + def polymorphic_tosave_elements(self): + return [rec for rec in self.polymorphic_elements if not rec.isdelete] + polymorphic_tosave_elements = property(polymorphic_tosave_elements) + + def polymorphic_todelete_elements(self): + return [rec for rec in self.polymorphic_elements if rec.isdelete] + polymorphic_todelete_elements = property(polymorphic_todelete_elements) + + def polymorphic_tosave_objects(self): + return [ + rec.state for rec in self.polymorphic_elements + if rec.state is not None and not rec.listonly and rec.isdelete is False + ] + polymorphic_tosave_objects = property(polymorphic_tosave_objects) - polymorphic_todelete_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements - if rec.state is not None and not rec.listonly and rec.isdelete is True]) + def polymorphic_todelete_objects(self): + return [ + rec.state for rec in self.polymorphic_elements + if rec.state is not None and not rec.listonly and rec.isdelete is True + ] + polymorphic_todelete_objects = property(polymorphic_todelete_objects) - polymorphic_dependencies = _polymorphic_collection(lambda task:task.dependencies) + def polymorphic_dependencies(self): + return self.dependencies + polymorphic_dependencies = _polymorphic_collection(polymorphic_dependencies) - polymorphic_cyclical_dependencies = _polymorphic_collection(lambda task:task.cyclical_dependencies) + def polymorphic_cyclical_dependencies(self): + return self.cyclical_dependencies + polymorphic_cyclical_dependencies = _polymorphic_collection(polymorphic_cyclical_dependencies) def _sort_circular_dependencies(self, trans, cycles): """Create a hierarchical tree of *subtasks* @@ -741,7 +579,7 @@ class UOWTask(object): if t is None: t = UOWTask(self.uowtransaction, originating_task.mapper) nexttasks[originating_task] = t - parenttask._append_cyclical_childtask(t) + parenttask.dependent_tasks.append(t) t.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete) if state in dependencies: @@ -777,29 +615,17 @@ class UOWTask(object): return ret def __repr__(self): - if self.mapper is not None: - if self.mapper.__class__.__name__ == 'Mapper': - name = self.mapper.class_.__name__ + "/" + self.mapper.local_table.description - else: - name = repr(self.mapper) - else: - name = '(none)' - return ("UOWTask(%s) Mapper: '%s'" % (hex(id(self)), name)) + return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper)) class UOWTaskElement(object): - """An element within a UOWTask. - - Corresponds to a single object instance to be saved, deleted, or - just part of the transaction as a placeholder for further - dependencies (i.e. 'listonly'). - - may also store additional sub-UOWTasks. + """Corresponds to a single InstanceState to be saved, deleted, + or otherwise marked as having dependencies. A collection of + UOWTaskElements are held by a UOWTask. + """ - def __init__(self, state): self.state = state self.listonly = True - self.childtasks = [] self.isdelete = False self.__preprocessed = {} @@ -835,11 +661,11 @@ class UOWTaskElement(object): class UOWDependencyProcessor(object): """In between the saving and deleting of objects, process - *dependent* data, such as filling in a foreign key on a child item + dependent data, such as filling in a foreign key on a child item from a new primary key, or deleting association rows before a delete. This object acts as a proxy to a DependencyProcessor. + """ - def __init__(self, processor, targettask): self.processor = processor self.targettask = targettask @@ -877,12 +703,12 @@ class UOWDependencyProcessor(object): return elem.state ret = False - elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None and not elem.is_preprocessed(self)] + elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if not elem.is_preprocessed(self)] if elements: ret = True self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False) - elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None and not elem.is_preprocessed(self)] + elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if not elem.is_preprocessed(self)] if elements: ret = True self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True) @@ -892,9 +718,9 @@ class UOWDependencyProcessor(object): """process all objects contained within this ``UOWDependencyProcessor``s target task.""" if not delete: - self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None], trans, delete=False) + self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements], trans, delete=False) else: - self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None], trans, delete=True) + self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements], trans, delete=True) def get_object_dependencies(self, state, trans, passive): return trans.get_attribute_history(state, self.processor.key, passive=passive) @@ -907,7 +733,6 @@ class UOWDependencyProcessor(object): when toplogically sorting on a per-instance basis. """ - return self.processor.whose_dependent_on_who(state1, state2) def branch(self, task): @@ -917,7 +742,6 @@ class UOWDependencyProcessor(object): is broken up into many individual ``UOWTask`` objects. """ - return UOWDependencyProcessor(self.processor, task) @@ -944,13 +768,11 @@ class UOWExecutor(object): def execute_save_steps(self, trans, task): self.save_objects(trans, task) self.execute_cyclical_dependencies(trans, task, False) - self.execute_per_element_childtasks(trans, task, False) self.execute_dependencies(trans, task, False) self.execute_dependencies(trans, task, True) - + def execute_delete_steps(self, trans, task): self.execute_cyclical_dependencies(trans, task, True) - self.execute_per_element_childtasks(trans, task, True) self.delete_objects(trans, task) def execute_dependencies(self, trans, task, isdelete=None): @@ -964,12 +786,5 @@ class UOWExecutor(object): def execute_cyclical_dependencies(self, trans, task, isdelete): for dep in task.polymorphic_cyclical_dependencies: self.execute_dependency(trans, dep, isdelete) - - def execute_per_element_childtasks(self, trans, task, isdelete): - for element in task.polymorphic_tosave_elements + task.polymorphic_todelete_elements: - self.execute_element_childtasks(trans, element, isdelete) - - def execute_element_childtasks(self, trans, element, isdelete): - for child in element.childtasks: - self.execute(trans, [child], isdelete) - + for t in task.dependent_tasks: + self.execute(trans, [t], isdelete) |
