summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-05-17 18:17:46 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-05-17 18:17:46 +0000
commit2be867ffac8881a4a20ca5387063ed207ac876dc (patch)
tree30b4a8d0663febea019679442c1f98360eb5ce26 /lib/sqlalchemy
parent6515e84d4c7084d9276922786291f6e047b70b84 (diff)
downloadsqlalchemy-2be867ffac8881a4a20ca5387063ed207ac876dc.tar.gz
- Significant performance enhancements regarding Sessions/flush()
in conjunction with large mapper graphs, large numbers of objects: - The Session's "weak referencing" behavior is now *full* - no strong references whatsoever are made to a mapped object or related items/collections in its __dict__. Backrefs and other cycles in objects no longer affect the Session's ability to lose all references to unmodified objects. Objects with pending changes still are maintained strongly until flush. [ticket:1398] The implementation also improves performance by moving the "resurrection" process of garbage collected items to only be relevant for mappings that map "mutable" attributes (i.e. PickleType, composite attrs). This removes overhead from the gc process and simplifies internal behavior. If a "mutable" attribute change is the sole change on an object which is then dereferenced, the mapper will not have access to other attribute state when the UPDATE is issued. This may present itself differently to some MapperExtensions. The change also affects the internal attribute API, but not the AttributeExtension interface nor any of the publically documented attribute functions. - The unit of work no longer genererates a graph of "dependency" processors for the full graph of mappers during flush(), instead creating such processors only for those mappers which represent objects with pending changes. This saves a tremendous number of method calls in the context of a large interconnected graph of mappers. - Cached a wasteful "table sort" operation that previously occured multiple times per flush, also removing significant method call count from flush(). - Other redundant behaviors have been simplified in mapper._save_obj().
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/attributes.py672
-rw-r--r--lib/sqlalchemy/orm/collections.py7
-rw-r--r--lib/sqlalchemy/orm/dependency.py39
-rw-r--r--lib/sqlalchemy/orm/dynamic.py46
-rw-r--r--lib/sqlalchemy/orm/identity.py89
-rw-r--r--lib/sqlalchemy/orm/interfaces.py24
-rw-r--r--lib/sqlalchemy/orm/mapper.py144
-rw-r--r--lib/sqlalchemy/orm/properties.py18
-rw-r--r--lib/sqlalchemy/orm/query.py16
-rw-r--r--lib/sqlalchemy/orm/session.py33
-rw-r--r--lib/sqlalchemy/orm/state.py429
-rw-r--r--lib/sqlalchemy/orm/strategies.py34
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py20
13 files changed, 887 insertions, 684 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 68aa0d93a..4fa41ff3b 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -20,14 +20,13 @@ import types
import weakref
from sqlalchemy import util
-from sqlalchemy.util import EMPTY_SET
from sqlalchemy.orm import interfaces, collections, exc
import sqlalchemy.exceptions as sa_exc
# lazy imports
_entity_info = None
identity_equal = None
-
+state = None
PASSIVE_NORESULT = util.symbol('PASSIVE_NORESULT')
ATTR_WAS_SET = util.symbol('ATTR_WAS_SET')
@@ -105,7 +104,7 @@ class QueryableAttribute(interfaces.PropComparator):
self.parententity = parententity
def get_history(self, instance, **kwargs):
- return self.impl.get_history(instance_state(instance), **kwargs)
+ return self.impl.get_history(instance_state(instance), instance_dict(instance), **kwargs)
def __selectable__(self):
# TODO: conditionally attach this method based on clause_element ?
@@ -148,15 +147,15 @@ class InstrumentedAttribute(QueryableAttribute):
"""Public-facing descriptor, placed in the mapped class dictionary."""
def __set__(self, instance, value):
- self.impl.set(instance_state(instance), value, None)
+ self.impl.set(instance_state(instance), instance_dict(instance), value, None)
def __delete__(self, instance):
- self.impl.delete(instance_state(instance))
+ self.impl.delete(instance_state(instance), instance_dict(instance))
def __get__(self, instance, owner):
if instance is None:
return self
- return self.impl.get(instance_state(instance))
+ return self.impl.get(instance_state(instance), instance_dict(instance))
class _ProxyImpl(object):
accepts_scalar_loader = False
@@ -335,7 +334,7 @@ class AttributeImpl(object):
else:
state.callables[self.key] = callable_
- def get_history(self, state, passive=PASSIVE_OFF):
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
raise NotImplementedError()
def _get_callable(self, state):
@@ -346,13 +345,13 @@ class AttributeImpl(object):
else:
return None
- def initialize(self, state):
+ def initialize(self, state, dict_):
"""Initialize this attribute on the given object instance with an empty value."""
- state.dict[self.key] = None
+ dict_[self.key] = None
return None
- def get(self, state, passive=PASSIVE_OFF):
+ def get(self, state, dict_, passive=PASSIVE_OFF):
"""Retrieve a value from the given object.
If a callable is assembled on this object's attribute, and
@@ -361,7 +360,7 @@ class AttributeImpl(object):
"""
try:
- return state.dict[self.key]
+ return dict_[self.key]
except KeyError:
# if no history, check for lazy callables, etc.
if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET:
@@ -374,25 +373,25 @@ class AttributeImpl(object):
return PASSIVE_NORESULT
value = callable_()
if value is not ATTR_WAS_SET:
- return self.set_committed_value(state, value)
+ return self.set_committed_value(state, dict_, value)
else:
- if self.key not in state.dict:
+ if self.key not in dict_:
return self.get(state, passive=passive)
- return state.dict[self.key]
+ return dict_[self.key]
# Return a new, empty value
- return self.initialize(state)
+ return self.initialize(state, dict_)
- def append(self, state, value, initiator, passive=PASSIVE_OFF):
- self.set(state, value, initiator)
+ def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ self.set(state, dict_, value, initiator)
- def remove(self, state, value, initiator, passive=PASSIVE_OFF):
- self.set(state, None, initiator)
+ def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ self.set(state, dict_, None, initiator)
- def set(self, state, value, initiator):
+ def set(self, state, dict_, value, initiator):
raise NotImplementedError()
- def get_committed_value(self, state, passive=PASSIVE_OFF):
+ def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
"""return the unchanged value of this attribute"""
if self.key in state.committed_state:
@@ -401,12 +400,12 @@ class AttributeImpl(object):
else:
return state.committed_state.get(self.key)
else:
- return self.get(state, passive=passive)
+ return self.get(state, dict_, passive=passive)
- def set_committed_value(self, state, value):
+ def set_committed_value(self, state, dict_, value):
"""set an attribute value on the given instance and 'commit' it."""
- state.commit([self.key])
+ state.commit(dict_, [self.key])
state.callables.pop(self.key, None)
state.dict[self.key] = value
@@ -419,45 +418,45 @@ class ScalarAttributeImpl(AttributeImpl):
accepts_scalar_loader = True
uses_objects = False
- def delete(self, state):
+ def delete(self, state, dict_):
# TODO: catch key errors, convert to attributeerror?
if self.active_history or self.extensions:
- old = self.get(state)
+ old = self.get(state, dict_)
else:
- old = state.dict.get(self.key, NO_VALUE)
+ old = dict_.get(self.key, NO_VALUE)
- state.modified_event(self, False, old)
+ state.modified_event(dict_, self, False, old)
if self.extensions:
- self.fire_remove_event(state, old, None)
- del state.dict[self.key]
+ self.fire_remove_event(state, dict_, old, None)
+ del dict_[self.key]
- def get_history(self, state, passive=PASSIVE_OFF):
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
return History.from_attribute(
- self, state, state.dict.get(self.key, NO_VALUE))
+ self, state, dict_.get(self.key, NO_VALUE))
- def set(self, state, value, initiator):
+ def set(self, state, dict_, value, initiator):
if initiator is self:
return
if self.active_history or self.extensions:
- old = self.get(state)
+ old = self.get(state, dict_)
else:
- old = state.dict.get(self.key, NO_VALUE)
+ old = dict_.get(self.key, NO_VALUE)
- state.modified_event(self, False, old)
+ state.modified_event(dict_, self, False, old)
if self.extensions:
- value = self.fire_replace_event(state, value, old, initiator)
- state.dict[self.key] = value
+ value = self.fire_replace_event(state, dict_, value, old, initiator)
+ dict_[self.key] = value
- def fire_replace_event(self, state, value, previous, initiator):
+ def fire_replace_event(self, state, dict_, value, previous, initiator):
for ext in self.extensions:
value = ext.set(state, value, previous, initiator or self)
return value
- def fire_remove_event(self, state, value, initiator):
+ def fire_remove_event(self, state, dict_, value, initiator):
for ext in self.extensions:
ext.remove(state, value, initiator or self)
@@ -483,29 +482,48 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
raise sa_exc.ArgumentError("MutableScalarAttributeImpl requires a copy function")
self.copy = copy_function
- def get_history(self, state, passive=PASSIVE_OFF):
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ if not dict_:
+ v = state.committed_state.get(self.key, NO_VALUE)
+ else:
+ v = dict_.get(self.key, NO_VALUE)
+
return History.from_attribute(
- self, state, state.dict.get(self.key, NO_VALUE))
+ self, state, v)
- def commit_to_state(self, state, dest):
- dest[self.key] = self.copy(state.dict[self.key])
+ def commit_to_state(self, state, dict_, dest):
+ dest[self.key] = self.copy(dict_[self.key])
- def check_mutable_modified(self, state):
- (added, unchanged, deleted) = self.get_history(state, passive=PASSIVE_NO_INITIALIZE)
+ def check_mutable_modified(self, state, dict_):
+ (added, unchanged, deleted) = self.get_history(state, dict_, passive=PASSIVE_NO_INITIALIZE)
return bool(added or deleted)
- def set(self, state, value, initiator):
+ def get(self, state, dict_, passive=PASSIVE_OFF):
+ if self.key not in state.mutable_dict:
+ ret = ScalarAttributeImpl.get(self, state, dict_, passive=passive)
+ if ret is not PASSIVE_NORESULT:
+ state.mutable_dict[self.key] = ret
+ return ret
+ else:
+ return state.mutable_dict[self.key]
+
+ def delete(self, state, dict_):
+ ScalarAttributeImpl.delete(self, state, dict_)
+ state.mutable_dict.pop(self.key)
+
+ def set(self, state, dict_, value, initiator):
if initiator is self:
return
- state.modified_event(self, True, NEVER_SET)
-
+ state.modified_event(dict_, self, True, NEVER_SET)
+
if self.extensions:
- old = self.get(state)
- value = self.fire_replace_event(state, value, old, initiator)
- state.dict[self.key] = value
+ old = self.get(state, dict_)
+ value = self.fire_replace_event(state, dict_, value, old, initiator)
+ dict_[self.key] = value
else:
- state.dict[self.key] = value
+ dict_[self.key] = value
+ state.mutable_dict[self.key] = value
class ScalarObjectAttributeImpl(ScalarAttributeImpl):
@@ -526,22 +544,22 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
if compare_function is None:
self.is_equal = identity_equal
- def delete(self, state):
- old = self.get(state)
- self.fire_remove_event(state, old, self)
- del state.dict[self.key]
+ def delete(self, state, dict_):
+ old = self.get(state, dict_)
+ self.fire_remove_event(state, dict_, old, self)
+ del dict_[self.key]
- def get_history(self, state, passive=PASSIVE_OFF):
- if self.key in state.dict:
- return History.from_attribute(self, state, state.dict[self.key])
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ if self.key in dict_:
+ return History.from_attribute(self, state, dict_[self.key])
else:
- current = self.get(state, passive=passive)
+ current = self.get(state, dict_, passive=passive)
if current is PASSIVE_NORESULT:
return HISTORY_BLANK
else:
return History.from_attribute(self, state, current)
- def set(self, state, value, initiator):
+ def set(self, state, dict_, value, initiator):
"""Set a value on the given InstanceState.
`initiator` is the ``InstrumentedAttribute`` that initiated the
@@ -553,12 +571,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
return
# may want to add options to allow the get() here to be passive
- old = self.get(state)
- value = self.fire_replace_event(state, value, old, initiator)
- state.dict[self.key] = value
+ old = self.get(state, dict_)
+ value = self.fire_replace_event(state, dict_, value, old, initiator)
+ dict_[self.key] = value
- def fire_remove_event(self, state, value, initiator):
- state.modified_event(self, False, value)
+ def fire_remove_event(self, state, dict_, value, initiator):
+ state.modified_event(dict_, self, False, value)
if self.trackparent and value is not None:
self.sethasparent(instance_state(value), False)
@@ -566,8 +584,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
for ext in self.extensions:
ext.remove(state, value, initiator or self)
- def fire_replace_event(self, state, value, previous, initiator):
- state.modified_event(self, False, previous)
+ def fire_replace_event(self, state, dict_, value, previous, initiator):
+ state.modified_event(dict_, self, False, previous)
if self.trackparent:
if previous is not value and previous is not None:
@@ -615,15 +633,15 @@ class CollectionAttributeImpl(AttributeImpl):
def __copy(self, item):
return [y for y in list(collections.collection_adapter(item))]
- def get_history(self, state, passive=PASSIVE_OFF):
- current = self.get(state, passive=passive)
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ current = self.get(state, dict_, passive=passive)
if current is PASSIVE_NORESULT:
return HISTORY_BLANK
else:
return History.from_attribute(self, state, current)
- def fire_append_event(self, state, value, initiator):
- state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+ def fire_append_event(self, state, dict_, value, initiator):
+ state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
for ext in self.extensions:
value = ext.append(state, value, initiator or self)
@@ -633,11 +651,11 @@ class CollectionAttributeImpl(AttributeImpl):
return value
- def fire_pre_remove_event(self, state, initiator):
- state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+ def fire_pre_remove_event(self, state, dict_, initiator):
+ state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
- def fire_remove_event(self, state, value, initiator):
- state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+ def fire_remove_event(self, state, dict_, value, initiator):
+ state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
if self.trackparent and value is not None:
self.sethasparent(instance_state(value), False)
@@ -645,51 +663,51 @@ class CollectionAttributeImpl(AttributeImpl):
for ext in self.extensions:
ext.remove(state, value, initiator or self)
- def delete(self, state):
- if self.key not in state.dict:
+ def delete(self, state, dict_):
+ if self.key not in dict_:
return
- state.modified_event(self, True, NEVER_SET)
+ state.modified_event(dict_, self, True, NEVER_SET)
- collection = self.get_collection(state)
+ collection = self.get_collection(state, state.dict)
collection.clear_with_event()
# TODO: catch key errors, convert to attributeerror?
- del state.dict[self.key]
+ del dict_[self.key]
- def initialize(self, state):
+ def initialize(self, state, dict_):
"""Initialize this attribute with an empty collection."""
_, user_data = self._initialize_collection(state)
- state.dict[self.key] = user_data
+ dict_[self.key] = user_data
return user_data
def _initialize_collection(self, state):
return state.manager.initialize_collection(
self.key, state, self.collection_factory)
- def append(self, state, value, initiator, passive=PASSIVE_OFF):
+ def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
if initiator is self:
return
- collection = self.get_collection(state, passive=passive)
+ collection = self.get_collection(state, dict_, passive=passive)
if collection is PASSIVE_NORESULT:
- value = self.fire_append_event(state, value, initiator)
+ value = self.fire_append_event(state, dict_, value, initiator)
state.get_pending(self.key).append(value)
else:
collection.append_with_event(value, initiator)
- def remove(self, state, value, initiator, passive=PASSIVE_OFF):
+ def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
if initiator is self:
return
- collection = self.get_collection(state, passive=passive)
+ collection = self.get_collection(state, state.dict, passive=passive)
if collection is PASSIVE_NORESULT:
- self.fire_remove_event(state, value, initiator)
+ self.fire_remove_event(state, dict_, value, initiator)
state.get_pending(self.key).remove(value)
else:
collection.remove_with_event(value, initiator)
- def set(self, state, value, initiator):
+ def set(self, state, dict_, value, initiator):
"""Set a value on the given object.
`initiator` is the ``InstrumentedAttribute`` that initiated the
@@ -701,10 +719,10 @@ class CollectionAttributeImpl(AttributeImpl):
return
self._set_iterable(
- state, value,
+ state, dict_, value,
lambda adapter, i: adapter.adapt_like_to_iterable(i))
- def _set_iterable(self, state, iterable, adapter=None):
+ def _set_iterable(self, state, dict_, iterable, adapter=None):
"""Set a collection value from an iterable of state-bearers.
``adapter`` is an optional callable invoked with a CollectionAdapter
@@ -722,24 +740,24 @@ class CollectionAttributeImpl(AttributeImpl):
else:
new_values = list(iterable)
- old = self.get(state)
+ old = self.get(state, dict_)
# ignore re-assignment of the current collection, as happens
# implicitly with in-place operators (foo.collection |= other)
if old is iterable:
return
- state.modified_event(self, True, old)
+ state.modified_event(dict_, self, True, old)
- old_collection = self.get_collection(state, old)
+ old_collection = self.get_collection(state, dict_, old)
- state.dict[self.key] = user_data
+ dict_[self.key] = user_data
collections.bulk_replace(new_values, old_collection, new_collection)
old_collection.unlink(old)
- def set_committed_value(self, state, value):
+ def set_committed_value(self, state, dict_, value):
"""Set an attribute value on the given instance and 'commit' it."""
collection, user_data = self._initialize_collection(state)
@@ -751,13 +769,13 @@ class CollectionAttributeImpl(AttributeImpl):
state.callables.pop(self.key, None)
state.dict[self.key] = user_data
- state.commit([self.key])
+ state.commit(dict_, [self.key])
if self.key in state.pending:
# pending items exist. issue a modified event,
# add/remove new items.
- state.modified_event(self, True, user_data)
+ state.modified_event(dict_, self, True, user_data)
pending = state.pending.pop(self.key)
added = pending.added_items
@@ -769,14 +787,14 @@ class CollectionAttributeImpl(AttributeImpl):
return user_data
- def get_collection(self, state, user_data=None, passive=PASSIVE_OFF):
+ def get_collection(self, state, dict_, user_data=None, passive=PASSIVE_OFF):
"""Retrieve the CollectionAdapter associated with the given state.
Creates a new CollectionAdapter if one does not exist.
"""
if user_data is None:
- user_data = self.get(state, passive=passive)
+ user_data = self.get(state, dict_, passive=passive)
if user_data is PASSIVE_NORESULT:
return user_data
@@ -799,320 +817,26 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
if oldchild is not None:
# With lazy=None, there's no guarantee that the full collection is
# present when updating via a backref.
- old_state = instance_state(oldchild)
+ old_state, old_dict = instance_state(oldchild), instance_dict(oldchild)
impl = old_state.get_impl(self.key)
try:
- impl.remove(old_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
+ impl.remove(old_state, old_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
except (ValueError, KeyError, IndexError):
pass
if child is not None:
- new_state = instance_state(child)
- new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
+ new_state, new_dict = instance_state(child), instance_dict(child)
+ new_state.get_impl(self.key).append(new_state, new_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
return child
def append(self, state, child, initiator):
- child_state = instance_state(child)
- child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
+ child_state, child_dict = instance_state(child), instance_dict(child)
+ child_state.get_impl(self.key).append(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
return child
def remove(self, state, child, initiator):
if child is not None:
- child_state = instance_state(child)
- child_state.get_impl(self.key).remove(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
-
-
-class InstanceState(object):
- """tracks state information at the instance level."""
-
- session_id = None
- key = None
- runid = None
- expired_attributes = EMPTY_SET
- load_options = EMPTY_SET
- load_path = ()
- insert_order = None
-
- def __init__(self, obj, manager):
- self.class_ = obj.__class__
- self.manager = manager
- self.obj = weakref.ref(obj, self._cleanup)
- self.dict = obj.__dict__
- self.modified = False
- self.callables = {}
- self.expired = False
- self.committed_state = {}
- self.pending = {}
- self.parents = {}
-
- def detach(self):
- if self.session_id:
- del self.session_id
-
- def dispose(self):
- if self.session_id:
- del self.session_id
- del self.obj
- del self.dict
-
- def _cleanup(self, ref):
- self.dispose()
-
- def obj(self):
- return None
-
- @util.memoized_property
- def dict(self):
- # return a blank dict
- # if none is available, so that asynchronous gc
- # doesn't blow up expiration operations in progress
- # (usually expire_attributes)
- return {}
-
- @property
- def sort_key(self):
- return self.key and self.key[1] or (self.insert_order, )
-
- def check_modified(self):
- if self.modified:
- return True
- else:
- for key in self.manager.mutable_attributes:
- if self.manager[key].impl.check_mutable_modified(self):
- return True
- else:
- return False
-
- def initialize_instance(*mixed, **kwargs):
- self, instance, args = mixed[0], mixed[1], mixed[2:]
- manager = self.manager
-
- for fn in manager.events.on_init:
- fn(self, instance, args, kwargs)
- try:
- return manager.events.original_init(*mixed[1:], **kwargs)
- except:
- for fn in manager.events.on_init_failure:
- fn(self, instance, args, kwargs)
- raise
-
- def get_history(self, key, **kwargs):
- return self.manager.get_impl(key).get_history(self, **kwargs)
-
- def get_impl(self, key):
- return self.manager.get_impl(key)
-
- def get_pending(self, key):
- if key not in self.pending:
- self.pending[key] = PendingCollection()
- return self.pending[key]
-
- def value_as_iterable(self, key, passive=PASSIVE_OFF):
- """return an InstanceState attribute as a list,
- regardless of it being a scalar or collection-based
- attribute.
-
- returns None if passive is not PASSIVE_OFF and the getter returns
- PASSIVE_NORESULT.
- """
-
- impl = self.get_impl(key)
- x = impl.get(self, passive=passive)
- if x is PASSIVE_NORESULT:
-
- return None
- elif hasattr(impl, 'get_collection'):
- return impl.get_collection(self, x, passive=passive)
- elif isinstance(x, list):
- return x
- else:
- return [x]
-
- def _run_on_load(self, instance=None):
- if instance is None:
- instance = self.obj()
- self.manager.events.run('on_load', instance)
-
- def __getstate__(self):
- return {'key': self.key,
- 'committed_state': self.committed_state,
- 'pending': self.pending,
- 'parents': self.parents,
- 'modified': self.modified,
- 'expired':self.expired,
- 'load_options':self.load_options,
- 'load_path':interfaces.serialize_path(self.load_path),
- 'instance': self.obj(),
- 'expired_attributes':self.expired_attributes,
- 'callables': self.callables}
-
- def __setstate__(self, state):
- self.committed_state = state['committed_state']
- self.parents = state['parents']
- self.key = state['key']
- self.session_id = None
- self.pending = state['pending']
- self.modified = state['modified']
- self.obj = weakref.ref(state['instance'])
- self.load_options = state['load_options'] or EMPTY_SET
- self.load_path = interfaces.deserialize_path(state['load_path'])
- self.class_ = self.obj().__class__
- self.manager = manager_of_class(self.class_)
- self.dict = self.obj().__dict__
- self.callables = state['callables']
- self.runid = None
- self.expired = state['expired']
- self.expired_attributes = state['expired_attributes']
-
- def initialize(self, key):
- self.manager.get_impl(key).initialize(self)
-
- def set_callable(self, key, callable_):
- self.dict.pop(key, None)
- self.callables[key] = callable_
-
- def __call__(self):
- """__call__ allows the InstanceState to act as a deferred
- callable for loading expired attributes, which is also
- serializable (picklable).
-
- """
- unmodified = self.unmodified
- class_manager = self.manager
- class_manager.deferred_scalar_loader(self, [
- attr.impl.key for attr in class_manager.attributes if
- attr.impl.accepts_scalar_loader and
- attr.impl.key in self.expired_attributes and
- attr.impl.key in unmodified
- ])
- for k in self.expired_attributes:
- self.callables.pop(k, None)
- del self.expired_attributes
- return ATTR_WAS_SET
-
- @property
- def unmodified(self):
- """a set of keys which have no uncommitted changes"""
-
- return set(
- key for key in self.manager.iterkeys()
- if (key not in self.committed_state or
- (key in self.manager.mutable_attributes and
- not self.manager[key].impl.check_mutable_modified(self))))
-
- @property
- def unloaded(self):
- """a set of keys which do not have a loaded value.
-
- This includes expired attributes and any other attribute that
- was never populated or modified.
-
- """
- return set(
- key for key in self.manager.iterkeys()
- if key not in self.committed_state and key not in self.dict)
-
- def expire_attributes(self, attribute_names):
- self.expired_attributes = set(self.expired_attributes)
-
- if attribute_names is None:
- attribute_names = self.manager.keys()
- self.expired = True
- self.modified = False
- filter_deferred = True
- else:
- filter_deferred = False
- for key in attribute_names:
- impl = self.manager[key].impl
- if not filter_deferred or \
- not impl.dont_expire_missing or \
- key in self.dict:
- self.expired_attributes.add(key)
- if impl.accepts_scalar_loader:
- self.callables[key] = self
- self.dict.pop(key, None)
- self.pending.pop(key, None)
- self.committed_state.pop(key, None)
-
- def reset(self, key):
- """remove the given attribute and any callables associated with it."""
-
- self.dict.pop(key, None)
- self.callables.pop(key, None)
-
- def modified_event(self, attr, should_copy, previous, passive=PASSIVE_OFF):
- needs_committed = attr.key not in self.committed_state
-
- if needs_committed:
- if previous is NEVER_SET:
- if passive:
- if attr.key in self.dict:
- previous = self.dict[attr.key]
- else:
- previous = attr.get(self)
-
- if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
- previous = attr.copy(previous)
-
- if needs_committed:
- self.committed_state[attr.key] = previous
-
- self.modified = True
-
- def commit(self, keys):
- """Commit attributes.
-
- This is used by a partial-attribute load operation to mark committed
- those attributes which were refreshed from the database.
-
- Attributes marked as "expired" can potentially remain "expired" after
- this step if a value was not populated in state.dict.
-
- """
- class_manager = self.manager
- for key in keys:
- if key in self.dict and key in class_manager.mutable_attributes:
- class_manager[key].impl.commit_to_state(self, self.committed_state)
- else:
- self.committed_state.pop(key, None)
-
- self.expired = False
- # unexpire attributes which have loaded
- for key in self.expired_attributes.intersection(keys):
- if key in self.dict:
- self.expired_attributes.remove(key)
- self.callables.pop(key, None)
-
- def commit_all(self):
- """commit all attributes unconditionally.
-
- This is used after a flush() or a full load/refresh
- to remove all pending state from the instance.
-
- - all attributes are marked as "committed"
- - the "strong dirty reference" is removed
- - the "modified" flag is set to False
- - any "expired" markers/callables are removed.
-
- Attributes marked as "expired" can potentially remain "expired" after this step
- if a value was not populated in state.dict.
-
- """
-
- self.committed_state = {}
- self.pending = {}
-
- # unexpire attributes which have loaded
- if self.expired_attributes:
- for key in self.expired_attributes.intersection(self.dict):
- self.callables.pop(key, None)
- self.expired_attributes.difference_update(self.dict)
-
- for key in self.manager.mutable_attributes:
- if key in self.dict:
- self.manager[key].impl.commit_to_state(self, self.committed_state)
-
- self.modified = self.expired = False
- self._strong_obj = None
+ child_state, child_dict = instance_state(child), instance_dict(child)
+ child_state.get_impl(self.key).remove(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
class Events(object):
@@ -1121,6 +845,7 @@ class Events(object):
self.on_init = ()
self.on_init_failure = ()
self.on_load = ()
+ self.on_resurrect = ()
def run(self, event, *args, **kwargs):
for fn in getattr(self, event):
@@ -1146,7 +871,6 @@ class ClassManager(dict):
STATE_ATTR = '_sa_instance_state'
event_registry_factory = Events
- instance_state_factory = InstanceState
deferred_scalar_loader = None
def __init__(self, class_):
@@ -1170,7 +894,6 @@ class ClassManager(dict):
def _configure_create_arguments(self,
_source=None,
- instance_state_factory=None,
deferred_scalar_loader=None):
"""Accept extra **kw arguments passed to create_manager_for_cls.
@@ -1185,11 +908,8 @@ class ClassManager(dict):
"""
if _source:
- instance_state_factory = _source.instance_state_factory
deferred_scalar_loader = _source.deferred_scalar_loader
- if instance_state_factory:
- self.instance_state_factory = instance_state_factory
if deferred_scalar_loader:
self.deferred_scalar_loader = deferred_scalar_loader
@@ -1222,7 +942,16 @@ class ClassManager(dict):
if self.new_init:
self.uninstall_member('__init__')
self.new_init = None
-
+
+ def _create_instance_state(self, instance):
+ global state
+ if state is None:
+ from sqlalchemy.orm import state
+ if self.mutable_attributes:
+ return state.MutableAttrInstanceState(instance, self)
+ else:
+ return state.InstanceState(instance, self)
+
def manage(self):
"""Mark this instance as the manager for its class."""
@@ -1330,11 +1059,11 @@ class ClassManager(dict):
def new_instance(self, state=None):
instance = self.class_.__new__(self.class_)
- setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self))
+ setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance))
return instance
def setup_instance(self, instance, state=None):
- setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self))
+ setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance))
def teardown_instance(self, instance):
delattr(instance, self.STATE_ATTR)
@@ -1348,13 +1077,10 @@ class ClassManager(dict):
if hasattr(instance, self.STATE_ATTR):
return False
else:
- state = self.instance_state_factory(instance, self)
+ state = self._create_instance_state(instance)
setattr(instance, self.STATE_ATTR, state)
return state
- def state_of(self, instance):
- return getattr(instance, self.STATE_ATTR)
-
def state_getter(self):
"""Return a (instance) -> InstanceState callable.
@@ -1365,6 +1091,9 @@ class ClassManager(dict):
return attrgetter(self.STATE_ATTR)
+ def dict_getter(self):
+ return attrgetter('__dict__')
+
def has_state(self, instance):
return hasattr(instance, self.STATE_ATTR)
@@ -1385,6 +1114,9 @@ class _ClassInstrumentationAdapter(ClassManager):
def __init__(self, class_, override, **kw):
self._adapted = override
+ self._get_state = self._adapted.state_getter(class_)
+ self._get_dict = self._adapted.dict_getter(class_)
+
ClassManager.__init__(self, class_, **kw)
def manage(self):
@@ -1446,36 +1178,27 @@ class _ClassInstrumentationAdapter(ClassManager):
self._adapted.initialize_instance_dict(self.class_, instance)
if state is None:
- state = self.instance_state_factory(instance, self)
+ state = self._create_instance_state(instance)
# the given instance is assumed to have no state
self._adapted.install_state(self.class_, instance, state)
- state.dict = self._adapted.get_instance_dict(self.class_, instance)
return state
def teardown_instance(self, instance):
self._adapted.remove_state(self.class_, instance)
- def state_of(self, instance):
- if hasattr(self._adapted, 'state_of'):
- return self._adapted.state_of(self.class_, instance)
- else:
- getter = self._adapted.state_getter(self.class_)
- return getter(instance)
-
def has_state(self, instance):
- if hasattr(self._adapted, 'has_state'):
- return self._adapted.has_state(self.class_, instance)
- else:
- try:
- state = self.state_of(instance)
- return True
- except exc.NO_STATE:
- return False
+ try:
+ state = self._get_state(instance)
+ return True
+ except exc.NO_STATE:
+ return False
def state_getter(self):
- return self._adapted.state_getter(self.class_)
+ return self._get_state
+ def dict_getter(self):
+ return self._get_dict
class History(tuple):
"""A 3-tuple of added, unchanged and deleted values.
@@ -1520,7 +1243,7 @@ class History(tuple):
original = state.committed_state.get(attribute.key, NEVER_SET)
if hasattr(attribute, 'get_collection'):
- current = attribute.get_collection(state, current)
+ current = attribute.get_collection(state, state.dict, current)
if original is NO_VALUE:
return cls(list(current), (), ())
elif original is NEVER_SET:
@@ -1557,30 +1280,8 @@ class History(tuple):
HISTORY_BLANK = History(None, None, None)
-class PendingCollection(object):
- """A writable placeholder for an unloaded collection.
-
- Stores items appended to and removed from a collection that has not yet
- been loaded. When the collection is loaded, the changes stored in
- PendingCollection are applied to it to produce the final result.
-
- """
- def __init__(self):
- self.deleted_items = util.IdentitySet()
- self.added_items = util.OrderedIdentitySet()
-
- def append(self, value):
- if value in self.deleted_items:
- self.deleted_items.remove(value)
- self.added_items.add(value)
-
- def remove(self, value):
- if value in self.added_items:
- self.added_items.remove(value)
- self.deleted_items.add(value)
-
def _conditional_instance_state(obj):
- if not isinstance(obj, InstanceState):
+ if not isinstance(obj, state.InstanceState):
obj = instance_state(obj)
return obj
@@ -1690,15 +1391,16 @@ def init_collection(obj, key):
this usage is deprecated.
"""
-
- return init_state_collection(_conditional_instance_state(obj), key)
+ state = _conditional_instance_state(obj)
+ dict_ = state.dict
+ return init_state_collection(state, dict_, key)
-def init_state_collection(state, key):
+def init_state_collection(state, dict_, key):
"""Initialize a collection attribute and return the collection adapter."""
attr = state.get_impl(key)
- user_data = attr.initialize(state)
- return attr.get_collection(state, user_data)
+ user_data = attr.initialize(state, dict_)
+ return attr.get_collection(state, dict_, user_data)
def set_committed_value(instance, key, value):
"""Set the value of an attribute with no history events.
@@ -1715,8 +1417,8 @@ def set_committed_value(instance, key, value):
as though it were part of its original loaded state.
"""
- state = instance_state(instance)
- state.get_impl(key).set_committed_value(instance, key, value)
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state.get_impl(key).set_committed_value(state, dict_, key, value)
def set_attribute(instance, key, value):
"""Set the value of an attribute, firing history events.
@@ -1728,8 +1430,8 @@ def set_attribute(instance, key, value):
by SQLAlchemy.
"""
- state = instance_state(instance)
- state.get_impl(key).set(state, value, None)
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state.get_impl(key).set(state, dict_, value, None)
def get_attribute(instance, key):
"""Get the value of an attribute, firing any callables required.
@@ -1741,8 +1443,8 @@ def get_attribute(instance, key):
by SQLAlchemy.
"""
- state = instance_state(instance)
- return state.get_impl(key).get(state)
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ return state.get_impl(key).get(state, dict_)
def del_attribute(instance, key):
"""Delete the value of an attribute, firing history events.
@@ -1754,8 +1456,8 @@ def del_attribute(instance, key):
by SQLAlchemy.
"""
- state = instance_state(instance)
- state.get_impl(key).delete(state)
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state.get_impl(key).delete(state, dict_)
def is_instrumented(instance, key):
"""Return True if the given attribute on the given instance is instrumented
@@ -1772,6 +1474,7 @@ class InstrumentationRegistry(object):
_manager_finders = weakref.WeakKeyDictionary()
_state_finders = util.WeakIdentityMapping()
+ _dict_finders = util.WeakIdentityMapping()
_extended = False
def create_manager_for_cls(self, class_, **kw):
@@ -1806,6 +1509,7 @@ class InstrumentationRegistry(object):
manager.factory = factory
self._manager_finders[class_] = manager.manager_getter()
self._state_finders[class_] = manager.state_getter()
+ self._dict_finders[class_] = manager.dict_getter()
return manager
def _collect_management_factories_for(self, cls):
@@ -1845,6 +1549,7 @@ class InstrumentationRegistry(object):
return finder(cls)
def state_of(self, instance):
+ # this is only called when alternate instrumentation has been established
if instance is None:
raise AttributeError("None has no persistent state.")
try:
@@ -1852,21 +1557,15 @@ class InstrumentationRegistry(object):
except KeyError:
raise AttributeError("%r is not instrumented" % instance.__class__)
- def state_or_default(self, instance, default=None):
+ def dict_of(self, instance):
+ # this is only called when alternate instrumentation has been established
if instance is None:
- return default
+ raise AttributeError("None has no persistent state.")
try:
- finder = self._state_finders[instance.__class__]
+ return self._dict_finders[instance.__class__](instance)
except KeyError:
- return default
- else:
- try:
- return finder(instance)
- except exc.NO_STATE:
- return default
- except:
- raise
-
+ raise AttributeError("%r is not instrumented" % instance.__class__)
+
def unregister(self, class_):
if class_ in self._manager_finders:
manager = self.manager_of_class(class_)
@@ -1874,6 +1573,7 @@ class InstrumentationRegistry(object):
manager.dispose()
del self._manager_finders[class_]
del self._state_finders[class_]
+ del self._dict_finders[class_]
instrumentation_registry = InstrumentationRegistry()
@@ -1887,12 +1587,14 @@ def _install_lookup_strategy(implementation):
and unit tests specific to this behavior.
"""
- global instance_state
+ global instance_state, instance_dict
if implementation is util.symbol('native'):
instance_state = attrgetter(ClassManager.STATE_ATTR)
+ instance_dict = attrgetter("__dict__")
else:
instance_state = instrumentation_registry.state_of
-
+ instance_dict = instrumentation_registry.dict_of
+
manager_of_class = instrumentation_registry.manager_of_class
_create_manager_for_cls = instrumentation_registry.create_manager_for_cls
_install_lookup_strategy(util.symbol('native'))
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index 5638a7e4a..4ca4c5719 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -472,6 +472,7 @@ class CollectionAdapter(object):
"""
def __init__(self, attr, owner_state, data):
self.attr = attr
+ # TODO: figure out what this being a weakref buys us
self._data = weakref.ref(data)
self.owner_state = owner_state
self.link_to_self(data)
@@ -578,7 +579,7 @@ class CollectionAdapter(object):
"""
if initiator is not False and item is not None:
- return self.attr.fire_append_event(self.owner_state, item, initiator)
+ return self.attr.fire_append_event(self.owner_state, self.owner_state.dict, item, initiator)
else:
return item
@@ -591,7 +592,7 @@ class CollectionAdapter(object):
"""
if initiator is not False and item is not None:
- self.attr.fire_remove_event(self.owner_state, item, initiator)
+ self.attr.fire_remove_event(self.owner_state, self.owner_state.dict, item, initiator)
def fire_pre_remove_event(self, initiator=None):
"""Notify that an entity is about to be removed from the collection.
@@ -600,7 +601,7 @@ class CollectionAdapter(object):
fire_remove_event().
"""
- self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator)
+ self.attr.fire_pre_remove_event(self.owner_state, self.owner_state.dict, initiator=initiator)
def __getstate__(self):
return {'key': self.attr.key,
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
index a80727b7f..151c557d7 100644
--- a/lib/sqlalchemy/orm/dependency.py
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -64,17 +64,21 @@ class DependencyProcessor(object):
def register_dependencies(self, uowcommit):
"""Tell a ``UOWTransaction`` what mappers are dependent on
which, with regards to the two or three mappers handled by
- this ``PropertyLoader``.
+ this ``DependencyProcessor``.
- Also register itself as a *processor* for one of its mappers,
- which will be executed after that mapper's objects have been
- saved or before they've been deleted. The process operation
- manages attributes and dependent operations upon the objects
- of one of the involved mappers.
"""
raise NotImplementedError()
+ def register_processors(self, uowcommit):
+ """Tell a ``UOWTransaction`` about this object as a processor,
+ which will be executed after that mapper's objects have been
+ saved or before they've been deleted. The process operation
+ manages attributes and dependent operations between two mappers.
+
+ """
+ raise NotImplementedError()
+
def whose_dependent_on_who(self, state1, state2):
"""Given an object pair assuming `obj2` is a child of `obj1`,
return a tuple with the dependent object second, or None if
@@ -181,9 +185,13 @@ class OneToManyDP(DependencyProcessor):
if self.post_update:
uowcommit.register_dependency(self.mapper, self.dependency_marker)
uowcommit.register_dependency(self.parent, self.dependency_marker)
- uowcommit.register_processor(self.dependency_marker, self, self.parent)
else:
uowcommit.register_dependency(self.parent, self.mapper)
+
+ def register_processors(self, uowcommit):
+ if self.post_update:
+ uowcommit.register_processor(self.dependency_marker, self, self.parent)
+ else:
uowcommit.register_processor(self.parent, self, self.parent)
def process_dependencies(self, task, deplist, uowcommit, delete = False):
@@ -285,6 +293,9 @@ class DetectKeySwitch(DependencyProcessor):
no_dependencies = True
def register_dependencies(self, uowcommit):
+ pass
+
+ def register_processors(self, uowcommit):
uowcommit.register_processor(self.parent, self, self.mapper)
def preprocess_dependencies(self, task, deplist, uowcommit, delete=False):
@@ -330,12 +341,15 @@ class ManyToOneDP(DependencyProcessor):
if self.post_update:
uowcommit.register_dependency(self.mapper, self.dependency_marker)
uowcommit.register_dependency(self.parent, self.dependency_marker)
- uowcommit.register_processor(self.dependency_marker, self, self.parent)
else:
uowcommit.register_dependency(self.mapper, self.parent)
+
+ def register_processors(self, uowcommit):
+ if self.post_update:
+ uowcommit.register_processor(self.dependency_marker, self, self.parent)
+ else:
uowcommit.register_processor(self.mapper, self, self.parent)
-
def process_dependencies(self, task, deplist, uowcommit, delete=False):
if delete:
if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all':
@@ -408,8 +422,10 @@ class ManyToManyDP(DependencyProcessor):
uowcommit.register_dependency(self.parent, self.dependency_marker)
uowcommit.register_dependency(self.mapper, self.dependency_marker)
- uowcommit.register_processor(self.dependency_marker, self, self.parent)
+ def register_processors(self, uowcommit):
+ uowcommit.register_processor(self.dependency_marker, self, self.parent)
+
def process_dependencies(self, task, deplist, uowcommit, delete = False):
connection = uowcommit.transaction.connection(self.mapper)
secondary_delete = []
@@ -527,6 +543,9 @@ class MapperStub(object):
def _register_dependencies(self, uowcommit):
pass
+ def _register_procesors(self, uowcommit):
+ pass
+
def _save_obj(self, *args, **kwargs):
pass
diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py
index 3d31a686a..70243291d 100644
--- a/lib/sqlalchemy/orm/dynamic.py
+++ b/lib/sqlalchemy/orm/dynamic.py
@@ -55,21 +55,21 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
else:
self.query_class = mixin_user_query(query_class)
- def get(self, state, passive=False):
+ def get(self, state, dict_, passive=False):
if passive:
return self._get_collection_history(state, passive=True).added_items
else:
return self.query_class(self, state)
- def get_collection(self, state, user_data=None, passive=True):
+ def get_collection(self, state, dict_, user_data=None, passive=True):
if passive:
return self._get_collection_history(state, passive=passive).added_items
else:
history = self._get_collection_history(state, passive=passive)
return history.added_items + history.unchanged_items
- def fire_append_event(self, state, value, initiator):
- collection_history = self._modified_event(state)
+ def fire_append_event(self, state, dict_, value, initiator):
+ collection_history = self._modified_event(state, dict_)
collection_history.added_items.append(value)
for ext in self.extensions:
@@ -78,8 +78,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
if self.trackparent and value is not None:
self.sethasparent(attributes.instance_state(value), True)
- def fire_remove_event(self, state, value, initiator):
- collection_history = self._modified_event(state)
+ def fire_remove_event(self, state, dict_, value, initiator):
+ collection_history = self._modified_event(state, dict_)
collection_history.deleted_items.append(value)
if self.trackparent and value is not None:
@@ -88,31 +88,31 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
for ext in self.extensions:
ext.remove(state, value, initiator or self)
- def _modified_event(self, state):
+ def _modified_event(self, state, dict_):
if self.key not in state.committed_state:
state.committed_state[self.key] = CollectionHistory(self, state)
- state.modified_event(self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE)
+ state.modified_event(dict_, self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE)
# this is a hack to allow the _base.ComparableEntity fixture
# to work
- state.dict[self.key] = True
+ dict_[self.key] = True
return state.committed_state[self.key]
- def set(self, state, value, initiator):
+ def set(self, state, dict_, value, initiator):
if initiator is self:
return
- self._set_iterable(state, value)
+ self._set_iterable(state, dict_, value)
- def _set_iterable(self, state, iterable, adapter=None):
+ def _set_iterable(self, state, dict_, iterable, adapter=None):
- collection_history = self._modified_event(state)
+ collection_history = self._modified_event(state, dict_)
new_values = list(iterable)
if _state_has_identity(state):
- old_collection = list(self.get(state))
+ old_collection = list(self.get(state, dict_))
else:
old_collection = []
@@ -121,7 +121,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
def delete(self, *args, **kwargs):
raise NotImplementedError()
- def get_history(self, state, passive=False):
+ def get_history(self, state, dict_, passive=False):
c = self._get_collection_history(state, passive)
return attributes.History(c.added_items, c.unchanged_items, c.deleted_items)
@@ -136,13 +136,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
else:
return c
- def append(self, state, value, initiator, passive=False):
+ def append(self, state, dict_, value, initiator, passive=False):
if initiator is not self:
- self.fire_append_event(state, value, initiator)
+ self.fire_append_event(state, dict_, value, initiator)
- def remove(self, state, value, initiator, passive=False):
+ def remove(self, state, dict_, value, initiator, passive=False):
if initiator is not self:
- self.fire_remove_event(state, value, initiator)
+ self.fire_remove_event(state, dict_, value, initiator)
class DynCollectionAdapter(object):
"""the dynamic analogue to orm.collections.CollectionAdapter"""
@@ -156,10 +156,10 @@ class DynCollectionAdapter(object):
return iter(self.data)
def append_with_event(self, item, initiator=None):
- self.attr.append(self.state, item, initiator)
+ self.attr.append(self.state, self.state.dict, item, initiator)
def remove_with_event(self, item, initiator=None):
- self.attr.remove(self.state, item, initiator)
+ self.attr.remove(self.state, self.state.dict, item, initiator)
def append_without_event(self, item):
pass
@@ -240,10 +240,10 @@ class AppenderMixin(object):
return query
def append(self, item):
- self.attr.append(attributes.instance_state(self.instance), item, None)
+ self.attr.append(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None)
def remove(self, item):
- self.attr.remove(attributes.instance_state(self.instance), item, None)
+ self.attr.remove(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None)
class AppenderQuery(AppenderMixin, Query):
diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py
index 0753ea991..aa041a585 100644
--- a/lib/sqlalchemy/orm/identity.py
+++ b/lib/sqlalchemy/orm/identity.py
@@ -15,6 +15,9 @@ class IdentityMap(dict):
self._mutable_attrs = {}
self.modified = False
self._wr = weakref.ref(self)
+
+ def replace(self, state):
+ raise NotImplementedError()
def add(self, state):
raise NotImplementedError()
@@ -102,6 +105,17 @@ class WeakInstanceDict(IdentityMap):
def contains_state(self, state):
return dict.get(self, state.key) is state
+ def replace(self, state):
+ if dict.__contains__(self, state.key):
+ existing = dict.__getitem__(self, state.key)
+ if existing is not state:
+ self._manage_removed_state(existing)
+ else:
+ return
+
+ dict.__setitem__(self, state.key, state)
+ self._manage_incoming_state(state)
+
def add(self, state):
if state.key in self:
if dict.__getitem__(self, state.key) is not state:
@@ -161,12 +175,24 @@ class StrongInstanceDict(IdentityMap):
def contains_state(self, state):
return state.key in self and attributes.instance_state(self[state.key]) is state
+ def replace(self, state):
+ if dict.__contains__(self, state.key):
+ existing = dict.__getitem__(self, state.key)
+ existing = attributes.instance_state(existing)
+ if existing is not state:
+ self._manage_removed_state(existing)
+ else:
+ return
+
+ dict.__setitem__(self, state.key, state.obj())
+ self._manage_incoming_state(state)
+
def add(self, state):
dict.__setitem__(self, state.key, state.obj())
self._manage_incoming_state(state)
def remove(self, state):
- if dict.pop(self, state.key) is not state:
+ if attributes.instance_state(dict.pop(self, state.key)) is not state:
raise AssertionError("State %s is not present in this identity map" % state)
self._manage_removed_state(state)
@@ -176,7 +202,7 @@ class StrongInstanceDict(IdentityMap):
self._manage_removed_state(state)
def remove_key(self, key):
- state = dict.__getitem__(self, key)
+ state = attributes.instance_state(dict.__getitem__(self, key))
self.remove(state)
def prune(self):
@@ -190,62 +216,3 @@ class StrongInstanceDict(IdentityMap):
self.modified = bool(dirty)
return ref_count - len(self)
-class IdentityManagedState(attributes.InstanceState):
- def _instance_dict(self):
- return None
-
- def modified_event(self, attr, should_copy, previous, passive=False):
- attributes.InstanceState.modified_event(self, attr, should_copy, previous, passive)
-
- instance_dict = self._instance_dict()
- if instance_dict:
- instance_dict.modified = True
-
- def _is_really_none(self):
- """do a check modified/resurrect.
-
- This would be called in the extremely rare
- race condition that the weakref returned None but
- the cleanup handler had not yet established the
- __resurrect callable as its replacement.
-
- """
- if self.check_modified():
- self.obj = self.__resurrect
- return self.obj()
- else:
- return None
-
- def _cleanup(self, ref):
- """weakref callback.
-
- This method may be called by an asynchronous
- gc.
-
- If the state shows pending changes, the weakref
- is replaced by the __resurrect callable which will
- re-establish an object reference on next access,
- else removes this InstanceState from the owning
- identity map, if any.
-
- """
- if self.check_modified():
- self.obj = self.__resurrect
- else:
- instance_dict = self._instance_dict()
- if instance_dict:
- instance_dict.remove(self)
- self.dispose()
-
- def __resurrect(self):
- """A substitute for the obj() weakref function which resurrects."""
-
- # store strong ref'ed version of the object; will revert
- # to weakref when changes are persisted
- obj = self.manager.new_instance(state=self)
- self.obj = weakref.ref(obj, self._cleanup)
- self._strong_obj = obj
- obj.__dict__.update(self.dict)
- self.dict = obj.__dict__
- self._run_on_load(obj)
- return obj
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index d36f51194..0ac771305 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -359,7 +359,7 @@ class MapperProperty(object):
Callables are of the following form::
- def new_execute(state, row, **flags):
+ def new_execute(state, dict_, row, **flags):
# process incoming instance state and given row. the instance is
# "new" and was just created upon receipt of this row.
# flags is a dictionary containing at least the following
@@ -368,7 +368,7 @@ class MapperProperty(object):
# result of reading this row
# instancekey - identity key of the instance
- def existing_execute(state, row, **flags):
+ def existing_execute(state, dict_, row, **flags):
# process incoming instance state and given row. the instance is
# "existing" and was created based on a previous row.
@@ -427,13 +427,23 @@ class MapperProperty(object):
def register_dependencies(self, *args, **kwargs):
"""Called by the ``Mapper`` in response to the UnitOfWork
calling the ``Mapper``'s register_dependencies operation.
- Should register with the UnitOfWork all inter-mapper
- dependencies as well as dependency processors (see UOW docs
- for more details).
+ Establishes a topological dependency between two mappers
+ which will affect the order in which mappers persist data.
+
"""
pass
+ def register_processors(self, *args, **kwargs):
+ """Called by the ``Mapper`` in response to the UnitOfWork
+ calling the ``Mapper``'s register_processors operation.
+ Establishes a processor object between two mappers which
+ will link data and state between parent/child objects.
+
+ """
+
+ pass
+
def is_primary(self):
"""Return True if this ``MapperProperty``'s mapper is the
primary mapper for its class.
@@ -939,3 +949,7 @@ class InstrumentationManager(object):
def state_getter(self, class_):
return lambda instance: getattr(instance, '_default_state')
+
+ def dict_getter(self, class_):
+ return lambda inst: self.get_instance_dict(class_, inst)
+ \ No newline at end of file
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 8af6153d6..87c4c8100 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -23,7 +23,6 @@ deque = __import__('collections').deque
from sqlalchemy import sql, util, log, exc as sa_exc
from sqlalchemy.sql import expression, visitors, operators, util as sqlutil
from sqlalchemy.orm import attributes, exc, sync
-from sqlalchemy.orm.identity import IdentityManagedState
from sqlalchemy.orm.interfaces import (
MapperProperty, EXT_CONTINUE, PropComparator
)
@@ -255,7 +254,8 @@ class Mapper(object):
for mapper in self.iterate_to_root():
util.reset_memoized(mapper, '_equivalent_columns')
-
+ util.reset_memoized(mapper, '_sorted_tables')
+
if self.order_by is False and not self.concrete and self.inherits.order_by is not False:
self.order_by = self.inherits.order_by
@@ -357,7 +357,6 @@ class Mapper(object):
if manager is None:
manager = attributes.register_class(self.class_,
- instance_state_factory = IdentityManagedState,
deferred_scalar_loader = _load_scalar_attributes
)
@@ -372,6 +371,8 @@ class Mapper(object):
event_registry = manager.events
event_registry.add_listener('on_init', _event_on_init)
event_registry.add_listener('on_init_failure', _event_on_init_failure)
+ event_registry.add_listener('on_resurrect', _event_on_resurrect)
+
for key, method in util.iterate_attributes(self.class_):
if isinstance(method, types.FunctionType):
if hasattr(method, '__sa_reconstructor__'):
@@ -1173,6 +1174,19 @@ class Mapper(object):
# persistence
+ @util.memoized_property
+ def _sorted_tables(self):
+ table_to_mapper = {}
+ for mapper in self.base_mapper.polymorphic_iterator():
+ for t in mapper.tables:
+ table_to_mapper[t] = mapper
+
+ sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys())
+ ret = util.OrderedDict()
+ for t in sorted_:
+ ret[t] = table_to_mapper[t]
+ return ret
+
def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
@@ -1198,16 +1212,37 @@ class Mapper(object):
# if session has a connection callable,
# organize individual states with the connection to use for insert/update
+ tups = []
if 'connection_callable' in uowtransaction.mapper_flush_opts:
connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
- tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in _sort_states(states)]
+ for state in _sort_states(states):
+ m = _state_mapper(state)
+ tups.append(
+ (
+ state,
+ m,
+ connection_callable(self, state.obj()),
+ _state_has_identity(state),
+ state.key or m._identity_key_from_state(state)
+ )
+ )
else:
connection = uowtransaction.transaction.connection(self)
- tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in _sort_states(states)]
+ for state in _sort_states(states):
+ m = _state_mapper(state)
+ tups.append(
+ (
+ state,
+ m,
+ connection,
+ _state_has_identity(state),
+ state.key or m._identity_key_from_state(state)
+ )
+ )
if not postupdate:
# call before_XXX extensions
- for state, mapper, connection, has_identity in tups:
+ for state, mapper, connection, has_identity, instance_key in tups:
if not has_identity:
if 'before_insert' in mapper.extension:
mapper.extension.before_insert(mapper, connection, state.obj())
@@ -1215,39 +1250,44 @@ class Mapper(object):
if 'before_update' in mapper.extension:
mapper.extension.before_update(mapper, connection, state.obj())
- for state, mapper, connection, has_identity in tups:
- # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
- # and another instance with the same identity key already exists as persistent. convert to an
- # UPDATE if so.
- instance_key = mapper._identity_key_from_state(state)
- if not postupdate and not has_identity and instance_key in uowtransaction.session.identity_map:
- instance = uowtransaction.session.identity_map[instance_key]
- existing = attributes.instance_state(instance)
- if not uowtransaction.is_deleted(existing):
- raise exc.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing)))
- if self._should_log_debug:
- self._log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing)))
- uowtransaction.set_row_switch(existing)
-
- table_to_mapper = {}
- for mapper in self.base_mapper.polymorphic_iterator():
- for t in mapper.tables:
- table_to_mapper[t] = mapper
+ row_switches = set()
+ if not postupdate:
+ for state, mapper, connection, has_identity, instance_key in tups:
+ # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
+ # and another instance with the same identity key already exists as persistent. convert to an
+ # UPDATE if so.
+ if not has_identity and instance_key in uowtransaction.session.identity_map:
+ instance = uowtransaction.session.identity_map[instance_key]
+ existing = attributes.instance_state(instance)
+ if not uowtransaction.is_deleted(existing):
+ raise exc.FlushError(
+ "New instance %s with identity key %s conflicts with persistent instance %s" %
+ (state_str(state), instance_key, state_str(existing)))
+ if self._should_log_debug:
+ self._log_debug(
+ "detected row switch for identity %s. will update %s, remove %s from transaction",
+ instance_key, state_str(state), state_str(existing))
+
+ # remove the "delete" flag from the existing element
+ uowtransaction.set_row_switch(existing)
+ row_switches.add(state)
+
+ table_to_mapper = self._sorted_tables
- for table in sqlutil.sort_tables(table_to_mapper.iterkeys()):
+ for table in table_to_mapper.iterkeys():
insert = []
update = []
- for state, mapper, connection, has_identity in tups:
+ for state, mapper, connection, has_identity, instance_key in tups:
if table not in mapper._pks_by_table:
continue
+
pks = mapper._pks_by_table[table]
- instance_key = mapper._identity_key_from_state(state)
-
+
if self._should_log_debug:
self._log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key)))
- isinsert = not instance_key in uowtransaction.session.identity_map and not postupdate and not has_identity
+ isinsert = not has_identity and not postupdate and state not in row_switches
params = {}
value_params = {}
@@ -1364,7 +1404,7 @@ class Mapper(object):
sync.populate(state, m, state, m, m._inherits_equated_pairs)
if not postupdate:
- for state, mapper, connection, has_identity in tups:
+ for state, mapper, connection, has_identity, instance_key in tups:
# expire readonly attributes
readonly = state.unmodified.intersection(
@@ -1434,12 +1474,9 @@ class Mapper(object):
if 'before_delete' in mapper.extension:
mapper.extension.before_delete(mapper, connection, state.obj())
- table_to_mapper = {}
- for mapper in self.base_mapper.polymorphic_iterator():
- for t in mapper.tables:
- table_to_mapper[t] = mapper
+ table_to_mapper = self._sorted_tables
- for table in reversed(sqlutil.sort_tables(table_to_mapper.iterkeys())):
+ for table in reversed(table_to_mapper.keys()):
delete = {}
for state, mapper, connection in tups:
if table not in mapper._pks_by_table:
@@ -1485,6 +1522,10 @@ class Mapper(object):
for dep in self._props.values() + self._dependency_processors:
dep.register_dependencies(uowcommit)
+ def _register_processors(self, uowcommit):
+ for dep in self._props.values() + self._dependency_processors:
+ dep.register_processors(uowcommit)
+
# result set conversion
def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None, polymorphic_discriminator=None):
@@ -1514,7 +1555,7 @@ class Mapper(object):
new_populators = []
existing_populators = []
- def populate_state(state, row, isnew, only_load_props, **flags):
+ def populate_state(state, dict_, row, isnew, only_load_props, **flags):
if isnew:
if context.options:
state.load_options = context.options
@@ -1533,7 +1574,7 @@ class Mapper(object):
populators = [p for p in populators if p[0] in only_load_props]
for key, populator in populators:
- populator(state, row, isnew=isnew, **flags)
+ populator(state, dict_, row, isnew=isnew, **flags)
session_identity_map = context.session.identity_map
@@ -1573,9 +1614,11 @@ class Mapper(object):
if identitykey in session_identity_map:
instance = session_identity_map[identitykey]
state = attributes.instance_state(instance)
+ dict_ = attributes.instance_dict(instance)
if self._should_log_debug:
- self._log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), identitykey))
+ self._log_debug("_instance(): using existing instance %s identity %s",
+ instance_str(instance), identitykey)
isnew = state.runid != context.runid
currentload = not isnew
@@ -1592,12 +1635,13 @@ class Mapper(object):
# when eager_defaults is True.
state = refresh_state
instance = state.obj()
+ dict_ = attributes.instance_dict(instance)
isnew = state.runid != context.runid
currentload = True
loaded_instance = False
else:
if self._should_log_debug:
- self._log_debug("_instance(): identity key %s not in session" % str(identitykey))
+ self._log_debug("_instance(): identity key %s not in session", identitykey)
if self.allow_null_pks:
for x in identitykey[1]:
@@ -1625,8 +1669,10 @@ class Mapper(object):
instance = self.class_manager.new_instance()
if self._should_log_debug:
- self._log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey)))
+ self._log_debug("_instance(): created new instance %s identity %s",
+ instance_str(instance), identitykey)
+ dict_ = attributes.instance_dict(instance)
state = attributes.instance_state(instance)
state.key = identitykey
@@ -1638,12 +1684,12 @@ class Mapper(object):
if currentload or populate_existing:
if isnew:
state.runid = context.runid
- context.progress.add(state)
+ context.progress[state] = dict_
if not populate_instance or \
populate_instance(self, context, row, instance,
only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
- populate_state(state, row, isnew, only_load_props)
+ populate_state(state, dict_, row, isnew, only_load_props)
else:
# populate attributes on non-loading instances which have been expired
@@ -1652,16 +1698,16 @@ class Mapper(object):
if state in context.partials:
isnew = False
- attrs = context.partials[state]
+ (d_, attrs) = context.partials[state]
else:
isnew = True
attrs = state.unloaded
- context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs
+ context.partials[state] = (dict_, attrs) #<-- allow query.instances to commit the subset of attrs
if not populate_instance or \
populate_instance(self, context, row, instance,
only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
- populate_state(state, row, isnew, attrs, instancekey=identitykey)
+ populate_state(state, dict_, row, isnew, attrs, instancekey=identitykey)
if loaded_instance:
state._run_on_load(instance)
@@ -1759,6 +1805,14 @@ def _event_on_init_failure(state, instance, args, kwargs):
instrumenting_mapper, instrumenting_mapper.class_,
state.manager.events.original_init, instance, args, kwargs)
+def _event_on_resurrect(state, instance):
+ # re-populate the primary key elements
+ # of the dict based on the mapping.
+ instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
+ for col, val in zip(instrumenting_mapper.primary_key, state.key[1]):
+ instrumenting_mapper._set_state_attr_by_column(state, col, val)
+
+
def _sort_states(states):
return sorted(states, key=operator.attrgetter('sort_key'))
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index d0cca2dc1..5605cdcd1 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -96,13 +96,13 @@ class ColumnProperty(StrategizedProperty):
return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
def getattr(self, state, column):
- return state.get_impl(self.key).get(state)
+ return state.get_impl(self.key).get(state, state.dict)
def getcommitted(self, state, column, passive=False):
- return state.get_impl(self.key).get_committed_value(state, passive=passive)
+ return state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive)
def setattr(self, state, value, column):
- state.get_impl(self.key).set(state, value, None)
+ state.get_impl(self.key).set(state, state.dict, value, None)
def merge(self, session, source, dest, dont_load, _recursive):
value = attributes.instance_state(source).value_as_iterable(
@@ -159,7 +159,7 @@ class CompositeProperty(ColumnProperty):
super(ColumnProperty, self).do_init()
def getattr(self, state, column):
- obj = state.get_impl(self.key).get(state)
+ obj = state.get_impl(self.key).get(state, state.dict)
return self.get_col_value(column, obj)
def getcommitted(self, state, column, passive=False):
@@ -168,7 +168,7 @@ class CompositeProperty(ColumnProperty):
def setattr(self, state, value, column):
- obj = state.get_impl(self.key).get(state)
+ obj = state.get_impl(self.key).get(state, state.dict)
if obj is None:
obj = self.composite_class(*[None for c in self.columns])
state.get_impl(self.key).set(state, obj, None)
@@ -635,7 +635,7 @@ class RelationProperty(StrategizedProperty):
return
source_state = attributes.instance_state(source)
- dest_state = attributes.instance_state(dest)
+ dest_state, dest_dict = attributes.instance_state(dest), attributes.instance_dict(dest)
if not "merge" in self.cascade:
dest_state.expire_attributes([self.key])
@@ -658,7 +658,7 @@ class RelationProperty(StrategizedProperty):
for c in dest_list:
coll.append_without_event(c)
else:
- getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_list)
+ getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_dict, dest_list)
else:
current = instances[0]
if current is not None:
@@ -1119,6 +1119,10 @@ class RelationProperty(StrategizedProperty):
if not self.viewonly:
self._dependency_processor.register_dependencies(uowcommit)
+ def register_processors(self, uowcommit):
+ if not self.viewonly:
+ self._dependency_processor.register_processors(uowcommit)
+
PropertyLoader = RelationProperty
log.class_logger(RelationProperty)
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 28ddcc5ea..e3cc3c756 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -1330,7 +1330,7 @@ class Query(object):
rowtuple.keys = labels.keys
while True:
- context.progress = set()
+ context.progress = {}
context.partials = {}
if self._yield_per:
@@ -1354,13 +1354,13 @@ class Query(object):
rows = filter(rows)
if context.refresh_state and self._only_load_props and context.refresh_state in context.progress:
- context.refresh_state.commit(self._only_load_props)
- context.progress.remove(context.refresh_state)
+ context.refresh_state.commit(context.refresh_state.dict, self._only_load_props)
+ context.progress.pop(context.refresh_state)
session._finalize_loaded(context.progress)
- for ii, attrs in context.partials.items():
- ii.commit(attrs)
+ for ii, (dict_, attrs) in context.partials.items():
+ ii.commit(dict_, attrs)
for row in rows:
yield row
@@ -1683,14 +1683,14 @@ class Query(object):
evaluated_keys = value_evaluators.keys()
if issubclass(cls, target_cls) and eval_condition(obj):
- state = attributes.instance_state(obj)
+ state, dict_ = attributes.instance_state(obj), attributes.instance_dict(obj)
# only evaluate unmodified attributes
to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
- state.dict[key] = value_evaluators[key](obj)
+ dict_[key] = value_evaluators[key](obj)
- state.commit(list(to_evaluate))
+ state.commit(dict_, list(to_evaluate))
# expire attributes with pending changes (there was no autoflush, so they are overwritten)
state.expire_attributes(set(evaluated_keys).difference(to_evaluate))
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 1e3a750d9..00a7d55e5 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -12,7 +12,7 @@ import sqlalchemy.exceptions as sa_exc
from sqlalchemy import util, sql, engine, log
from sqlalchemy.sql import util as sql_util, expression
from sqlalchemy.orm import (
- SessionExtension, attributes, exc, query, unitofwork, util as mapperutil,
+ SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, state
)
from sqlalchemy.orm.util import object_mapper as _object_mapper
from sqlalchemy.orm.util import class_mapper as _class_mapper
@@ -899,8 +899,8 @@ class Session(object):
self.flush()
def _finalize_loaded(self, states):
- for state in states:
- state.commit_all()
+ for state, dict_ in states.items():
+ state.commit_all(dict_)
def refresh(self, instance, attribute_names=None):
"""Refresh the attributes on the given instance.
@@ -1020,11 +1020,9 @@ class Session(object):
# primary key switch
self.identity_map.remove(state)
state.key = instance_key
-
- if state.key in self.identity_map and not self.identity_map.contains_state(state):
- self.identity_map.remove_key(state.key)
- self.identity_map.add(state)
- state.commit_all()
+
+ self.identity_map.replace(state)
+ state.commit_all(state.dict)
# remove from new last, might be the last strong ref
if state in self._new:
@@ -1213,7 +1211,7 @@ class Session(object):
prop.merge(self, instance, merged, dont_load, _recursive)
if dont_load:
- attributes.instance_state(merged).commit_all() # remove any history
+ attributes.instance_state(merged).commit_all(attributes.instance_dict(merged)) # remove any history
if new_instance:
merged_state._run_on_load(merged)
@@ -1368,7 +1366,7 @@ class Session(object):
self.identity_map.modified = False
return
- flush_context = UOWTransaction(self)
+ flush_context = UOWTransaction(self)
if self.extensions:
for ext in self.extensions:
@@ -1489,7 +1487,7 @@ class Session(object):
return util.IdentitySet(
[state
for state in self.identity_map.all_states()
- if state.check_modified()])
+ if state.modified])
@property
def dirty(self):
@@ -1528,7 +1526,7 @@ class Session(object):
return util.IdentitySet(self._new.values())
-_expire_state = attributes.InstanceState.expire_attributes
+_expire_state = state.InstanceState.expire_attributes
UOWEventHandler = unitofwork.UOWEventHandler
@@ -1548,16 +1546,19 @@ def _cascade_unknown_state_iterator(cascade, state, **kwargs):
yield _state_for_unknown_persistence_instance(o), m
def _state_for_unsaved_instance(instance, create=False):
- manager = attributes.manager_of_class(instance.__class__)
- if manager is None:
+ try:
+ state = attributes.instance_state(instance)
+ except AttributeError:
raise exc.UnmappedInstanceError(instance)
- if manager.has_state(instance):
- state = manager.state_of(instance)
+ if state:
if state.key is not None:
raise sa_exc.InvalidRequestError(
"Instance '%s' is already persistent" %
mapperutil.state_str(state))
elif create:
+ manager = attributes.manager_of_class(instance.__class__)
+ if manager is None:
+ raise exc.UnmappedInstanceError(instance)
state = manager.setup_instance(instance)
else:
raise exc.UnmappedInstanceError(instance)
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
new file mode 100644
index 000000000..c99dfe73c
--- /dev/null
+++ b/lib/sqlalchemy/orm/state.py
@@ -0,0 +1,429 @@
+from sqlalchemy.util import EMPTY_SET
+import weakref
+from sqlalchemy import util
+from sqlalchemy.orm.attributes import PASSIVE_NORESULT, PASSIVE_OFF, NEVER_SET, NO_VALUE, manager_of_class, ATTR_WAS_SET
+from sqlalchemy.orm import attributes
+from sqlalchemy.orm import interfaces
+
+class InstanceState(object):
+ """tracks state information at the instance level."""
+
+ session_id = None
+ key = None
+ runid = None
+ expired_attributes = EMPTY_SET
+ load_options = EMPTY_SET
+ load_path = ()
+ insert_order = None
+ mutable_dict = None
+
+ def __init__(self, obj, manager):
+ self.class_ = obj.__class__
+ self.manager = manager
+ self.obj = weakref.ref(obj, self._cleanup)
+ self.modified = False
+ self.callables = {}
+ self.expired = False
+ self.committed_state = {}
+ self.pending = {}
+ self.parents = {}
+
+ def detach(self):
+ if self.session_id:
+ del self.session_id
+
+ def dispose(self):
+ if self.session_id:
+ del self.session_id
+ del self.obj
+
+ def _cleanup(self, ref):
+ instance_dict = self._instance_dict()
+ if instance_dict:
+ instance_dict.remove(self)
+ self.dispose()
+
+ def obj(self):
+ return None
+
+ @property
+ def dict(self):
+ o = self.obj()
+ if o is not None:
+ return attributes.instance_dict(o)
+ else:
+ return {}
+
+ @property
+ def sort_key(self):
+ return self.key and self.key[1] or (self.insert_order, )
+
+ def check_modified(self):
+ # TODO: deprecate
+ return self.modified
+
+ def initialize_instance(*mixed, **kwargs):
+ self, instance, args = mixed[0], mixed[1], mixed[2:]
+ manager = self.manager
+
+ for fn in manager.events.on_init:
+ fn(self, instance, args, kwargs)
+
+ # LESSTHANIDEAL:
+ # adjust for the case where the InstanceState was created before
+ # mapper compilation, and this actually needs to be a MutableAttrInstanceState
+ if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState:
+ self.__class__ = MutableAttrInstanceState
+ self.obj = weakref.ref(self.obj(), self._cleanup)
+ self.mutable_dict = {}
+
+ try:
+ return manager.events.original_init(*mixed[1:], **kwargs)
+ except:
+ for fn in manager.events.on_init_failure:
+ fn(self, instance, args, kwargs)
+ raise
+
+ def get_history(self, key, **kwargs):
+ return self.manager.get_impl(key).get_history(self, self.dict, **kwargs)
+
+ def get_impl(self, key):
+ return self.manager.get_impl(key)
+
+ def get_pending(self, key):
+ if key not in self.pending:
+ self.pending[key] = PendingCollection()
+ return self.pending[key]
+
+ def value_as_iterable(self, key, passive=PASSIVE_OFF):
+ """return an InstanceState attribute as a list,
+ regardless of it being a scalar or collection-based
+ attribute.
+
+ returns None if passive is not PASSIVE_OFF and the getter returns
+ PASSIVE_NORESULT.
+ """
+
+ impl = self.get_impl(key)
+ dict_ = self.dict
+ x = impl.get(self, dict_, passive=passive)
+ if x is PASSIVE_NORESULT:
+ return None
+ elif hasattr(impl, 'get_collection'):
+ return impl.get_collection(self, dict_, x, passive=passive)
+ elif isinstance(x, list):
+ return x
+ else:
+ return [x]
+
+ def _run_on_load(self, instance):
+ self.manager.events.run('on_load', instance)
+
+ def __getstate__(self):
+ return {'key': self.key,
+ 'committed_state': self.committed_state,
+ 'pending': self.pending,
+ 'parents': self.parents,
+ 'modified': self.modified,
+ 'expired':self.expired,
+ 'load_options':self.load_options,
+ 'load_path':interfaces.serialize_path(self.load_path),
+ 'instance': self.obj(),
+ 'expired_attributes':self.expired_attributes,
+ 'callables': self.callables}
+
+ def __setstate__(self, state):
+ self.committed_state = state['committed_state']
+ self.parents = state['parents']
+ self.key = state['key']
+ self.session_id = None
+ self.pending = state['pending']
+ self.modified = state['modified']
+ self.obj = weakref.ref(state['instance'])
+ self.load_options = state['load_options'] or EMPTY_SET
+ self.load_path = interfaces.deserialize_path(state['load_path'])
+ self.class_ = self.obj().__class__
+ self.manager = manager_of_class(self.class_)
+ self.callables = state['callables']
+ self.runid = None
+ self.expired = state['expired']
+ self.expired_attributes = state['expired_attributes']
+
+ def initialize(self, key):
+ self.manager.get_impl(key).initialize(self, self.dict)
+
+ def set_callable(self, key, callable_):
+ self.dict.pop(key, None)
+ self.callables[key] = callable_
+
+ def __call__(self):
+ """__call__ allows the InstanceState to act as a deferred
+ callable for loading expired attributes, which is also
+ serializable (picklable).
+
+ """
+ unmodified = self.unmodified
+ class_manager = self.manager
+ class_manager.deferred_scalar_loader(self, [
+ attr.impl.key for attr in class_manager.attributes if
+ attr.impl.accepts_scalar_loader and
+ attr.impl.key in self.expired_attributes and
+ attr.impl.key in unmodified
+ ])
+ for k in self.expired_attributes:
+ self.callables.pop(k, None)
+ del self.expired_attributes
+ return ATTR_WAS_SET
+
+ @property
+ def unmodified(self):
+ """a set of keys which have no uncommitted changes"""
+
+ return set(self.manager).difference(self.committed_state)
+
+ @property
+ def unloaded(self):
+ """a set of keys which do not have a loaded value.
+
+ This includes expired attributes and any other attribute that
+ was never populated or modified.
+
+ """
+ return set(
+ key for key in self.manager.iterkeys()
+ if key not in self.committed_state and key not in self.dict)
+
+ def expire_attributes(self, attribute_names):
+ self.expired_attributes = set(self.expired_attributes)
+
+ if attribute_names is None:
+ attribute_names = self.manager.keys()
+ self.expired = True
+ self.modified = False
+ filter_deferred = True
+ else:
+ filter_deferred = False
+ dict_ = self.dict
+
+ for key in attribute_names:
+ impl = self.manager[key].impl
+ if not filter_deferred or \
+ not impl.dont_expire_missing or \
+ key in dict_:
+ self.expired_attributes.add(key)
+ if impl.accepts_scalar_loader:
+ self.callables[key] = self
+ dict_.pop(key, None)
+ self.pending.pop(key, None)
+ self.committed_state.pop(key, None)
+ if self.mutable_dict:
+ self.mutable_dict.pop(key, None)
+
+ def reset(self, key, dict_):
+ """remove the given attribute and any callables associated with it."""
+
+ dict_.pop(key, None)
+ self.callables.pop(key, None)
+
+ def _instance_dict(self):
+ return None
+
+ def _is_really_none(self):
+ return self.obj()
+
+ def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF):
+ needs_committed = attr.key not in self.committed_state
+
+ if needs_committed:
+ if previous is NEVER_SET:
+ if passive:
+ if attr.key in dict_:
+ previous = dict_[attr.key]
+ else:
+ previous = attr.get(self, dict_)
+
+ if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
+ previous = attr.copy(previous)
+
+ if needs_committed:
+ self.committed_state[attr.key] = previous
+
+ self.modified = True
+ self._strong_obj = self.obj()
+
+ instance_dict = self._instance_dict()
+ if instance_dict:
+ instance_dict.modified = True
+
+ def commit(self, dict_, keys):
+ """Commit attributes.
+
+ This is used by a partial-attribute load operation to mark committed
+ those attributes which were refreshed from the database.
+
+ Attributes marked as "expired" can potentially remain "expired" after
+ this step if a value was not populated in state.dict.
+
+ """
+ class_manager = self.manager
+ for key in keys:
+ if key in dict_ and key in class_manager.mutable_attributes:
+ class_manager[key].impl.commit_to_state(self, dict_, self.committed_state)
+ else:
+ self.committed_state.pop(key, None)
+
+ self.expired = False
+ # unexpire attributes which have loaded
+ for key in self.expired_attributes.intersection(keys):
+ if key in dict_:
+ self.expired_attributes.remove(key)
+ self.callables.pop(key, None)
+
+ def commit_all(self, dict_):
+ """commit all attributes unconditionally.
+
+ This is used after a flush() or a full load/refresh
+ to remove all pending state from the instance.
+
+ - all attributes are marked as "committed"
+ - the "strong dirty reference" is removed
+ - the "modified" flag is set to False
+ - any "expired" markers/callables are removed.
+
+ Attributes marked as "expired" can potentially remain "expired" after this step
+ if a value was not populated in state.dict.
+
+ """
+
+ self.committed_state = {}
+ self.pending = {}
+
+ # unexpire attributes which have loaded
+ if self.expired_attributes:
+ for key in self.expired_attributes.intersection(dict_):
+ self.callables.pop(key, None)
+ self.expired_attributes.difference_update(dict_)
+
+ for key in self.manager.mutable_attributes:
+ if key in dict_:
+ self.manager[key].impl.commit_to_state(self, dict_, self.committed_state)
+
+ self.modified = self.expired = False
+ self._strong_obj = None
+
+class MutableAttrInstanceState(InstanceState):
+ def __init__(self, obj, manager):
+ self.mutable_dict = {}
+ InstanceState.__init__(self, obj, manager)
+
+ def _get_modified(self, dict_=None):
+ if self.__dict__.get('modified', False):
+ return True
+ else:
+ if dict_ is None:
+ dict_ = self.dict
+ for key in self.manager.mutable_attributes:
+ if self.manager[key].impl.check_mutable_modified(self, dict_):
+ return True
+ else:
+ return False
+
+ def _set_modified(self, value):
+ self.__dict__['modified'] = value
+
+ modified = property(_get_modified, _set_modified)
+
+ @property
+ def unmodified(self):
+ """a set of keys which have no uncommitted changes"""
+
+ dict_ = self.dict
+ return set(
+ key for key in self.manager.iterkeys()
+ if (key not in self.committed_state or
+ (key in self.manager.mutable_attributes and
+ not self.manager[key].impl.check_mutable_modified(self, dict_))))
+
+ def _is_really_none(self):
+ """do a check modified/resurrect.
+
+ This would be called in the extremely rare
+ race condition that the weakref returned None but
+ the cleanup handler had not yet established the
+ __resurrect callable as its replacement.
+
+ """
+ if self.modified:
+ self.obj = self.__resurrect
+ return self.obj()
+ else:
+ return None
+
+ def reset(self, key, dict_):
+ self.mutable_dict.pop(key, None)
+ InstanceState.reset(self, key, dict_)
+
+ def _cleanup(self, ref):
+ """weakref callback.
+
+ This method may be called by an asynchronous
+ gc.
+
+ If the state shows pending changes, the weakref
+ is replaced by the __resurrect callable which will
+ re-establish an object reference on next access,
+ else removes this InstanceState from the owning
+ identity map, if any.
+
+ """
+ if self._get_modified(self.mutable_dict):
+ self.obj = self.__resurrect
+ else:
+ instance_dict = self._instance_dict()
+ if instance_dict:
+ instance_dict.remove(self)
+ self.dispose()
+
+ def __resurrect(self):
+ """A substitute for the obj() weakref function which resurrects."""
+
+ # store strong ref'ed version of the object; will revert
+ # to weakref when changes are persisted
+
+ obj = self.manager.new_instance(state=self)
+ self.obj = weakref.ref(obj, self._cleanup)
+ self._strong_obj = obj
+ obj.__dict__.update(self.mutable_dict)
+
+ # re-establishes identity attributes from the key
+ self.manager.events.run('on_resurrect', self, obj)
+
+ # TODO: don't really think we should run this here.
+ # resurrect is only meant to preserve the minimal state needed to
+ # do an UPDATE, not to produce a fully usable object
+ self._run_on_load(obj)
+
+ return obj
+
+class PendingCollection(object):
+ """A writable placeholder for an unloaded collection.
+
+ Stores items appended to and removed from a collection that has not yet
+ been loaded. When the collection is loaded, the changes stored in
+ PendingCollection are applied to it to produce the final result.
+
+ """
+ def __init__(self):
+ self.deleted_items = util.IdentitySet()
+ self.added_items = util.OrderedIdentitySet()
+
+ def append(self, value):
+ if value in self.deleted_items:
+ self.deleted_items.remove(value)
+ self.added_items.add(value)
+
+ def remove(self, value):
+ if value in self.added_items:
+ self.added_items.remove(value)
+ self.deleted_items.add(value)
+
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 1aeb311e1..20cbb8f4d 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -115,8 +115,8 @@ class ColumnLoader(LoaderStrategy):
if adapter:
col = adapter.columns[col]
if col in row:
- def new_execute(state, row, **flags):
- state.dict[key] = row[col]
+ def new_execute(state, dict_, row, **flags):
+ dict_[key] = row[col]
if self._should_log_debug:
new_execute = self.debug_callable(new_execute, self.logger,
@@ -125,7 +125,7 @@ class ColumnLoader(LoaderStrategy):
)
return (new_execute, None)
else:
- def new_execute(state, row, isnew, **flags):
+ def new_execute(state, dict_, row, isnew, **flags):
if isnew:
state.expire_attributes([key])
if self._should_log_debug:
@@ -171,15 +171,15 @@ class CompositeColumnLoader(ColumnLoader):
columns = [adapter.columns[c] for c in columns]
for c in columns:
if c not in row:
- def new_execute(state, row, isnew, **flags):
+ def new_execute(state, dict_, row, isnew, **flags):
if isnew:
state.expire_attributes([key])
if self._should_log_debug:
self.logger.debug("%s deferring load" % self)
return (new_execute, None)
else:
- def new_execute(state, row, **flags):
- state.dict[key] = composite_class(*[row[c] for c in columns])
+ def new_execute(state, dict_, row, **flags):
+ dict_[key] = composite_class(*[row[c] for c in columns])
if self._should_log_debug:
new_execute = self.debug_callable(new_execute, self.logger,
@@ -202,13 +202,13 @@ class DeferredColumnLoader(LoaderStrategy):
return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, path, mapper, row, adapter)
elif not self.is_class_level:
- def new_execute(state, row, **flags):
+ def new_execute(state, dict_, row, **flags):
state.set_callable(self.key, LoadDeferredColumns(state, self.key))
else:
- def new_execute(state, row, **flags):
+ def new_execute(state, dict_, row, **flags):
# reset state on the key so that deferred callables
# fire off on next access.
- state.reset(self.key)
+ state.reset(self.key, dict_)
if self._should_log_debug:
new_execute = self.debug_callable(new_execute, self.logger, None,
@@ -340,7 +340,7 @@ class NoLoader(AbstractRelationLoader):
)
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
- def new_execute(state, row, **flags):
+ def new_execute(state, dict_, row, **flags):
self._init_instance_attribute(state)
if self._should_log_debug:
@@ -437,7 +437,7 @@ class LazyLoader(AbstractRelationLoader):
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
if not self.is_class_level:
- def new_execute(state, row, **flags):
+ def new_execute(state, dict_, row, **flags):
# we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
# which will override the class-level behavior.
# this currently only happens when using a "lazyload" option on a "no load" attribute -
@@ -451,11 +451,11 @@ class LazyLoader(AbstractRelationLoader):
return (new_execute, None)
else:
- def new_execute(state, row, **flags):
+ def new_execute(state, dict_, row, **flags):
# we are the primary manager for this attribute on this class - reset its per-instance attribute state,
# so that the class-level lazy loader is executed when next referenced on this instance.
# this is needed in populate_existing() types of scenarios to reset any existing state.
- state.reset(self.key)
+ state.reset(self.key, dict_)
if self._should_log_debug:
new_execute = self.debug_callable(new_execute, self.logger, None,
@@ -735,24 +735,24 @@ class EagerLoader(AbstractRelationLoader):
_instance = self.mapper._instance_processor(context, path + (self.mapper.base_mapper,), eager_adapter)
if not self.uselist:
- def execute(state, row, isnew, **flags):
+ def execute(state, dict_, row, isnew, **flags):
if isnew:
# set a scalar object instance directly on the
# parent object, bypassing InstrumentedAttribute
# event handlers.
- state.dict[key] = _instance(row, None)
+ dict_[key] = _instance(row, None)
else:
# call _instance on the row, even though the object has been created,
# so that we further descend into properties
_instance(row, None)
else:
- def execute(state, row, isnew, **flags):
+ def execute(state, dict_, row, isnew, **flags):
if isnew or (state, key) not in context.attributes:
# appender_key can be absent from context.attributes with isnew=False
# when self-referential eager loading is used; the same instance may be present
# in two distinct sets of result columns
- collection = attributes.init_state_collection(state, key)
+ collection = attributes.init_state_collection(state, dict_, key)
appender = util.UniqueAppender(collection, 'append_without_event')
context.attributes[(state, key)] = appender
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 4ac9c765e..407b702a8 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -96,6 +96,8 @@ class UOWTransaction(object):
# information.
self.attributes = {}
+ self.processors = set()
+
def get_attribute_history(self, state, key, passive=True):
hashkey = ("history", state, key)
@@ -136,6 +138,16 @@ class UOWTransaction(object):
else:
task.append(state, listonly=listonly, isdelete=isdelete)
+ # ensure the mapper for this object has had its
+ # DependencyProcessors added.
+ if mapper not in self.processors:
+ mapper._register_processors(self)
+ self.processors.add(mapper)
+
+ if mapper.base_mapper not in self.processors:
+ mapper.base_mapper._register_processors(self)
+ self.processors.add(mapper.base_mapper)
+
def set_row_switch(self, state):
"""mark a deleted object as a 'row switch'.
@@ -147,7 +159,7 @@ class UOWTransaction(object):
task = self.get_task_by_mapper(mapper)
taskelement = task._objects[state]
taskelement.isdelete = "rowswitch"
-
+
def is_deleted(self, state):
"""return true if the given state is marked as deleted within this UOWTransaction."""
@@ -201,9 +213,9 @@ class UOWTransaction(object):
self.dependencies.add((mapper, dependency))
def register_processor(self, mapper, processor, mapperfrom):
- """register a dependency processor, corresponding to dependencies between
- the two given mappers.
-
+ """register a dependency processor, corresponding to
+ operations which occur between two mappers.
+
"""
# correct for primary mapper
mapper = mapper.primary_mapper()