summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/unitofwork.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/unitofwork.py')
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py369
1 files changed, 92 insertions, 277 deletions
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 66b68770d..4edfeefdc 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -17,16 +17,19 @@ unique against their primary key identity using an *identity map*
pattern. The Unit of Work then maintains lists of objects that are
new, dirty, or deleted and provides the capability to flush all those
changes at once.
+
"""
-import StringIO, weakref
-from sqlalchemy import util, logging, topological, exceptions
+import StringIO
+
+from sqlalchemy import util, log, topological
from sqlalchemy.orm import attributes, interfaces
from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.mapper import object_mapper, _state_mapper, has_identity
+from sqlalchemy.orm.mapper import _state_mapper
# Load lazily
object_session = None
+_state_session = None
class UOWEventHandler(interfaces.AttributeExtension):
"""An event handler added to all relation attributes which handles
@@ -38,33 +41,33 @@ class UOWEventHandler(interfaces.AttributeExtension):
self.class_ = class_
self.cascade = cascade
- def _target_mapper(self, obj):
- prop = object_mapper(obj).get_property(self.key)
+ def _target_mapper(self, state):
+ prop = _state_mapper(state).get_property(self.key)
return prop.mapper
- def append(self, obj, item, initiator):
+ def append(self, state, item, initiator):
# process "save_update" cascade rules for when an instance is appended to the list of another instance
- sess = object_session(obj)
+ sess = _state_session(state)
if sess:
if self.cascade.save_update and item not in sess:
- sess.save_or_update(item, entity_name=self._target_mapper(obj).entity_name)
+ sess.save_or_update(item, entity_name=self._target_mapper(state).entity_name)
- def remove(self, obj, item, initiator):
- sess = object_session(obj)
+ def remove(self, state, item, initiator):
+ sess = _state_session(state)
if sess:
# expunge pending orphans
if self.cascade.delete_orphan and item in sess.new:
- if self._target_mapper(obj)._is_orphan(item):
+ if self._target_mapper(state)._is_orphan(attributes.instance_state(item)):
sess.expunge(item)
- def set(self, obj, newvalue, oldvalue, initiator):
+ def set(self, state, newvalue, oldvalue, initiator):
# process "save_update" cascade rules for when an instance is attached to another instance
if oldvalue is newvalue:
return
- sess = object_session(obj)
+ sess = _state_session(state)
if sess:
if newvalue is not None and self.cascade.save_update and newvalue not in sess:
- sess.save_or_update(newvalue, entity_name=self._target_mapper(obj).entity_name)
+ sess.save_or_update(newvalue, entity_name=self._target_mapper(state).entity_name)
if self.cascade.delete_orphan and oldvalue in sess.new:
sess.expunge(oldvalue)
@@ -86,184 +89,6 @@ def register_attribute(class_, key, *args, **kwargs):
-class UnitOfWork(object):
- """Main UOW object which stores lists of dirty/new/deleted objects.
-
- Provides top-level *flush* functionality as well as the
- default transaction boundaries involved in a write
- operation.
- """
-
- def __init__(self, session):
- if session.weak_identity_map:
- self.identity_map = attributes.WeakInstanceDict()
- else:
- self.identity_map = attributes.StrongInstanceDict()
-
- self.new = {} # InstanceState->object, strong refs object
- self.deleted = {} # same
- self.logger = logging.instance_logger(self, echoflag=session.echo_uow)
-
- def _remove_deleted(self, state):
- if '_instance_key' in state.dict:
- del self.identity_map[state.dict['_instance_key']]
- self.deleted.pop(state, None)
- self.new.pop(state, None)
-
- def _is_valid(self, state):
- if '_instance_key' in state.dict:
- return state.dict['_instance_key'] in self.identity_map
- else:
- return state in self.new
-
- def _register_clean(self, state):
- """register the given object as 'clean' (i.e. persistent) within this unit of work, after
- a save operation has taken place."""
-
- mapper = _state_mapper(state)
- instance_key = mapper._identity_key_from_state(state)
-
- if '_instance_key' not in state.dict:
- state.dict['_instance_key'] = instance_key
-
- elif state.dict['_instance_key'] != instance_key:
- # primary key switch
- del self.identity_map[state.dict['_instance_key']]
- state.dict['_instance_key'] = instance_key
-
- if hasattr(state, 'insert_order'):
- delattr(state, 'insert_order')
-
- o = state.obj()
- # prevent against last minute dereferences of the object
- # TODO: identify a code path where state.obj() is None
- if o is not None:
- self.identity_map[state.dict['_instance_key']] = o
- state.commit_all()
-
- # remove from new last, might be the last strong ref
- self.new.pop(state, None)
-
- def register_new(self, obj):
- """register the given object as 'new' (i.e. unsaved) within this unit of work."""
-
- if hasattr(obj, '_instance_key'):
- raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj))
- if obj._state not in self.new:
- self.new[obj._state] = obj
- obj._state.insert_order = len(self.new)
-
- def register_deleted(self, obj):
- """register the given persistent object as 'to be deleted' within this unit of work."""
-
- self.deleted[obj._state] = obj
-
- def locate_dirty(self):
- """return a set of all persistent instances within this unit of work which
- either contain changes or are marked as deleted.
- """
-
- # a little bit of inlining for speed
- return util.IdentitySet([x for x in self.identity_map.values()
- if x._state not in self.deleted
- and (
- x._state.modified
- or (x.__class__._class_state.has_mutable_scalars and x._state.is_modified())
- )
- ])
-
- def flush(self, session, objects=None):
- """create a dependency tree of all pending SQL operations within this unit of work and execute."""
-
- dirty = [x for x in self.identity_map.all_states()
- if x.modified
- or (x.class_._class_state.has_mutable_scalars and x.is_modified())
- ]
-
- if not dirty and not self.deleted and not self.new:
- return
-
- deleted = util.Set(self.deleted)
- new = util.Set(self.new)
-
- dirty = util.Set(dirty).difference(deleted)
-
- flush_context = UOWTransaction(self, session)
-
- if session.extension is not None:
- session.extension.before_flush(session, flush_context, objects)
-
- # create the set of all objects we want to operate upon
- if objects:
- # specific list passed in
- objset = util.Set([o._state for o in objects])
- else:
- # or just everything
- objset = util.Set(self.identity_map.all_states()).union(new)
-
- # store objects whose fate has been decided
- processed = util.Set()
-
- # put all saves/updates into the flush context. detect top-level orphans and throw them into deleted.
- for state in new.union(dirty).intersection(objset).difference(deleted):
- if state in processed:
- continue
-
- obj = state.obj()
- is_orphan = _state_mapper(state)._is_orphan(obj)
- if is_orphan and not has_identity(obj):
- raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" %
- (
- obj,
- ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in _state_mapper(state).delete_orphans])
- ))
- flush_context.register_object(state, isdelete=is_orphan)
- processed.add(state)
-
- # put all remaining deletes into the flush context.
- for state in deleted.intersection(objset).difference(processed):
- flush_context.register_object(state, isdelete=True)
-
- if len(flush_context.tasks) == 0:
- return
-
- session.create_transaction(autoflush=False)
- flush_context.transaction = session.transaction
- try:
- flush_context.execute()
-
- if session.extension is not None:
- session.extension.after_flush(session, flush_context)
- session.commit()
- except:
- session.rollback()
- raise
-
- flush_context.post_exec()
-
- if session.extension is not None:
- session.extension.after_flush_postexec(session, flush_context)
-
- def prune_identity_map(self):
- """Removes unreferenced instances cached in a strong-referencing identity map.
-
- Note that this method is only meaningful if "weak_identity_map"
- on the parent Session is set to False and therefore this UnitOfWork's
- identity map is a regular dictionary
-
- Removes any object in the identity map that is not referenced
- in user code or scheduled for a unit of work operation. Returns
- the number of objects pruned.
- """
-
- if isinstance(self.identity_map, attributes.WeakInstanceDict):
- return 0
- ref_count = len(self.identity_map)
- dirty = self.locate_dirty()
- keepers = weakref.WeakValueDictionary(self.identity_map)
- self.identity_map.clear()
- self.identity_map.update(keepers)
- return ref_count - len(self.identity_map)
class UOWTransaction(object):
"""Handles the details of organizing and executing transaction
@@ -275,8 +100,7 @@ class UOWTransaction(object):
packages.
"""
- def __init__(self, uow, session):
- self.uow = uow
+ def __init__(self, session):
self.session = session
self.mapper_flush_opts = session._mapper_flush_opts
@@ -291,7 +115,7 @@ class UOWTransaction(object):
# information.
self.attributes = {}
- self.logger = logging.instance_logger(self, echoflag=session.echo_uow)
+ self.logger = log.instance_logger(self, echoflag=session.echo_uow)
def get_attribute_history(self, state, key, passive=True):
hashkey = ("history", state, key)
@@ -310,19 +134,18 @@ class UOWTransaction(object):
(added, unchanged, deleted) = attributes.get_history(state, key, passive=passive)
self.attributes[hashkey] = (added, unchanged, deleted, passive)
- if added is None:
+ if added is None or not state.get_impl(key).uses_objects:
return (added, unchanged, deleted)
else:
return (
- [getattr(c, '_state', c) for c in added],
- [getattr(c, '_state', c) for c in unchanged],
- [getattr(c, '_state', c) for c in deleted],
+ [c is not None and attributes.instance_state(c) or None for c in added],
+ [c is not None and attributes.instance_state(c) or None for c in unchanged],
+ [c is not None and attributes.instance_state(c) or None for c in deleted],
)
-
- def register_object(self, state, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs):
+ def register_object(self, state, isdelete=False, listonly=False, postupdate=False, post_update_cols=None):
# if object is not in the overall session, do nothing
- if not self.uow._is_valid(state):
+ if not self.session._contains_state(state):
if self._should_log_debug:
self.logger.debug("object %s not part of session, not registering for flush" % (mapperutil.state_str(state)))
return
@@ -331,12 +154,12 @@ class UOWTransaction(object):
self.logger.debug("register object for flush: %s isdelete=%s listonly=%s postupdate=%s" % (mapperutil.state_str(state), isdelete, listonly, postupdate))
mapper = _state_mapper(state)
-
+
task = self.get_task_by_mapper(mapper)
if postupdate:
task.append_postupdate(state, post_update_cols)
else:
- task.append(state, listonly, isdelete=isdelete, **kwargs)
+ task.append(state, listonly=listonly, isdelete=isdelete)
def set_row_switch(self, state):
"""mark a deleted object as a 'row switch'.
@@ -451,22 +274,26 @@ class UOWTransaction(object):
import uowdumper
uowdumper.UOWDumper(tasks, buf)
return buf.getvalue()
-
- def post_exec(self):
+
+ def elements(self):
+ """return an iterator of all UOWTaskElements within this UOWTransaction."""
+ for task in self.tasks.values():
+ for elem in task.elements:
+ yield elem
+ elements = property(elements)
+
+ def finalize_flush_changes(self):
"""mark processed objects as clean / deleted after a successful flush().
this method is called within the flush() method after the
execute() method has succeeded and the transaction has been committed.
"""
- for task in self.tasks.values():
- for elem in task.elements:
- if elem.state is None:
- continue
- if elem.isdelete:
- self.uow._remove_deleted(elem.state)
- else:
- self.uow._register_clean(elem.state)
+ for elem in self.elements:
+ if elem.isdelete:
+ self.session._remove_newly_deleted(elem.state)
+ else:
+ self.session._register_newly_persistent(elem.state)
def _sort_dependencies(self):
nodes = topological.sort_with_cycles(self.dependencies,
@@ -489,10 +316,9 @@ class UOWTransaction(object):
class UOWTask(object):
"""Represents all of the objects in the UOWTransaction which correspond to
- a particular mapper. This is the primary class of three classes used to generate
- the elements of the dependency graph.
+ a particular mapper.
+
"""
-
def __init__(self, uowtransaction, mapper, base_task=None):
self.uowtransaction = uowtransaction
@@ -515,6 +341,7 @@ class UOWTask(object):
# mapping of InstanceState -> UOWTaskElement
self._objects = {}
+ self.dependent_tasks = []
self.dependencies = util.Set()
self.cyclical_dependencies = util.Set()
@@ -564,11 +391,6 @@ class UOWTask(object):
rec.update(listonly, isdelete)
- def _append_cyclical_childtask(self, task):
- if "cyclical" not in self._objects:
- self._objects["cyclical"] = UOWTaskElement(None)
- self._objects["cyclical"].childtasks.append(task)
-
def append_postupdate(self, state, post_update_cols):
"""issue a 'post update' UPDATE statement via this object's mapper immediately.
@@ -577,8 +399,8 @@ class UOWTask(object):
"""
# postupdates are UPDATED immeditely (for now)
- # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns
- # instead of __eq__
+ # convert post_update_cols list to a Set so that __hash__() is used to compare columns
+ # instead of __eq__()
self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols))
def __contains__(self, state):
@@ -607,26 +429,42 @@ class UOWTask(object):
for rec in callable(task):
yield rec
return property(collection)
-
- elements = property(lambda self:self._objects.values())
- polymorphic_elements = _polymorphic_collection(lambda task:task.elements)
-
- polymorphic_tosave_elements = property(lambda self: [rec for rec in self.polymorphic_elements
- if not rec.isdelete])
-
- polymorphic_todelete_elements = property(lambda self:[rec for rec in self.polymorphic_elements
- if rec.isdelete])
+ def _elements(self):
+ return self._objects.values()
+ elements = property(_elements)
+
+ polymorphic_elements = _polymorphic_collection(_elements)
- polymorphic_tosave_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements
- if rec.state is not None and not rec.listonly and rec.isdelete is False])
+ def polymorphic_tosave_elements(self):
+ return [rec for rec in self.polymorphic_elements if not rec.isdelete]
+ polymorphic_tosave_elements = property(polymorphic_tosave_elements)
+
+ def polymorphic_todelete_elements(self):
+ return [rec for rec in self.polymorphic_elements if rec.isdelete]
+ polymorphic_todelete_elements = property(polymorphic_todelete_elements)
+
+ def polymorphic_tosave_objects(self):
+ return [
+ rec.state for rec in self.polymorphic_elements
+ if rec.state is not None and not rec.listonly and rec.isdelete is False
+ ]
+ polymorphic_tosave_objects = property(polymorphic_tosave_objects)
- polymorphic_todelete_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements
- if rec.state is not None and not rec.listonly and rec.isdelete is True])
+ def polymorphic_todelete_objects(self):
+ return [
+ rec.state for rec in self.polymorphic_elements
+ if rec.state is not None and not rec.listonly and rec.isdelete is True
+ ]
+ polymorphic_todelete_objects = property(polymorphic_todelete_objects)
- polymorphic_dependencies = _polymorphic_collection(lambda task:task.dependencies)
+ def polymorphic_dependencies(self):
+ return self.dependencies
+ polymorphic_dependencies = _polymorphic_collection(polymorphic_dependencies)
- polymorphic_cyclical_dependencies = _polymorphic_collection(lambda task:task.cyclical_dependencies)
+ def polymorphic_cyclical_dependencies(self):
+ return self.cyclical_dependencies
+ polymorphic_cyclical_dependencies = _polymorphic_collection(polymorphic_cyclical_dependencies)
def _sort_circular_dependencies(self, trans, cycles):
"""Create a hierarchical tree of *subtasks*
@@ -741,7 +579,7 @@ class UOWTask(object):
if t is None:
t = UOWTask(self.uowtransaction, originating_task.mapper)
nexttasks[originating_task] = t
- parenttask._append_cyclical_childtask(t)
+ parenttask.dependent_tasks.append(t)
t.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete)
if state in dependencies:
@@ -777,29 +615,17 @@ class UOWTask(object):
return ret
def __repr__(self):
- if self.mapper is not None:
- if self.mapper.__class__.__name__ == 'Mapper':
- name = self.mapper.class_.__name__ + "/" + self.mapper.local_table.description
- else:
- name = repr(self.mapper)
- else:
- name = '(none)'
- return ("UOWTask(%s) Mapper: '%s'" % (hex(id(self)), name))
+ return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper))
class UOWTaskElement(object):
- """An element within a UOWTask.
-
- Corresponds to a single object instance to be saved, deleted, or
- just part of the transaction as a placeholder for further
- dependencies (i.e. 'listonly').
-
- may also store additional sub-UOWTasks.
+ """Corresponds to a single InstanceState to be saved, deleted,
+ or otherwise marked as having dependencies. A collection of
+ UOWTaskElements are held by a UOWTask.
+
"""
-
def __init__(self, state):
self.state = state
self.listonly = True
- self.childtasks = []
self.isdelete = False
self.__preprocessed = {}
@@ -835,11 +661,11 @@ class UOWTaskElement(object):
class UOWDependencyProcessor(object):
"""In between the saving and deleting of objects, process
- *dependent* data, such as filling in a foreign key on a child item
+ dependent data, such as filling in a foreign key on a child item
from a new primary key, or deleting association rows before a
delete. This object acts as a proxy to a DependencyProcessor.
+
"""
-
def __init__(self, processor, targettask):
self.processor = processor
self.targettask = targettask
@@ -877,12 +703,12 @@ class UOWDependencyProcessor(object):
return elem.state
ret = False
- elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None and not elem.is_preprocessed(self)]
+ elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if not elem.is_preprocessed(self)]
if elements:
ret = True
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False)
- elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None and not elem.is_preprocessed(self)]
+ elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if not elem.is_preprocessed(self)]
if elements:
ret = True
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
@@ -892,9 +718,9 @@ class UOWDependencyProcessor(object):
"""process all objects contained within this ``UOWDependencyProcessor``s target task."""
if not delete:
- self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None], trans, delete=False)
+ self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements], trans, delete=False)
else:
- self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None], trans, delete=True)
+ self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements], trans, delete=True)
def get_object_dependencies(self, state, trans, passive):
return trans.get_attribute_history(state, self.processor.key, passive=passive)
@@ -907,7 +733,6 @@ class UOWDependencyProcessor(object):
when toplogically sorting on a per-instance basis.
"""
-
return self.processor.whose_dependent_on_who(state1, state2)
def branch(self, task):
@@ -917,7 +742,6 @@ class UOWDependencyProcessor(object):
is broken up into many individual ``UOWTask`` objects.
"""
-
return UOWDependencyProcessor(self.processor, task)
@@ -944,13 +768,11 @@ class UOWExecutor(object):
def execute_save_steps(self, trans, task):
self.save_objects(trans, task)
self.execute_cyclical_dependencies(trans, task, False)
- self.execute_per_element_childtasks(trans, task, False)
self.execute_dependencies(trans, task, False)
self.execute_dependencies(trans, task, True)
-
+
def execute_delete_steps(self, trans, task):
self.execute_cyclical_dependencies(trans, task, True)
- self.execute_per_element_childtasks(trans, task, True)
self.delete_objects(trans, task)
def execute_dependencies(self, trans, task, isdelete=None):
@@ -964,12 +786,5 @@ class UOWExecutor(object):
def execute_cyclical_dependencies(self, trans, task, isdelete):
for dep in task.polymorphic_cyclical_dependencies:
self.execute_dependency(trans, dep, isdelete)
-
- def execute_per_element_childtasks(self, trans, task, isdelete):
- for element in task.polymorphic_tosave_elements + task.polymorphic_todelete_elements:
- self.execute_element_childtasks(trans, element, isdelete)
-
- def execute_element_childtasks(self, trans, element, isdelete):
- for child in element.childtasks:
- self.execute(trans, [child], isdelete)
-
+ for t in task.dependent_tasks:
+ self.execute(trans, [t], isdelete)