diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-09-22 16:55:36 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-09-22 16:55:36 +0000 |
| commit | 6b0a907fbdd33b9d9333ec1b72287580a2568d07 (patch) | |
| tree | 89303fa82c0e239ea855ff7f19f0f7a1c8a27e03 /lib/sqlalchemy | |
| parent | 7f6bf93da869a5b59c53d0d10a50da3c23c4b738 (diff) | |
| download | sqlalchemy-6b0a907fbdd33b9d9333ec1b72287580a2568d07.tar.gz | |
- merged sa_entity branch. the big change here is the attributes system
deals primarily with the InstanceState and almost never with the instrumented object
directly. This reduces lookups and complexity since we need the state for just about
everything, now its the one place for everything internally.
we also merged the new weak referencing identity map, which will go out in beta6 and
we'll see how that goes !
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 612 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/collections.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/dependency.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/dynamic.py | 79 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 18 |
9 files changed, 495 insertions, 309 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 559b97e5a..3f08d8871 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -4,8 +4,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import weakref - +import weakref, threading +import UserDict from sqlalchemy import util from sqlalchemy.orm import util as orm_util, interfaces, collections from sqlalchemy.orm.mapper import class_mapper @@ -17,10 +17,51 @@ ATTR_WAS_SET = object() NO_VALUE = object() class InstrumentedAttribute(interfaces.PropComparator): - """attribute access for instrumented classes.""" + """public-facing instrumented attribute.""" - def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs): + def __init__(self, impl, comparator=None): """Construct an InstrumentedAttribute. + comparator + a sql.Comparator to which class-level compare/math events will be sent + """ + + self.impl = impl + self.comparator = comparator + + def __set__(self, obj, value): + self.impl.set(obj._state, value, None) + + def __delete__(self, obj): + self.impl.delete(obj._state) + + def __get__(self, obj, owner): + if obj is None: + return self + return self.impl.get(obj._state) + + def clause_element(self): + return self.comparator.clause_element() + + def expression_element(self): + return self.comparator.expression_element() + + def operate(self, op, *other, **kwargs): + return op(self.comparator, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + return op(other, self.comparator, **kwargs) + + def hasparent(self, instance, optimistic=False): + return self.impl.hasparent(instance._state, optimistic=optimistic) + + property = property(lambda s: class_mapper(s.impl.class_).get_property(s.impl.key), + doc="the MapperProperty object associated with this attribute") + +class AttributeImpl(object): + """internal implementation for instrumented attributes.""" + + def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, **kwargs): + """Construct an AttributeImpl. class_ the class to be instrumented. @@ -53,8 +94,6 @@ class InstrumentedAttribute(interfaces.PropComparator): and need to be compared against a copy of their original contents in order to detect changes on the parent instance - comparator - a sql.Comparator to which class-level compare/math events will be sent """ @@ -64,47 +103,23 @@ class InstrumentedAttribute(interfaces.PropComparator): self.callable_ = callable_ self.trackparent = trackparent self.mutable_scalars = mutable_scalars - self.comparator = comparator self.copy = None if compare_function is None: self.is_equal = lambda x,y: x == y else: self.is_equal = compare_function self.extensions = util.to_list(extension or []) - - def __set__(self, obj, value): - self.set(obj, value, None) - - def __delete__(self, obj): - self.delete(None, obj) - - def __get__(self, obj, owner): - if obj is None: - return self - return self.get(obj) - - def commit_to_state(self, state, obj, value=NO_VALUE): + + def commit_to_state(self, state, value=NO_VALUE): """commit the object's current state to its 'committed' state.""" if value is NO_VALUE: - if self.key in obj.__dict__: - value = obj.__dict__[self.key] + if self.key in state.dict: + value = state.dict[self.key] if value is not NO_VALUE: state.committed_state[self.key] = self.copy(value) - def clause_element(self): - return self.comparator.clause_element() - - def expression_element(self): - return self.comparator.expression_element() - - def operate(self, op, *other, **kwargs): - return op(self.comparator, *other, **kwargs) - - def reverse_operate(self, op, other, **kwargs): - return op(other, self.comparator, **kwargs) - - def hasparent(self, item, optimistic=False): + def hasparent(self, state, optimistic=False): """Return the boolean value of a `hasparent` flag attached to the given item. The `optimistic` flag determines what the default return value @@ -119,32 +134,23 @@ class InstrumentedAttribute(interfaces.PropComparator): will also not have a `hasparent` flag. """ - return item._state.parents.get(id(self), optimistic) + return state.parents.get(id(self), optimistic) - def sethasparent(self, item, value): + def sethasparent(self, state, value): """Set a boolean flag on the given item corresponding to whether or not it is attached to a parent object via the attribute represented by this ``InstrumentedAttribute``. """ - item._state.parents[id(self)] = value - - def get_history(self, obj, passive=False): - """Return a new ``AttributeHistory`` object for the given object/this attribute's key. - - If `passive` is True, then don't execute any callables; if the - attribute's value can only be achieved via executing a - callable, then return None. - """ + state.parents[id(self)] = value - # get the current state. this may trigger a lazy load if - # passive is False. - current = self.get(obj, passive=passive) + def get_history(self, state, passive=False): + current = self.get(state, passive=passive) if current is PASSIVE_NORESULT: return None - return AttributeHistory(self, obj, current, passive=passive) - - def set_callable(self, obj, callable_, clear=False): + return AttributeHistory(self, state, current, passive=passive) + + def set_callable(self, state, callable_, clear=False): """Set a callable function for this attribute on the given object. This callable will be executed when the attribute is next @@ -160,22 +166,22 @@ class InstrumentedAttribute(interfaces.PropComparator): """ if clear: - self.clear(obj) + self.clear(state) if callable_ is None: - self.initialize(obj) + self.initialize(state) else: - obj._state.callables[self] = callable_ + state.callables[self] = callable_ - def _get_callable(self, obj): - if self in obj._state.callables: - return obj._state.callables[self] + def _get_callable(self, state): + if self in state.callables: + return state.callables[self] elif self.callable_ is not None: - return self.callable_(obj) + return self.callable_(state.obj()) else: return None - def reset(self, obj): + def reset(self, state): """Remove any per-instance callable functions corresponding to this ``InstrumentedAttribute``'s attribute from the given object, and remove this ``InstrumentedAttribute``'s attribute @@ -183,12 +189,12 @@ class InstrumentedAttribute(interfaces.PropComparator): """ try: - del obj._state.callables[self] + del state.callables[self] except KeyError: pass - self.clear(obj) + self.clear(state) - def clear(self, obj): + def clear(self, state): """Remove this ``InstrumentedAttribute``'s attribute from the given object's dictionary. Subsequent calls to ``getattr(obj, key)`` will raise an @@ -196,20 +202,20 @@ class InstrumentedAttribute(interfaces.PropComparator): """ try: - del obj.__dict__[self.key] + del state.dict[self.key] except KeyError: pass - def check_mutable_modified(self, obj): + def check_mutable_modified(self, state): return False - def initialize(self, obj): + def initialize(self, state): """Initialize this attribute on the given object instance with an empty value.""" - obj.__dict__[self.key] = None + state.dict[self.key] = None return None - def get(self, obj, passive=False): + def get(self, state, passive=False): """Retrieve a value from the given object. If a callable is assembled on this object's attribute, and @@ -218,40 +224,37 @@ class InstrumentedAttribute(interfaces.PropComparator): """ try: - return obj.__dict__[self.key] + return state.dict[self.key] except KeyError: - state = obj._state # if an instance-wide "trigger" was set, call that # and start again if state.trigger: state.call_trigger() - return self.get(obj, passive=passive) + return self.get(state, passive=passive) - callable_ = self._get_callable(obj) + callable_ = self._get_callable(state) if callable_ is not None: if passive: return PASSIVE_NORESULT - self.logger.debug("Executing lazy callable on %s.%s" % - (orm_util.instance_str(obj), self.key)) value = callable_() if value is not ATTR_WAS_SET: - return self.set_committed_value(obj, value) + return self.set_committed_value(state, value) else: - return obj.__dict__[self.key] + return state.dict[self.key] else: # Return a new, empty value - return self.initialize(obj) + return self.initialize(state) - def append(self, obj, value, initiator): - self.set(obj, value, initiator) + def append(self, state, value, initiator): + self.set(state, value, initiator) - def remove(self, obj, value, initiator): - self.set(obj, None, initiator) + def remove(self, state, value, initiator): + self.set(state, None, initiator) - def set(self, obj, value, initiator): + def set(self, state, value, initiator): raise NotImplementedError() - def set_committed_value(self, obj, value): + def set_committed_value(self, state, value): """set an attribute value on the given instance and 'commit' it. this indicates that the given value is the "persisted" value, @@ -262,53 +265,51 @@ class InstrumentedAttribute(interfaces.PropComparator): to set object attributes after the initial load. """ - state = obj._state if state.committed_state is not None: - self.commit_to_state(state, obj, value) + self.commit_to_state(state, value) # remove per-instance callable, if any state.callables.pop(self, None) - obj.__dict__[self.key] = value + state.dict[self.key] = value return value - def set_raw_value(self, obj, value): - obj.__dict__[self.key] = value + def set_raw_value(self, state, value): + state.dict[self.key] = value return value - def fire_append_event(self, obj, value, initiator): - obj._state.modified = True + def fire_append_event(self, state, value, initiator): + state.modified = True if self.trackparent and value is not None: - self.sethasparent(value, True) + self.sethasparent(value._state, True) + obj = state.obj() for ext in self.extensions: ext.append(obj, value, initiator or self) - def fire_remove_event(self, obj, value, initiator): - obj._state.modified = True + def fire_remove_event(self, state, value, initiator): + state.modified = True if self.trackparent and value is not None: - self.sethasparent(value, False) + self.sethasparent(value._state, False) + obj = state.obj() for ext in self.extensions: ext.remove(obj, value, initiator or self) - def fire_replace_event(self, obj, value, previous, initiator): - obj._state.modified = True + def fire_replace_event(self, state, value, previous, initiator): + state.modified = True if self.trackparent: if value is not None: - self.sethasparent(value, True) + self.sethasparent(value._state, True) if previous is not None: - self.sethasparent(previous, False) + self.sethasparent(previous._state, False) + obj = state.obj() for ext in self.extensions: ext.set(obj, value, previous, initiator or self) - property = property(lambda s: class_mapper(s.class_).get_property(s.key), - doc="the MapperProperty object associated with this attribute") - -InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute) -class InstrumentedScalarAttribute(InstrumentedAttribute): +class ScalarAttributeImpl(AttributeImpl): """represents a scalar-holding InstrumentedAttribute.""" def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs): - super(InstrumentedScalarAttribute, self).__init__(class_, manager, key, + super(ScalarAttributeImpl, self).__init__(class_, manager, key, callable_, trackparent=trackparent, extension=extension, compare_function=compare_function, **kwargs) self.mutable_scalars = mutable_scalars @@ -322,23 +323,23 @@ class InstrumentedScalarAttribute(InstrumentedAttribute): # is passed return item - def __delete__(self, obj): - old = self.get(obj) - del obj.__dict__[self.key] - self.fire_remove_event(obj, old, self) + def delete(self, state): + old = self.get(state) + del state.dict[self.key] + self.fire_remove_event(state, old, self) - def check_mutable_modified(self, obj): + def check_mutable_modified(self, state): if self.mutable_scalars: - h = self.get_history(obj, passive=True) + h = self.get_history(state, passive=True) if h is not None and h.is_modified(): - obj._state.modified = True + state.modified = True return True else: return False else: return False - def set(self, obj, value, initiator): + def set(self, state, value, initiator): """Set a value on the given object. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -349,19 +350,18 @@ class InstrumentedScalarAttribute(InstrumentedAttribute): if initiator is self: return - state = obj._state # if an instance-wide "trigger" was set, call that if state.trigger: state.call_trigger() - old = self.get(obj) - obj.__dict__[self.key] = value - self.fire_replace_event(obj, value, old, initiator) + old = self.get(state) + state.dict[self.key] = value + self.fire_replace_event(state, value, old, initiator) type = property(lambda self: self.property.columns[0].type) -class InstrumentedCollectionAttribute(InstrumentedAttribute): +class CollectionAttributeImpl(AttributeImpl): """A collection-holding attribute that instruments changes in membership. InstrumentedCollectionAttribute holds an arbitrary, user-specified @@ -371,7 +371,7 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): """ def __init__(self, class_, manager, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): - super(InstrumentedCollectionAttribute, self).__init__(class_, manager, + super(CollectionAttributeImpl, self).__init__(class_, manager, key, callable_, trackparent=trackparent, extension=extension, compare_function=compare_function, **kwargs) @@ -389,53 +389,36 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): def __copy(self, item): return [y for y in list(collections.collection_adapter(item))] - def __set__(self, obj, value): - """Replace the current collection with a new one.""" - - setting_type = util.duck_type_collection(value) - - if value is None or setting_type != self.collection_interface: - raise exceptions.ArgumentError( - "Incompatible collection type on assignment: %s is not %s-like" % - (type(value).__name__, self.collection_interface.__name__)) - - if hasattr(value, '_sa_adapter'): - self.set(obj, list(getattr(value, '_sa_adapter')), None) - elif setting_type == dict: - self.set(obj, value.values(), None) - else: - self.set(obj, value, None) - - def __delete__(self, obj): - if self.key not in obj.__dict__: + def delete(self, state): + if self.key not in state.dict: return - obj._state.modified = True + state.modified = True - collection = self.get_collection(obj) + collection = self.get_collection(state) collection.clear_with_event() - del obj.__dict__[self.key] + del state.dict[self.key] - def initialize(self, obj): + def initialize(self, state): """Initialize this attribute on the given object instance with an empty collection.""" - _, user_data = self._build_collection(obj) - obj.__dict__[self.key] = user_data + _, user_data = self._build_collection(state) + state.dict[self.key] = user_data return user_data - def append(self, obj, value, initiator): + def append(self, state, value, initiator): if initiator is self: return - collection = self.get_collection(obj) + collection = self.get_collection(state) collection.append_with_event(value, initiator) - def remove(self, obj, value, initiator): + def remove(self, state, value, initiator): if initiator is self: return - collection = self.get_collection(obj) + collection = self.get_collection(state) collection.remove_with_event(value, initiator) - def set(self, obj, value, initiator): + def set(self, state, value, initiator): """Set a value on the given object. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -446,19 +429,30 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): if initiator is self: return - state = obj._state + setting_type = util.duck_type_collection(value) + + if value is None or setting_type != self.collection_interface: + raise exceptions.ArgumentError( + "Incompatible collection type on assignment: %s is not %s-like" % + (type(value).__name__, self.collection_interface.__name__)) + + if hasattr(value, '_sa_adapter'): + value = list(getattr(value, '_sa_adapter')) + elif setting_type == dict: + value = value.values() + # if an instance-wide "trigger" was set, call that if state.trigger: state.call_trigger() - old = self.get(obj) - old_collection = self.get_collection(obj, old) + old = self.get(state) + old_collection = self.get_collection(state, old) - new_collection, user_data = self._build_collection(obj) - self._load_collection(obj, value or [], emit_events=True, + new_collection, user_data = self._build_collection(state) + self._load_collection(state, value or [], emit_events=True, collection=new_collection) - obj.__dict__[self.key] = user_data + state.dict[self.key] = user_data state.modified = True # mark all the old elements as detached from the parent @@ -466,30 +460,28 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): old_collection.clear_with_event() old_collection.unlink(old) - def set_committed_value(self, obj, value): + def set_committed_value(self, state, value): """Set an attribute value on the given instance and 'commit' it.""" - - state = obj._state - collection, user_data = self._build_collection(obj) - self._load_collection(obj, value or [], emit_events=False, + collection, user_data = self._build_collection(state) + self._load_collection(state, value or [], emit_events=False, collection=collection) value = user_data if state.committed_state is not None: - self.commit_to_state(state, obj, value) + self.commit_to_state(state, value) # remove per-instance callable, if any state.callables.pop(self, None) - obj.__dict__[self.key] = value + state.dict[self.key] = value return value - def _build_collection(self, obj): + def _build_collection(self, state): user_data = self.collection_factory() - collection = collections.CollectionAdapter(self, obj, user_data) + collection = collections.CollectionAdapter(self, state, user_data) return collection, user_data - def _load_collection(self, obj, values, emit_events=True, collection=None): - collection = collection or self.get_collection(obj) + def _load_collection(self, state, values, emit_events=True, collection=None): + collection = collection or self.get_collection(state) if values is None: return elif emit_events: @@ -499,13 +491,13 @@ class InstrumentedCollectionAttribute(InstrumentedAttribute): for item in values: collection.append_without_event(item) - def get_collection(self, obj, user_data=None): + def get_collection(self, state, user_data=None): if user_data is None: - user_data = self.get(obj) + user_data = self.get(state) try: return getattr(user_data, '_sa_adapter') except AttributeError: - collections.CollectionAdapter(self, obj, user_data) + collections.CollectionAdapter(self, state, user_data) return getattr(user_data, '_sa_adapter') @@ -525,33 +517,83 @@ class GenericBackrefExtension(interfaces.AttributeExtension): if oldchild is child: return if oldchild is not None: - getattr(oldchild.__class__, self.key).remove(oldchild, obj, initiator) + getattr(oldchild.__class__, self.key).impl.remove(oldchild._state, obj, initiator) if child is not None: - getattr(child.__class__, self.key).append(child, obj, initiator) + getattr(child.__class__, self.key).impl.append(child._state, obj, initiator) def append(self, obj, child, initiator): - getattr(child.__class__, self.key).append(child, obj, initiator) + getattr(child.__class__, self.key).impl.append(child._state, obj, initiator) def remove(self, obj, child, initiator): - getattr(child.__class__, self.key).remove(child, obj, initiator) + getattr(child.__class__, self.key).impl.remove(child._state, obj, initiator) class InstanceState(object): """tracks state information at the instance level.""" + + __slots__ = 'class_', 'obj', 'dict', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj' def __init__(self, obj): + self.class_ = obj.__class__ + self.obj = weakref.ref(obj, self.__cleanup) + self.dict = obj.__dict__ self.committed_state = None self.modified = False self.trigger = None self.callables = {} self.parents = {} + self.instance_dict = None + + def __cleanup(self, ref): + if self.instance_dict is None or self.instance_dict() is None: + return + + instance_dict = self.instance_dict() + + # the mutexing here is based on the assumption that gc.collect() + # may be firing off cleanup handlers in a different thread than that + # which is normally operating upon the instance dict. + instance_dict._mutex.acquire() + try: + # if instance_dict de-refed us, or it called our + # _resurrect, return + if self.instance_dict is None or self.instance_dict() is None or self.obj() is not None: + return + + self.__resurrect(instance_dict) + finally: + instance_dict._mutex.release() + def _check_resurrect(self, instance_dict): + instance_dict._mutex.acquire() + try: + return self.obj() or self.__resurrect(instance_dict) + finally: + instance_dict._mutex.release() + + def __resurrect(self, instance_dict): + if self.modified or self.class_._sa_attribute_manager._is_modified(self): + # store strong ref'ed version of the object; will revert + # to weakref when changes are persisted + obj = self.class_._sa_attribute_manager.new_instance(self.class_, state=self) + self.obj = weakref.ref(obj, self.__cleanup) + self._strong_obj = obj + obj.__dict__.update(self.dict) + self.dict = obj.__dict__ + return obj + else: + del instance_dict[self.dict['_instance_key']] + return None + def __getstate__(self): - return {'committed_state':self.committed_state, 'parents':self.parents, 'modified':self.modified} + return {'committed_state':self.committed_state, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj()} def __setstate__(self, state): self.committed_state = state['committed_state'] self.parents = state['parents'] self.modified = state['modified'] + self.obj = weakref.ref(state['instance']) + self.class_ = self.obj().__class__ + self.dict = self.obj().__dict__ self.callables = {} self.trigger = None @@ -564,38 +606,154 @@ class InstanceState(object): self.committed_state = {} self.modified = False for attr in manager.managed_attributes(obj.__class__): - attr.commit_to_state(self, obj) - + attr.impl.commit_to_state(self) + # remove strong ref + self._strong_obj = None + def rollback(self, manager, obj): if not self.committed_state: manager._clear(obj) else: for attr in manager.managed_attributes(obj.__class__): - if attr.key in self.committed_state: - if not hasattr(attr, 'get_collection'): - obj.__dict__[attr.key] = self.committed_state[attr.key] + if attr.impl.key in self.committed_state: + if not hasattr(attr.impl, 'get_collection'): + obj.__dict__[attr.impl.key] = self.committed_state[attr.impl.key] else: - collection = attr.get_collection(obj) + collection = attr.impl.get_collection(self) collection.clear_without_event() - for item in self.committed_state[attr.key]: + for item in self.committed_state[attr.impl.key]: collection.append_without_event(item) else: - if attr.key in obj.__dict__: - del obj.__dict__[attr.key] + if attr.impl.key in self.dict: + del self.dict[attr.impl.key] + +class InstanceDict(UserDict.UserDict): + """similar to WeakValueDictionary, but wired towards 'state' objects.""" + + def __init__(self, *args, **kw): + self._wr = weakref.ref(self) + # RLock because the mutex is used by a cleanup + # handler, which can be called at any time (including within an already mutexed block) + self._mutex = threading.RLock() + UserDict.UserDict.__init__(self, *args, **kw) + + def __getitem__(self, key): + state = self.data[key] + o = state.obj() or state._check_resurrect(self) + if o is None: + raise KeyError, key + return o + + def __contains__(self, key): + try: + state = self.data[key] + o = state.obj() or state._check_resurrect(self) + except KeyError: + return False + return o is not None + def has_key(self, key): + return key in self + + def __repr__(self): + return "<InstanceDict at %s>" % id(self) + + def __setitem__(self, key, value): + if key in self.data: + self._mutex.acquire() + try: + if key in self.data: + self.data[key].instance_dict = None + finally: + self._mutex.release() + self.data[key] = value._state + value._state.instance_dict = self._wr + + def __delitem__(self, key): + state = self.data[key] + state.instance_dict = None + del self.data[key] + + def get(self, key, default=None): + try: + state = self.data[key] + except KeyError: + return default + else: + o = state.obj() + if o is None: + # This should only happen + return default + else: + return o + + def items(self): + L = [] + for key, state in self.data.items(): + o = state.obj() + if o is not None: + L.append((key, o)) + return L + + def iteritems(self): + for state in self.data.itervalues(): + value = state.obj() + if value is not None: + yield value._instance_key, value + + def iterkeys(self): + return self.data.iterkeys() + + def __iter__(self): + return self.data.iterkeys() + + def __len__(self): + return len(self.values()) + + def itervalues(self): + for state in self.data.itervalues(): + obj = state.obj() + if obj is not None: + yield obj + + def values(self): + L = [] + for state in self.data.values(): + o = state.obj() + if o is not None: + L.append(o) + return L + + def popitem(self): + raise NotImplementedError() + + def pop(self, key, *args): + raise NotImplementedError() + + def setdefault(self, key, default=None): + raise NotImplementedError() + + def update(self, dict=None, **kwargs): + raise NotImplementedError() + + def copy(self): + raise NotImplementedError() + + + class AttributeHistory(object): """Calculate the *history* of a particular attribute on a particular instance. """ - def __init__(self, attr, obj, current, passive=False): + def __init__(self, attr, state, current, passive=False): self.attr = attr # get the "original" value. if a lazy load was fired when we got # the 'current' value, this "original" was also populated just # now as well (therefore we have to get it second) - if obj._state.committed_state: - original = obj._state.committed_state.get(attr.key, None) + if state.committed_state: + original = state.committed_state.get(attr.key, None) else: original = None @@ -606,7 +764,7 @@ class AttributeHistory(object): self._unchanged_items = [] self._deleted_items = [] if current: - collection = attr.get_collection(obj, current) + collection = attr.get_collection(state, current) for a in collection: if a in s: self._unchanged_items.append(a) @@ -667,7 +825,7 @@ class AttributeManager(object): def _clear(self, obj): for attr in self.managed_attributes(obj.__class__): try: - del obj.__dict__[attr.key] + del obj.__dict__[attr.impl.key] except KeyError: pass @@ -683,6 +841,7 @@ class AttributeManager(object): """ try: + # TODO: move this collection onto the class itself? return self._inherited_attribute_cache[class_] except KeyError: if not isinstance(class_, type): @@ -693,6 +852,7 @@ class AttributeManager(object): def noninherited_managed_attributes(self, class_): try: + # TODO: move this collection onto the class itself? return self._noninherited_attribute_cache[class_] except KeyError: if not isinstance(class_, type): @@ -701,23 +861,25 @@ class AttributeManager(object): self._noninherited_attribute_cache[class_] = noninherited return noninherited - def is_modified(self, object): - if object._state.modified: + def is_modified(self, obj): + return self._is_modified(obj._state) + + def _is_modified(self, state): + if state.modified: return True else: - for attr in self.managed_attributes(object.__class__): - if attr.check_mutable_modified(object): + for attr in self.managed_attributes(state.class_): + if attr.impl.check_mutable_modified(state): return True else: return False - - + def get_history(self, obj, key, **kwargs): """Return a new ``AttributeHistory`` object for the given attribute on the given object. """ - return getattr(obj.__class__, key).get_history(obj, **kwargs) + return getattr(obj.__class__, key).impl.get_history(obj._state, **kwargs) def get_as_list(self, obj, key, passive=False): """Return an attribute of the given name from the given object. @@ -729,12 +891,13 @@ class AttributeManager(object): callable, the callable will only be executed if the given `passive` flag is False. """ - attr = getattr(obj.__class__, key) - x = attr.get(obj, passive=passive) + attr = getattr(obj.__class__, key).impl + state = obj._state + x = attr.get(state, passive=passive) if x is PASSIVE_NORESULT: return [] elif hasattr(attr, 'get_collection'): - return list(attr.get_collection(obj, x)) + return list(attr.get_collection(state, x)) elif isinstance(x, list): return x else: @@ -774,7 +937,7 @@ class AttributeManager(object): """ attr = getattr(obj.__class__, key) - attr.reset(obj) + attr.impl.reset(obj._state) def is_class_managed(self, class_, key): """Return True if the given `key` correponds to an @@ -782,6 +945,9 @@ class AttributeManager(object): """ return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute) + def has_parent(self, class_, obj, key, optimistic=False): + return getattr(class_, key).impl.hasparent(obj._state, optimistic=optimistic) + def init_instance_attribute(self, obj, key, callable_=None, clear=False): """Initialize an attribute on an instance to either a blank value, cancelling out any class- or instance-level callables @@ -789,9 +955,9 @@ class AttributeManager(object): callable to be invoked when the attribute is next accessed. """ - getattr(obj.__class__, key).set_callable(obj, callable_, clear=clear) + getattr(obj.__class__, key).impl.set_callable(obj._state, callable_, clear=clear) - def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs): + def _create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs): """Create a scalar property object, defaulting to ``InstrumentedAttribute``, which will communicate change events back to this ``AttributeManager``. @@ -799,35 +965,28 @@ class AttributeManager(object): if kwargs.pop('dynamic', False): from sqlalchemy.orm import dynamic - return dynamic.DynamicCollectionAttribute(class_, self, key, typecallable, **kwargs) + return dynamic.DynamicAttributeImpl(class_, self, key, typecallable, **kwargs) elif uselist: - return InstrumentedCollectionAttribute(class_, self, key, + return CollectionAttributeImpl(class_, self, key, callable_, typecallable, **kwargs) else: - return InstrumentedScalarAttribute(class_, self, key, callable_, + return ScalarAttributeImpl(class_, self, key, callable_, **kwargs) - def get_attribute(self, obj_or_cls, key): - """Register an attribute at the class level to be instrumented - for all instances of the class. - """ - - if isinstance(obj_or_cls, type): - return getattr(obj_or_cls, key) - else: - return getattr(obj_or_cls.__class__, key) - def manage(self, obj): if not hasattr(obj, '_state'): obj._state = InstanceState(obj) - def new_instance(self, class_): + def new_instance(self, class_, state=None): """create a new instance of class_ without its __init__() method being called.""" s = class_.__new__(class_) - s._state = InstanceState(s) + if state: + s._state = state + else: + s._state = InstanceState(s) return s def register_class(self, class_, extra_init=None, on_exception=None): @@ -841,7 +1000,8 @@ class AttributeManager(object): oldinit = None doinit = False - + class_._sa_attribute_manager = self + def init(instance, *args, **kwargs): instance._state = InstanceState(instance) @@ -884,7 +1044,7 @@ class AttributeManager(object): delattr(class_, '__init__') for attr in self.noninherited_managed_attributes(class_): - delattr(class_, attr.key) + delattr(class_, attr.impl.key) self._inherited_attribute_cache.pop(class_,None) self._noninherited_attribute_cache.pop(class_,None) @@ -901,11 +1061,19 @@ class AttributeManager(object): typecallable = kwargs.pop('typecallable', None) if isinstance(typecallable, InstrumentedAttribute): typecallable = None - setattr(class_, key, self.create_prop(class_, key, uselist, callable_, - typecallable=typecallable, **kwargs)) + comparator = kwargs.pop('comparator', None) + setattr(class_, key, InstrumentedAttribute(self._create_prop(class_, key, uselist, callable_, + typecallable=typecallable, **kwargs), comparator=comparator)) + + def set_raw_value(self, instance, key, value): + getattr(instance.__class__, key).impl.set_raw_value(instance._state, value) + + def set_committed_value(self, instance, key, value): + getattr(instance.__class__, key).impl.set_committed_value(instance._state, value) def init_collection(self, instance, key): """Initialize a collection attribute and return the collection adapter.""" - attr = self.get_attribute(instance, key) - user_data = attr.initialize(instance) - return attr.get_collection(instance, user_data) + attr = getattr(instance.__class__, key).impl + state = instance._state + user_data = attr.initialize(state) + return attr.get_collection(state, user_data) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index e08a4b8b7..bf365d267 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -428,14 +428,12 @@ class CollectionAdapter(object): entity collections. """ - def __init__(self, attr, owner, data): + def __init__(self, attr, owner_state, data): self.attr = attr - self._owner = weakref.ref(owner) self._data = weakref.ref(data) + self.owner_state = owner_state self.link_to_self(data) - owner = property(lambda s: s._owner(), - doc="The object that owns the entity collection.") data = property(lambda s: s._data(), doc="The entity collection being adapted.") @@ -507,7 +505,7 @@ class CollectionAdapter(object): """ if initiator is not False and item is not None: - self.attr.fire_append_event(self._owner(), item, initiator) + self.attr.fire_append_event(self.owner_state, item, initiator) def fire_remove_event(self, item, initiator=None): """Notify that a entity has entered the collection. @@ -518,16 +516,16 @@ class CollectionAdapter(object): """ if initiator is not False and item is not None: - self.attr.fire_remove_event(self._owner(), item, initiator) + self.attr.fire_remove_event(self.owner_state, item, initiator) def __getstate__(self): return { 'key': self.attr.key, - 'owner': self.owner, + 'owner_state': self.owner_state, 'data': self.data } def __setstate__(self, d): - self.attr = getattr(d['owner'].__class__, d['key']) - self._owner = weakref.ref(d['owner']) + self.attr = getattr(d['owner_state'].obj().__class__, d['key']).impl + self.owner_state = d['owner_state'] self._data = weakref.ref(d['data']) @@ -787,7 +785,7 @@ def _list_decorators(): if _sa_initiator is not False and item is not None: executor = getattr(self, '_sa_adapter', None) if executor: - executor.attr.fire_append_event(executor._owner(), + executor.attr.fire_append_event(executor.owner_state, item, _sa_initiator) fn(self, item) _tidy(append) diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 1e461e6bf..a1669e32f 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -55,7 +55,8 @@ class DependencyProcessor(object): """return True if the given object instance has a parent, according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``.""" - return self._get_instrumented_attribute().hasparent(obj) + # TODO: use correct API for this + return self._get_instrumented_attribute().impl.hasparent(obj._state) def register_dependencies(self, uowcommit): """Tell a ``UOWTransaction`` what mappers are dependent on diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index aa5105150..1b91bd977 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -6,32 +6,30 @@ from sqlalchemy.orm import attributes, object_session from sqlalchemy.orm.query import Query from sqlalchemy.orm.mapper import has_identity, object_mapper -class DynamicCollectionAttribute(attributes.InstrumentedAttribute): +class DynamicAttributeImpl(attributes.AttributeImpl): def __init__(self, class_, attribute_manager, key, typecallable, target_mapper, **kwargs): - super(DynamicCollectionAttribute, self).__init__(class_, attribute_manager, key, typecallable, **kwargs) + super(DynamicAttributeImpl, self).__init__(class_, attribute_manager, key, typecallable, **kwargs) self.target_mapper = target_mapper - def get(self, obj, passive=False): + def get(self, state, passive=False): if passive: - return self.get_history(obj, passive=True).added_items() + return self.get_history(state, passive=True).added_items() else: - return AppenderQuery(self, obj) + return AppenderQuery(self, state) - def commit_to_state(self, state, obj, value=attributes.NO_VALUE): + def commit_to_state(self, state, value=attributes.NO_VALUE): # we have our own AttributeHistory therefore dont need CommittedState # instead, we reset the history stored on the attribute - obj.__dict__[self.key] = CollectionHistory(self, obj) + state.dict[self.key] = CollectionHistory(self, state) - def get_collection(self, obj, user_data=None): - return self.get_history(obj)._added_items + def get_collection(self, state, user_data=None): + return self.get_history(state)._added_items - def set(self, obj, value, initiator): + def set(self, state, value, initiator): if initiator is self: return - state = obj._state - - old_collection = self.get(obj).assign(value) + old_collection = self.get(state).assign(value) # TODO: emit events ??? state.modified = True @@ -39,35 +37,36 @@ class DynamicCollectionAttribute(attributes.InstrumentedAttribute): def delete(self, *args, **kwargs): raise NotImplementedError() - def get_history(self, obj, passive=False): + def get_history(self, state, passive=False): try: - return obj.__dict__[self.key] + return state.dict[self.key] except KeyError: - obj.__dict__[self.key] = c = CollectionHistory(self, obj) + state.dict[self.key] = c = CollectionHistory(self, state) return c - def append(self, obj, value, initiator): + def append(self, state, value, initiator): if initiator is not self: - self.get_history(obj)._added_items.append(value) - self.fire_append_event(obj, value, self) + self.get_history(state)._added_items.append(value) + self.fire_append_event(state, value, self) - def remove(self, obj, value, initiator): + def remove(self, state, value, initiator): if initiator is not self: - self.get_history(obj)._deleted_items.append(value) - self.fire_remove_event(obj, value, self) + self.get_history(state)._deleted_items.append(value) + self.fire_remove_event(state, value, self) class AppenderQuery(Query): - def __init__(self, attr, instance): + def __init__(self, attr, state): super(AppenderQuery, self).__init__(attr.target_mapper, None) - self.instance = instance + self.state = state self.attr = attr def __session(self): - sess = object_session(self.instance) - if sess is not None and self.instance in sess and sess.autoflush: + instance = self.state.obj() + sess = object_session(instance) + if sess is not None and instance in sess and sess.autoflush: sess.flush() - if not has_identity(self.instance): + if not has_identity(instance): return None else: return sess @@ -75,21 +74,21 @@ class AppenderQuery(Query): def __len__(self): sess = self.__session() if sess is None: - return len(self.attr.get_history(self.instance)._added_items) + return len(self.attr.get_history(self.state)._added_items) else: return self._clone(sess).count() def __iter__(self): sess = self.__session() if sess is None: - return iter(self.attr.get_history(self.instance)._added_items) + return iter(self.attr.get_history(self.state)._added_items) else: return iter(self._clone(sess)) def __getitem__(self, index): sess = self.__session() if sess is None: - return self.attr.get_history(self.instance)._added_items.__getitem__(index) + return self.attr.get_history(self.state)._added_items.__getitem__(index) else: return self._clone(sess).__getitem__(index) @@ -97,39 +96,41 @@ class AppenderQuery(Query): # note we're returning an entirely new Query class instance here # without any assignment capabilities; # the class of this query is determined by the session. + instance = self.state.obj() if sess is None: - sess = object_session(self.instance) + sess = object_session(instance) if sess is None: try: - sess = object_mapper(self.instance).get_session() + sess = object_mapper(instance).get_session() except exceptions.InvalidRequestError: raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (self.instance.__class__, self.key)) - return sess.query(self.attr.target_mapper).with_parent(self.instance) + return sess.query(self.attr.target_mapper).with_parent(instance) def assign(self, collection): - if has_identity(self.instance): + instance = self.state.obj() + if has_identity(instance): oldlist = list(self) else: oldlist = [] - self.attr.get_history(self.instance).replace(oldlist, collection) + self.attr.get_history(self.state).replace(oldlist, collection) return oldlist def append(self, item): - self.attr.append(self.instance, item, None) + self.attr.append(self.state, item, None) def remove(self, item): - self.attr.remove(self.instance, item, None) + self.attr.remove(self.state, item, None) class CollectionHistory(attributes.AttributeHistory): """Overrides AttributeHistory to receive append/remove events directly.""" - def __init__(self, attr, obj): + def __init__(self, attr, state): self._deleted_items = [] self._added_items = [] self._unchanged_items = [] - self._obj = obj + self._state = state def replace(self, olditems, newitems): self._added_items = newitems diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index c1933d3ff..e0f7799b7 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -466,9 +466,11 @@ class LoaderStack(object): def push_property(self, key): self.__stack.append(key) + return tuple(self.__stack) def push_mapper(self, mapper): self.__stack.append(mapper.base_mapper) + return tuple(self.__stack) def pop(self): self.__stack.pop() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5d495d7a9..9764a0ae6 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -166,7 +166,7 @@ class Mapper(object): def _is_orphan(self, obj): optimistic = has_identity(obj) for (key,klass) in self.delete_orphans: - if getattr(klass, key).hasparent(obj, optimistic=optimistic): + if attribute_manager.has_parent(klass, obj, key, optimistic=optimistic): return False else: if self.delete_orphans: @@ -531,7 +531,9 @@ class Mapper(object): cls = object.__getattribute__(self, 'class_') clskey = object.__getattribute__(self, 'key') - # get the class' mapper; will compile all mappers + if key.startswith('__'): + return object.__getattribute__(self, key) + class_mapper(cls) if cls.__dict__.get(clskey) is self: @@ -1369,21 +1371,27 @@ class Mapper(object): # been exposed to being modified by the application. identitykey = self.identity_key_from_row(row) - populate_existing = context.populate_existing or self.always_refresh - if identitykey in context.session.identity_map: - instance = context.session.identity_map[identitykey] + (session_identity_map, local_identity_map) = (context.session.identity_map, context.identity_map) + + if identitykey in session_identity_map: + instance = session_identity_map[identitykey] + if self.__should_log_debug: self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) + isnew = False + if context.version_check and self.version_id_col is not None and self.get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]: raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self.get_attr_by_column(instance, self.version_id_col), row[self.version_id_col])) - if populate_existing or context.session.is_expired(instance, unexpire=True): - if identitykey not in context.identity_map: - context.identity_map[identitykey] = instance + if context.populate_existing or self.always_refresh or instance._state.trigger is not None: + instance._state.trigger = None + if identitykey not in local_identity_map: + local_identity_map[identitykey] = instance isnew = True if extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew) + if extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: if result is not None: result.append(instance) @@ -1391,9 +1399,9 @@ class Mapper(object): else: if self.__should_log_debug: self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) + # look in result-local identitymap for it. - exists = identitykey in context.identity_map - if not exists: + if identitykey not in local_identity_map: if self.allow_null_pks: # check if *all* primary key cols in the result are None - this indicates # an instance of the object is not present in the row. @@ -1415,10 +1423,10 @@ class Mapper(object): instance._entity_name = self.entity_name if self.__should_log_debug: self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) - context.identity_map[identitykey] = instance + local_identity_map[identitykey] = instance isnew = True else: - instance = context.identity_map[identitykey] + instance = local_identity_map[identitykey] isnew = False # call further mapper properties on the row, to pull further @@ -1470,13 +1478,12 @@ class Mapper(object): def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, **flags): """populate an instance from a result row.""" - selectcontext.stack.push_mapper(self) + snapshot = selectcontext.stack.push_mapper(self) # retrieve a set of "row population" functions derived from the MapperProperties attached # to this Mapper. These are keyed in the select context based primarily off the # "snapshot" of the stack, which represents a path from the lead mapper in the query to this one, # including relation() names. the key also includes "self", and allows us to distinguish between # other mappers within our inheritance hierarchy - snapshot = selectcontext.stack.snapshot() populators = selectcontext.attributes.get(((isnew or ispostselect) and 'new_populators' or 'existing_populators', self, snapshot, ispostselect), None) if populators is None: # no populators; therefore this is the first time we are receiving a row for diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 6f06474b7..25f8bacab 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -318,7 +318,7 @@ class Session(object): a thread-managed Session adapter, provided by the [sqlalchemy.orm#scoped_session()] function. """ - def __init__(self, bind=None, autoflush=True, transactional=False, twophase=False, echo_uow=False, weak_identity_map=False, binds=None, extension=None): + def __init__(self, bind=None, autoflush=True, transactional=False, twophase=False, echo_uow=False, weak_identity_map=True, binds=None, extension=None): """Construct a new Session. autoflush @@ -383,20 +383,23 @@ class Session(object): committed. weak_identity_map - when ``True``, use a ``WeakValueDictionary`` instead of a regular ``dict`` - for this ``Session`` object's identity map. This will allow objects which - fall out of scope to be automatically removed from the ``Session``. However, - objects who have been marked as "dirty" will also be garbage collected, and - those changes will not be persisted. - + When set to the default value of ``False``, a weak-referencing map is used; + instances which are not externally referenced will be garbage collected + immediately. For dereferenced instances which have pending changes present, + the attribute management system will create a temporary strong-reference to + the object which lasts until the changes are flushed to the database, at which + point it's again dereferenced. Alternatively, when using the value ``True``, + the identity map uses a regular Python dictionary to store instances. The + session will maintain all instances present until they are removed using + expunge(), clear(), or purge(). """ self.echo_uow = echo_uow - self.uow = unitofwork.UnitOfWork(self, weak_identity_map=weak_identity_map) + self.weak_identity_map = weak_identity_map + self.uow = unitofwork.UnitOfWork(self) self.identity_map = self.uow.identity_map self.bind = bind self.__binds = {} - self.weak_identity_map = weak_identity_map self.transaction = None self.hash_key = id(self) self.autoflush = autoflush @@ -565,7 +568,7 @@ class Session(object): for instance in self: self._unattach(instance) - self.uow = unitofwork.UnitOfWork(self, weak_identity_map=self.weak_identity_map) + self.uow = unitofwork.UnitOfWork(self) self.identity_map = self.uow.identity_map def bind_mapper(self, mapper, bind, entity_name=None): @@ -736,11 +739,14 @@ class Session(object): def prune(self): """Removes unreferenced instances cached in the identity map. - Removes any object in this Session'sidentity map that is not + Note that this method is only meaningful if "weak_identity_map" + is set to False. + + Removes any object in this Session's identity map that is not referenced in user code, modified, new or scheduled for deletion. Returns the number of objects pruned. """ - + return self.uow.prune_identity_map() def _expire_impl(self, obj): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 1ece80616..17bb10ca2 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -202,7 +202,7 @@ class DeferredColumnLoader(LoaderStrategy): try: row = result.fetchone() for prop in group: - sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]]) + sessionlib.attribute_manager.set_committed_value(instance, prop.key, row[prop.columns[0]]) return attributes.ATTR_WAS_SET finally: result.close() @@ -572,8 +572,7 @@ class EagerLoader(AbstractRelationLoader): return None def create_row_processor(self, selectcontext, mapper, row): - selectcontext.stack.push_property(self.key) - path = selectcontext.stack.snapshot() + path = selectcontext.stack.push_property(self.key) row_decorator = self._create_row_decorator(selectcontext, row, path) if row_decorator is not None: @@ -591,7 +590,7 @@ class EagerLoader(AbstractRelationLoader): # event handlers. # # FIXME: instead of... - sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.select_mapper._instance(selectcontext, decorated_row, None)) + sessionlib.attribute_manager.set_raw_value(instance, self.key, self.select_mapper._instance(selectcontext, decorated_row, None)) # bypass and set directly: #instance.__dict__[self.key] = self.select_mapper._instance(selectcontext, decorated_row, None) else: diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 74ece20c3..39387b4cf 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -69,12 +69,12 @@ class UOWAttributeManager(attributes.AttributeManager): instance for all ``InstrumentedAttributes``. """ - def create_prop(self, class_, key, uselist, callable_, typecallable, + def _create_prop(self, class_, key, uselist, callable_, typecallable, cascade=None, extension=None, **kwargs): extension = util.to_list(extension or []) extension.insert(0, UOWEventHandler(key, class_, cascade=cascade)) - return super(UOWAttributeManager, self).create_prop( + return super(UOWAttributeManager, self)._create_prop( class_, key, uselist, callable_, typecallable, extension=extension, **kwargs) @@ -87,9 +87,9 @@ class UnitOfWork(object): operation. """ - def __init__(self, session, weak_identity_map=False): - if weak_identity_map: - self.identity_map = weakref.WeakValueDictionary() + def __init__(self, session): + if session.weak_identity_map: + self.identity_map = attributes.InstanceDict() else: self.identity_map = {} @@ -215,14 +215,18 @@ class UnitOfWork(object): session.extension.after_flush_postexec(session, flush_context) def prune_identity_map(self): - """Removes unreferenced instances cached in the identity map. + """Removes unreferenced instances cached in a strong-referencing identity map. + Note that this method is only meaningful if "weak_identity_map" + on the parent Session is set to False and therefore this UnitOfWork's + identity map is a regular dictionary + Removes any object in the identity map that is not referenced in user code or scheduled for a unit of work operation. Returns the number of objects pruned. """ - if isinstance(self.identity_map, weakref.WeakValueDictionary): + if isinstance(self.identity_map, attributes.InstanceDict): return 0 ref_count = len(self.identity_map) dirty = self.locate_dirty() |
