summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-10-22 00:24:26 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-10-22 00:24:26 +0000
commitbc240be3f87b41232671e4da7f59679744959154 (patch)
tree4e978300d6e796f8a476b2e6ccf543ba2e9d881f /lib/sqlalchemy/orm
parent97feb4dbeee3ef5bc50de667ec25a43d44a5ff2c (diff)
downloadsqlalchemy-bc240be3f87b41232671e4da7f59679744959154.tar.gz
- attributes module and test suite moves underneath 'orm' package
- fixed table comparison example in metadata.txt - docstrings all over the place - renamed mapper _getattrbycolumn/_setattrbycolumn to get_attr_by_column,set_attr_by_column - removed frommapper parameter from populate_instance(). the two operations can be performed separately - fix to examples/adjacencytree/byroot_tree.py to fire off lazy loaders upon load, to reduce query calling - added get(), get_by(), load() to MapperExtension - re-implemented ExtensionOption (called by extension() function) - redid _ExtensionCarrier to function dynamically based on __getattribute__ - added logging to attributes package, indicating the execution of a lazy callable - going to close [ticket:329]
Diffstat (limited to 'lib/sqlalchemy/orm')
-rw-r--r--lib/sqlalchemy/orm/attributes.py757
-rw-r--r--lib/sqlalchemy/orm/mapper.py154
-rw-r--r--lib/sqlalchemy/orm/properties.py3
-rw-r--r--lib/sqlalchemy/orm/query.py9
-rw-r--r--lib/sqlalchemy/orm/strategies.py7
-rw-r--r--lib/sqlalchemy/orm/sync.py4
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py3
7 files changed, 855 insertions, 82 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
new file mode 100644
index 000000000..f5eb669f3
--- /dev/null
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -0,0 +1,757 @@
+# attributes.py - manages object attributes
+# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from sqlalchemy import util
+import util as orm_util
+import weakref
+from sqlalchemy import logging
+
+class InstrumentedAttribute(object):
+ """a property object that instruments attribute access on object instances. All methods correspond to
+ a single attribute on a particular class."""
+
+ PASSIVE_NORESULT = object()
+
+ def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs):
+ self.manager = manager
+ self.key = key
+ self.uselist = uselist
+ self.callable_ = callable_
+ self.typecallable= typecallable
+ self.trackparent = trackparent
+ self.mutable_scalars = mutable_scalars
+ if copy_function is None:
+ if uselist:
+ self._copyfunc = lambda x: [y for y in x]
+ else:
+ # scalar values are assumed to be immutable unless a copy function
+ # is passed
+ self._copyfunc = lambda x: x
+ else:
+ self._copyfunc = copy_function
+ if compare_function is None:
+ self._compare_function = lambda x,y: x == y
+ else:
+ self._compare_function = compare_function
+ self.extensions = util.to_list(extension or [])
+
+ def __set__(self, obj, value):
+ self.set(None, obj, value)
+ def __delete__(self, obj):
+ self.delete(None, obj)
+ def __get__(self, obj, owner):
+ if obj is None:
+ return self
+ return self.get(obj)
+
+ def is_equal(self, x, y):
+ return self._compare_function(x, y)
+ def copy(self, value):
+ return self._copyfunc(value)
+
+ def check_mutable_modified(self, obj):
+ if self.mutable_scalars:
+ h = self.get_history(obj, passive=True)
+ if h is not None and h.is_modified():
+ obj._state['modified'] = True
+ return True
+ else:
+ return False
+ else:
+ return False
+
+
+ def hasparent(self, item, optimistic=False):
+ """return the boolean value of a "hasparent" flag attached to the given item.
+
+ the 'optimistic' flag determines what the default return value should be if
+ no "hasparent" flag can be located. as this function is used to determine if
+ an instance is an "orphan", instances that were loaded from storage should be assumed
+ to not be orphans, until a True/False value for this flag is set. an instance attribute
+ that is loaded by a callable function will also not have a "hasparent" flag.
+ """
+ return item._state.get(('hasparent', id(self)), optimistic)
+
+ def sethasparent(self, item, value):
+ """sets a boolean flag on the given item corresponding to whether or not it is
+ attached to a parent object via the attribute represented by this InstrumentedAttribute."""
+ item._state[('hasparent', id(self))] = value
+
+ def get_history(self, obj, passive=False):
+ """return a new AttributeHistory object for the given object/this attribute's key.
+
+ if passive is True, then dont execute any callables; if the attribute's value
+ can only be achieved via executing a callable, then return None."""
+ # get the current state. this may trigger a lazy load if
+ # passive is False.
+ current = self.get(obj, passive=passive, raiseerr=False)
+ if current is InstrumentedAttribute.PASSIVE_NORESULT:
+ return None
+ return AttributeHistory(self, obj, current, passive=passive)
+
+ def set_callable(self, obj, callable_):
+ """set a callable function for this attribute on the given object.
+
+ this callable will be executed when the attribute is next accessed,
+ and is assumed to construct part of the instances previously stored state. When
+ its value or values are loaded, they will be established as part of the
+ instance's "committed state". while "trackparent" information will be assembled
+ for these instances, attribute-level event handlers will not be fired.
+
+ the callable overrides the class level callable set in the InstrumentedAttribute
+ constructor.
+ """
+ if callable_ is None:
+ self.initialize(obj)
+ else:
+ obj._state[('callable', self)] = callable_
+
+ def reset(self, obj):
+ """removes any per-instance callable functions corresponding to this InstrumentedAttribute's attribute
+ from the given object, and removes this InstrumentedAttribute's
+ attribute from the given object's dictionary."""
+ try:
+ del obj._state[('callable', self)]
+ except KeyError:
+ pass
+ self.clear(obj)
+
+ def clear(self, obj):
+ """removes this InstrumentedAttribute's attribute from the given object's dictionary. subsequent calls to
+ getattr(obj, key) will raise an AttributeError by default."""
+ try:
+ del obj.__dict__[self.key]
+ except KeyError:
+ pass
+
+ def _get_callable(self, obj):
+ if obj._state.has_key(('callable', self)):
+ return obj._state[('callable', self)]
+ elif self.callable_ is not None:
+ return self.callable_(obj)
+ else:
+ return None
+
+ def _blank_list(self):
+ if self.typecallable is not None:
+ return self.typecallable()
+ else:
+ return []
+
+ def _adapt_list(self, data):
+ if self.typecallable is not None:
+ t = self.typecallable()
+ if data is not None:
+ [t.append(x) for x in data]
+ return t
+ else:
+ return data
+
+ def initialize(self, obj):
+ """initialize this attribute on the given object instance.
+
+ if this is a list-based attribute, a new, blank list will be created.
+ if a scalar attribute, the value will be initialized to None."""
+ if self.uselist:
+ l = InstrumentedList(self, obj, self._blank_list())
+ obj.__dict__[self.key] = l
+ return l
+ else:
+ obj.__dict__[self.key] = None
+ return None
+
+ def get(self, obj, passive=False, raiseerr=True):
+ """retrieves a value from the given object. if a callable is assembled
+ on this object's attribute, and passive is False, the callable will be executed
+ and the resulting value will be set as the new value for this attribute."""
+ try:
+ return obj.__dict__[self.key]
+ except KeyError:
+ state = obj._state
+ # if an instance-wide "trigger" was set, call that
+ # and start again
+ if state.has_key('trigger'):
+ trig = state['trigger']
+ del state['trigger']
+ trig()
+ return self.get(obj, passive=passive, raiseerr=raiseerr)
+
+ if self.uselist:
+ callable_ = self._get_callable(obj)
+ if callable_ is not None:
+ if passive:
+ return InstrumentedAttribute.PASSIVE_NORESULT
+ self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
+ values = callable_()
+ l = InstrumentedList(self, obj, self._adapt_list(values), init=False)
+
+ # if a callable was executed, then its part of the "committed state"
+ # if any, so commit the newly loaded data
+ orig = state.get('original', None)
+ if orig is not None:
+ orig.commit_attribute(self, obj, l)
+
+ else:
+ # note that we arent raising AttributeErrors, just creating a new
+ # blank list and setting it.
+ # this might be a good thing to be changeable by options.
+ l = InstrumentedList(self, obj, self._blank_list(), init=False)
+ obj.__dict__[self.key] = l
+ return l
+ else:
+ callable_ = self._get_callable(obj)
+ if callable_ is not None:
+ if passive:
+ return InstrumentedAttribute.PASSIVE_NORESULT
+ self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
+ value = callable_()
+ obj.__dict__[self.key] = value
+
+ # if a callable was executed, then its part of the "committed state"
+ # if any, so commit the newly loaded data
+ orig = state.get('original', None)
+ if orig is not None:
+ orig.commit_attribute(self, obj)
+ return obj.__dict__[self.key]
+ else:
+ # note that we arent raising AttributeErrors, just returning None.
+ # this might be a good thing to be changeable by options.
+ return None
+
+ def set(self, event, obj, value):
+ """sets a value on the given object. 'event' is the InstrumentedAttribute that
+ initiated the set() operation and is used to control the depth of a circular setter
+ operation."""
+ if event is not self:
+ state = obj._state
+ # if an instance-wide "trigger" was set, call that
+ if state.has_key('trigger'):
+ trig = state['trigger']
+ del state['trigger']
+ trig()
+ if self.uselist:
+ value = InstrumentedList(self, obj, value)
+ old = self.get(obj)
+ obj.__dict__[self.key] = value
+ state['modified'] = True
+ if not self.uselist:
+ if self.trackparent:
+ if value is not None:
+ self.sethasparent(value, True)
+ if old is not None:
+ self.sethasparent(old, False)
+ for ext in self.extensions:
+ ext.set(event or self, obj, value, old)
+ else:
+ # mark all the old elements as detached from the parent
+ old.list_replaced()
+
+ def delete(self, event, obj):
+ """deletes a value from the given object. 'event' is the InstrumentedAttribute that
+ initiated the delete() operation and is used to control the depth of a circular delete
+ operation."""
+ if event is not self:
+ try:
+ if not self.uselist and (self.trackparent or len(self.extensions)):
+ old = self.get(obj)
+ del obj.__dict__[self.key]
+ except KeyError:
+ # TODO: raise this? not consistent with get() ?
+ raise AttributeError(self.key)
+ obj._state['modified'] = True
+ if not self.uselist:
+ if self.trackparent:
+ if old is not None:
+ self.sethasparent(old, False)
+ for ext in self.extensions:
+ ext.delete(event or self, obj, old)
+
+ def append(self, event, obj, value):
+ """appends an element to a list based element or sets a scalar based element to the given value.
+ Used by GenericBackrefExtension to "append" an item independent of list/scalar semantics.
+ 'event' is the InstrumentedAttribute that initiated the append() operation and is used to control
+ the depth of a circular append operation."""
+ if self.uselist:
+ if event is not self:
+ self.get(obj).append_with_event(value, event)
+ else:
+ self.set(event, obj, value)
+
+ def remove(self, event, obj, value):
+ """removes an element from a list based element or sets a scalar based element to None.
+ Used by GenericBackrefExtension to "remove" an item independent of list/scalar semantics.
+ 'event' is the InstrumentedAttribute that initiated the remove() operation and is used to control
+ the depth of a circular remove operation."""
+ if self.uselist:
+ if event is not self:
+ self.get(obj).remove_with_event(value, event)
+ else:
+ self.set(event, obj, None)
+
+ def append_event(self, event, obj, value):
+ """called by InstrumentedList when an item is appended"""
+ obj._state['modified'] = True
+ if self.trackparent and value is not None:
+ self.sethasparent(value, True)
+ for ext in self.extensions:
+ ext.append(event or self, obj, value)
+
+ def remove_event(self, event, obj, value):
+ """called by InstrumentedList when an item is removed"""
+ obj._state['modified'] = True
+ if self.trackparent and value is not None:
+ self.sethasparent(value, False)
+ for ext in self.extensions:
+ ext.delete(event or self, obj, value)
+InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute)
+
+class InstrumentedList(object):
+ """instruments a list-based attribute. all mutator operations (i.e. append, remove, etc.) will fire off events to the
+ InstrumentedAttribute that manages the object's attribute. those events in turn trigger things like
+ backref operations and whatever is implemented by do_list_value_changed on InstrumentedAttribute.
+
+ note that this list does a lot less than earlier versions of SA list-based attributes, which used HistoryArraySet.
+ this list wrapper does *not* maintain setlike semantics, meaning you can add as many duplicates as
+ you want (which can break a lot of SQL), and also does not do anything related to history tracking.
+
+ Please see ticket #213 for information on the future of this class, where it will be broken out into more
+ collection-specific subtypes."""
+ def __init__(self, attr, obj, data, init=True):
+ self.attr = attr
+ # this weakref is to prevent circular references between the parent object
+ # and the list attribute, which interferes with immediate garbage collection.
+ self.__obj = weakref.ref(obj)
+ self.key = attr.key
+ self.data = data or attr._blank_list()
+
+ # adapt to lists or sets
+ # TODO: make three subclasses of InstrumentedList that come off from a
+ # metaclass, based on the type of data sent in
+ if hasattr(self.data, 'append'):
+ self._data_appender = self.data.append
+ self._clear_data = self._clear_list
+ elif hasattr(self.data, 'add'):
+ self._data_appender = self.data.add
+ self._clear_data = self._clear_set
+ if isinstance(self.data, dict):
+ self._clear_data = self._clear_dict
+
+ if init:
+ for x in self.data:
+ self.__setrecord(x)
+
+ def list_replaced(self):
+ """fires off delete event handlers for each item in the list but
+ doesnt affect the original data list"""
+ [self.__delrecord(x) for x in self.data]
+
+ def clear(self):
+ """clears all items in this InstrumentedList and fires off delete event handlers for each item"""
+ self._clear_data()
+ def _clear_dict(self):
+ [self.__delrecord(x) for x in self.data.values()]
+ self.data.clear()
+ def _clear_set(self):
+ [self.__delrecord(x) for x in self.data]
+ self.data.clear()
+ def _clear_list(self):
+ self[:] = []
+
+ def __getstate__(self):
+ """implemented to allow pickling, since __obj is a weakref, also the InstrumentedAttribute has callables
+ attached to it"""
+ return {'key':self.key, 'obj':self.obj, 'data':self.data}
+ def __setstate__(self, d):
+ """implemented to allow pickling, since __obj is a weakref, also the InstrumentedAttribute has callables
+ attached to it"""
+ self.key = d['key']
+ self.__obj = weakref.ref(d['obj'])
+ self.data = d['data']
+ self.attr = getattr(d['obj'].__class__, self.key)
+
+ obj = property(lambda s:s.__obj())
+
+ def unchanged_items(self):
+ """deprecated"""
+ return self.attr.get_history(self.obj).unchanged_items
+ def added_items(self):
+ """deprecated"""
+ return self.attr.get_history(self.obj).added_items
+ def deleted_items(self):
+ """deprecated"""
+ return self.attr.get_history(self.obj).deleted_items
+
+ def __iter__(self):
+ return iter(self.data)
+ def __repr__(self):
+ return repr(self.data)
+
+ def __getattr__(self, attr):
+ """proxies unknown methods and attributes to the underlying
+ data array. this allows custom list classes to be used."""
+ return getattr(self.data, attr)
+
+ def __setrecord(self, item, event=None):
+ self.attr.append_event(event, self.obj, item)
+ return True
+
+ def __delrecord(self, item, event=None):
+ self.attr.remove_event(event, self.obj, item)
+ return True
+
+ def append_with_event(self, item, event):
+ self.__setrecord(item, event)
+ self._data_appender(item)
+
+ def append_without_event(self, item):
+ self._data_appender(item)
+
+ def remove_with_event(self, item, event):
+ self.__delrecord(item, event)
+ self.data.remove(item)
+
+ def append(self, item, _mapper_nohistory=False):
+ """fires off dependent events, and appends the given item to the underlying list.
+ _mapper_nohistory is a backwards compatibility hack; call append_without_event instead."""
+ if _mapper_nohistory:
+ self.append_without_event(item)
+ else:
+ self.__setrecord(item)
+ self._data_appender(item)
+
+
+ def __getitem__(self, i):
+ return self.data[i]
+ def __setitem__(self, i, item):
+ if isinstance(i, slice):
+ self.__setslice__(i.start, i.stop, item)
+ else:
+ self.__setrecord(item)
+ self.data[i] = item
+ def __delitem__(self, i):
+ if isinstance(i, slice):
+ self.__delslice__(i.start, i.stop)
+ else:
+ self.__delrecord(self.data[i], None)
+ del self.data[i]
+
+ def __lt__(self, other): return self.data < self.__cast(other)
+ def __le__(self, other): return self.data <= self.__cast(other)
+ def __eq__(self, other): return self.data == self.__cast(other)
+ def __ne__(self, other): return self.data != self.__cast(other)
+ def __gt__(self, other): return self.data > self.__cast(other)
+ def __ge__(self, other): return self.data >= self.__cast(other)
+ def __cast(self, other):
+ if isinstance(other, InstrumentedList): return other.data
+ else: return other
+ def __cmp__(self, other):
+ return cmp(self.data, self.__cast(other))
+ def __contains__(self, item): return item in self.data
+ def __len__(self): return len(self.data)
+ def __setslice__(self, i, j, other):
+ i = max(i, 0); j = max(j, 0)
+ [self.__delrecord(x) for x in self.data[i:]]
+ g = [a for a in list(other) if self.__setrecord(a)]
+ self.data[i:] = g
+ def __delslice__(self, i, j):
+ i = max(i, 0); j = max(j, 0)
+ for a in self.data[i:j]:
+ self.__delrecord(a)
+ del self.data[i:j]
+ def insert(self, i, item):
+ if self.__setrecord(item):
+ self.data.insert(i, item)
+ def pop(self, i=-1):
+ item = self.data[i]
+ self.__delrecord(item)
+ return self.data.pop(i)
+ def remove(self, item):
+ self.__delrecord(item)
+ self.data.remove(item)
+ def extend(self, item_list):
+ for item in item_list:
+ self.append(item)
+ def __add__(self, other):
+ raise NotImplementedError()
+ def __radd__(self, other):
+ raise NotImplementedError()
+ def __iadd__(self, other):
+ raise NotImplementedError()
+
+class AttributeExtension(object):
+ """an abstract class which specifies "append", "delete", and "set"
+ event handlers to be attached to an object property."""
+ def append(self, event, obj, child):
+ pass
+ def delete(self, event, obj, child):
+ pass
+ def set(self, event, obj, child, oldchild):
+ pass
+
+class GenericBackrefExtension(AttributeExtension):
+ """an extension which synchronizes a two-way relationship. A typical two-way
+ relationship is a parent object containing a list of child objects, where each
+ child object references the parent. The other are two objects which contain
+ scalar references to each other."""
+ def __init__(self, key):
+ self.key = key
+ def set(self, event, obj, child, oldchild):
+ if oldchild is child:
+ return
+ if oldchild is not None:
+ getattr(oldchild.__class__, self.key).remove(event, oldchild, obj)
+ if child is not None:
+ getattr(child.__class__, self.key).append(event, child, obj)
+ def append(self, event, obj, child):
+ getattr(child.__class__, self.key).append(event, child, obj)
+ def delete(self, event, obj, child):
+ getattr(child.__class__, self.key).remove(event, child, obj)
+
+class CommittedState(object):
+ """stores the original state of an object when the commit() method on the attribute manager
+ is called."""
+ NO_VALUE = object()
+
+ def __init__(self, manager, obj):
+ self.data = {}
+ for attr in manager.managed_attributes(obj.__class__):
+ self.commit_attribute(attr, obj)
+
+ def commit_attribute(self, attr, obj, value=NO_VALUE):
+ """establish the value of attribute 'attr' on instance 'obj' as "committed".
+
+ this corresponds to a previously saved state being restored. """
+ if value is CommittedState.NO_VALUE:
+ if obj.__dict__.has_key(attr.key):
+ value = obj.__dict__[attr.key]
+ if value is not CommittedState.NO_VALUE:
+ self.data[attr.key] = attr.copy(value)
+
+ # not tracking parent on lazy-loaded instances at the moment.
+ # its not needed since they will be "optimistically" tested
+ #if attr.uselist:
+ #if attr.trackparent:
+ # [attr.sethasparent(x, True) for x in self.data[attr.key] if x is not None]
+ #else:
+ #if attr.trackparent and value is not None:
+ # attr.sethasparent(value, True)
+
+ def rollback(self, manager, obj):
+ for attr in manager.managed_attributes(obj.__class__):
+ if self.data.has_key(attr.key):
+ if attr.uselist:
+ obj.__dict__[attr.key][:] = self.data[attr.key]
+ else:
+ obj.__dict__[attr.key] = self.data[attr.key]
+ else:
+ del obj.__dict__[attr.key]
+
+ def __repr__(self):
+ return "CommittedState: %s" % repr(self.data)
+
+class AttributeHistory(object):
+ """calculates the "history" of a particular attribute on a particular instance, based on the CommittedState
+ associated with the instance, if any."""
+ def __init__(self, attr, obj, current, passive=False):
+ self.attr = attr
+
+ # get the "original" value. if a lazy load was fired when we got
+ # the 'current' value, this "original" was also populated just
+ # now as well (therefore we have to get it second)
+ orig = obj._state.get('original', None)
+ if orig is not None:
+ original = orig.data.get(attr.key)
+ else:
+ original = None
+
+ if attr.uselist:
+ self._current = current
+ else:
+ self._current = [current]
+ if attr.uselist:
+ s = util.Set(original or [])
+ self._added_items = []
+ self._unchanged_items = []
+ self._deleted_items = []
+ if current:
+ for a in current:
+ if a in s:
+ self._unchanged_items.append(a)
+ else:
+ self._added_items.append(a)
+ for a in s:
+ if a not in self._unchanged_items:
+ self._deleted_items.append(a)
+ else:
+ if attr.is_equal(current, original):
+ self._unchanged_items = [current]
+ self._added_items = []
+ self._deleted_items = []
+ else:
+ self._added_items = [current]
+ if original is not None:
+ self._deleted_items = [original]
+ else:
+ self._deleted_items = []
+ self._unchanged_items = []
+ #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items
+ def __iter__(self):
+ return iter(self._current)
+ def is_modified(self):
+ return len(self._deleted_items) > 0 or len(self._added_items) > 0
+ def added_items(self):
+ return self._added_items
+ def unchanged_items(self):
+ return self._unchanged_items
+ def deleted_items(self):
+ return self._deleted_items
+ def hasparent(self, obj):
+ """deprecated. this should be called directly from the appropriate InstrumentedAttribute object."""
+ return self.attr.hasparent(obj)
+
+class AttributeManager(object):
+ """allows the instrumentation of object attributes. AttributeManager is stateless, but can be
+ overridden by subclasses to redefine some of its factory operations."""
+
+ def rollback(self, *obj):
+ """retrieves the committed history for each object in the given list, and rolls back the attributes
+ each instance to their original value."""
+ for o in obj:
+ orig = o._state.get('original')
+ if orig is not None:
+ orig.rollback(self, o)
+ else:
+ self._clear(o)
+
+ def _clear(self, obj):
+ for attr in self.managed_attributes(obj.__class__):
+ try:
+ del obj.__dict__[attr.key]
+ except KeyError:
+ pass
+
+ def commit(self, *obj):
+ """creates a CommittedState instance for each object in the given list, representing
+ its "unchanged" state, and associates it with the instance. AttributeHistory objects
+ will indicate the modified state of instance attributes as compared to its value in this
+ CommittedState object."""
+ for o in obj:
+ o._state['original'] = CommittedState(self, o)
+ o._state['modified'] = False
+
+ def managed_attributes(self, class_):
+ """returns an iterator of all InstrumentedAttribute objects associated with the given class."""
+ if not isinstance(class_, type):
+ raise repr(class_) + " is not a type"
+ for key in dir(class_):
+ value = getattr(class_, key, None)
+ if isinstance(value, InstrumentedAttribute):
+ yield value
+
+ def noninherited_managed_attributes(self, class_):
+ if not isinstance(class_, type):
+ raise repr(class_) + " is not a type"
+ for key in list(class_.__dict__):
+ value = getattr(class_, key, None)
+ if isinstance(value, InstrumentedAttribute):
+ yield value
+
+ def is_modified(self, object):
+ for attr in self.managed_attributes(object.__class__):
+ if attr.check_mutable_modified(object):
+ return True
+ return object._state.get('modified', False)
+
+ def init_attr(self, obj):
+ """sets up the __sa_attr_state dictionary on the given instance. This dictionary is
+ automatically created when the '_state' attribute of the class is first accessed, but calling
+ it here will save a single throw of an AttributeError that occurs in that creation step."""
+ setattr(obj, '_%s__sa_attr_state' % obj.__class__.__name__, {})
+
+ def get_history(self, obj, key, **kwargs):
+ """returns a new AttributeHistory object for the given attribute on the given object."""
+ return getattr(obj.__class__, key).get_history(obj, **kwargs)
+
+ def get_as_list(self, obj, key, passive=False):
+ """returns an attribute of the given name from the given object. if the attribute
+ is a scalar, returns it as a single-item list, otherwise returns the list based attribute.
+ if the attribute's value is to be produced by an unexecuted callable,
+ the callable will only be executed if the given 'passive' flag is False.
+ """
+ attr = getattr(obj.__class__, key)
+ x = attr.get(obj, passive=passive)
+ if x is InstrumentedAttribute.PASSIVE_NORESULT:
+ return []
+ elif attr.uselist:
+ return x
+ else:
+ return [x]
+
+ def trigger_history(self, obj, callable):
+ """clears all managed object attributes and places the given callable
+ as an attribute-wide "trigger", which will execute upon the next attribute access, after
+ which the trigger is removed."""
+ self._clear(obj)
+ try:
+ del obj._state['original']
+ except KeyError:
+ pass
+ obj._state['trigger'] = callable
+
+ def untrigger_history(self, obj):
+ """removes a trigger function set by trigger_history. does not restore the previous state of the object."""
+ del obj._state['trigger']
+
+ def has_trigger(self, obj):
+ """returns True if the given object has a trigger function set by trigger_history()."""
+ return obj._state.has_key('trigger')
+
+ def reset_instance_attribute(self, obj, key):
+ """removes any per-instance callable functions corresponding to given attribute key
+ from the given object, and removes this attribute from the given object's dictionary."""
+ attr = getattr(obj.__class__, key)
+ attr.reset(obj)
+
+ def reset_class_managed(self, class_):
+ """removes all InstrumentedAttribute property objects from the given class."""
+ for attr in self.noninherited_managed_attributes(class_):
+ delattr(class_, attr.key)
+
+ def is_class_managed(self, class_, key):
+ """returns True if the given key correponds to an instrumented property on the given class."""
+ return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute)
+
+ def init_instance_attribute(self, obj, key, uselist, callable_=None, **kwargs):
+ """initializes an attribute on an instance to either a blank value, cancelling
+ out any class- or instance-level callables that were present, or if a callable
+ is supplied sets the callable to be invoked when the attribute is next accessed."""
+ getattr(obj.__class__, key).set_callable(obj, callable_)
+
+ def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs):
+ """creates a scalar property object, defaulting to InstrumentedAttribute, which
+ will communicate change events back to this AttributeManager."""
+ return InstrumentedAttribute(self, key, uselist, callable_, typecallable, **kwargs)
+
+ def register_attribute(self, class_, key, uselist, callable_=None, **kwargs):
+ """registers an attribute at the class level to be instrumented for all instances
+ of the class."""
+ #print self, "register attribute", key, "for class", class_
+ if not hasattr(class_, '_state'):
+ def _get_state(self):
+ try:
+ return self.__sa_attr_state
+ except AttributeError:
+ self.__sa_attr_state = {}
+ return self.__sa_attr_state
+ class_._state = property(_get_state)
+
+ typecallable = kwargs.pop('typecallable', None)
+ if typecallable is None:
+ typecallable = getattr(class_, key, None)
+ if isinstance(typecallable, InstrumentedAttribute):
+ typecallable = None
+ setattr(class_, key, self.create_prop(class_, key, uselist, callable_, typecallable=typecallable, **kwargs))
+
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index b3dbe7a92..1327644b2 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -21,7 +21,7 @@ mapper_registry = weakref.WeakKeyDictionary()
# a list of MapperExtensions that will be installed in all mappers by default
global_extensions = []
-# a constant returned by _getattrbycolumn to indicate
+# a constant returned by get_attr_by_column to indicate
# this mapper is not handling an attribute for a particular
# column
NO_ATTRIBUTE = object()
@@ -289,10 +289,6 @@ class Mapper(object):
def _compile_extensions(self):
"""goes through the global_extensions list as well as the list of MapperExtensions
specified for this Mapper and creates a linked list of those extensions."""
- # uber-pendantic style of making mapper chain, as various testbase/
- # threadlocal/assignmapper combinations keep putting dupes etc. in the list
- # TODO: do something that isnt 21 lines....
-
extlist = util.Set()
for ext_class in global_extensions:
if isinstance(ext_class, MapperExtension):
@@ -307,7 +303,7 @@ class Mapper(object):
self.extension = _ExtensionCarrier()
for ext in extlist:
- self.extension.elements.append(ext)
+ self.extension.append(ext)
def _compile_inheritance(self):
"""determines if this Mapper inherits from another mapper, and if so calculates the mapped_table
@@ -744,7 +740,7 @@ class Mapper(object):
def primary_key_from_instance(self, instance):
"""return the list of primary key values for the given instance."""
- return [self._getattrbycolumn(instance, column) for column in self.pks_by_table[self.mapped_table]]
+ return [self.get_attr_by_column(instance, column) for column in self.pks_by_table[self.mapped_table]]
def instance_key(self, instance):
"""deprecated. a synonym for identity_key_from_instance."""
@@ -773,21 +769,27 @@ class Mapper(object):
raise exceptions.InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self)))
return prop[0]
- def _getattrbycolumn(self, obj, column, raiseerror=True):
+ def get_attr_by_column(self, obj, column, raiseerror=True):
+ """return an instance attribute using a Column as the key."""
prop = self._getpropbycolumn(column, raiseerror)
if prop is None:
return NO_ATTRIBUTE
#self.__log_debug("get column attribute '%s' from instance %s" % (column.key, mapperutil.instance_str(obj)))
return prop.getattr(obj)
- def _setattrbycolumn(self, obj, column, value):
+ def set_attr_by_column(self, obj, column, value):
+ """set the value of an instance attribute using a Column as the key."""
self.columntoproperty[column][0].setattr(obj, value)
def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False):
- """save a list of objects.
+ """issue INSERT and/or UPDATE statements for a list of objects.
+
+ this is called within the context of a UOWTransaction during a flush operation.
- this method is called within a unit of work flush() process. It saves objects that are mapped not just
- by this mapper, but inherited mappers as well, so that insert ordering of polymorphic objects is maintained."""
+ save_obj issues SQL statements not just for instances mapped directly by this mapper, but
+ for instances mapped by all inheriting mappers as well. This is to maintain proper insert
+ ordering among a polymorphic chain of instances. Therefore save_obj is typically
+ called only on a "base mapper", or a mapper which does not inherit from any other mapper."""
self.__log_debug("save_obj() start, " + (single and "non-batched" or "batched"))
@@ -848,7 +850,7 @@ class Mapper(object):
for col in table.columns:
if col is mapper.version_id_col:
if not isinsert:
- params[col._label] = mapper._getattrbycolumn(obj, col)
+ params[col._label] = mapper.get_attr_by_column(obj, col)
params[col.key] = params[col._label] + 1
else:
params[col.key] = 1
@@ -857,14 +859,14 @@ class Mapper(object):
if not isinsert:
# doing an UPDATE? put primary key values as "WHERE" parameters
# matching the bindparam we are creating below, i.e. "<tablename>_<colname>"
- params[col._label] = mapper._getattrbycolumn(obj, col)
+ params[col._label] = mapper.get_attr_by_column(obj, col)
else:
# doing an INSERT, primary key col ?
# if the primary key values are not populated,
# leave them out of the INSERT altogether, since PostGres doesn't want
# them to be present for SERIAL to take effect. A SQLEngine that uses
# explicit sequences will put them back in if they are needed
- value = mapper._getattrbycolumn(obj, col)
+ value = mapper.get_attr_by_column(obj, col)
if value is not None:
params[col.key] = value
elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
@@ -882,7 +884,7 @@ class Mapper(object):
if post_update_cols is not None and col not in post_update_cols:
continue
elif is_row_switch:
- params[col.key] = self._getattrbycolumn(obj, col)
+ params[col.key] = self.get_attr_by_column(obj, col)
hasdata = True
continue
prop = mapper._getpropbycolumn(col, False)
@@ -901,7 +903,7 @@ class Mapper(object):
# default. if its None and theres no default, we still might
# not want to put it in the col list but SQLIte doesnt seem to like that
# if theres no columns at all
- value = mapper._getattrbycolumn(obj, col, False)
+ value = mapper.get_attr_by_column(obj, col, False)
if value is NO_ATTRIBUTE:
continue
if col.default is None or value is not None:
@@ -955,8 +957,8 @@ class Mapper(object):
if primary_key is not None:
i = 0
for col in mapper.pks_by_table[table]:
- if mapper._getattrbycolumn(obj, col) is None and len(primary_key) > i:
- mapper._setattrbycolumn(obj, col, primary_key[i])
+ if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i:
+ mapper.set_attr_by_column(obj, col, primary_key[i])
i+=1
mapper._postfetch(connection, table, obj, c, c.last_inserted_params())
@@ -987,26 +989,29 @@ class Mapper(object):
if resultproxy.lastrow_has_defaults():
clause = sql.and_()
for p in self.pks_by_table[table]:
- clause.clauses.append(p == self._getattrbycolumn(obj, p))
+ clause.clauses.append(p == self.get_attr_by_column(obj, p))
row = connection.execute(table.select(clause), None).fetchone()
for c in table.c:
- if self._getattrbycolumn(obj, c, False) is None:
- self._setattrbycolumn(obj, c, row[c])
+ if self.get_attr_by_column(obj, c, False) is None:
+ self.set_attr_by_column(obj, c, row[c])
else:
for c in table.c:
if c.primary_key or not params.has_key(c.name):
continue
- v = self._getattrbycolumn(obj, c, False)
+ v = self.get_attr_by_column(obj, c, False)
if v is NO_ATTRIBUTE:
continue
elif v != params.get_original(c.name):
- self._setattrbycolumn(obj, c, params.get_original(c.name))
+ self.set_attr_by_column(obj, c, params.get_original(c.name))
def delete_obj(self, objects, uowtransaction):
- """called by a UnitOfWork object to delete objects, which involves a
- DELETE statement for each table used by this mapper, for each object in the list."""
+ """issue DELETE statements for a list of objects.
+
+ this is called within the context of a UOWTransaction during a flush operation."""
+
+ self.__log_debug("delete_obj() start")
+
connection = uowtransaction.transaction.connection(self)
- #print "DELETE_OBJ MAPPER", self.class_.__name__, objects
[self.extension.before_delete(self, connection, obj) for obj in objects]
deleted_objects = util.Set()
@@ -1021,9 +1026,9 @@ class Mapper(object):
else:
delete.append(params)
for col in self.pks_by_table[table]:
- params[col.key] = self._getattrbycolumn(obj, col)
+ params[col.key] = self.get_attr_by_column(obj, col)
if self.version_id_col is not None:
- params[self.version_id_col.key] = self._getattrbycolumn(obj, self.version_id_col)
+ params[self.version_id_col.key] = self.get_attr_by_column(obj, self.version_id_col)
deleted_objects.add(obj)
if len(delete):
def comparator(a, b):
@@ -1123,8 +1128,8 @@ class Mapper(object):
instance = context.session._get(identitykey)
self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
isnew = False
- if context.version_check and self.version_id_col is not None and self._getattrbycolumn(instance, self.version_id_col) != row[self.version_id_col]:
- raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._getattrbycolumn(instance, self.version_id_col), row[self.version_id_col]))
+ if context.version_check and self.version_id_col is not None and self.get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
+ raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self.get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
if populate_existing or context.session.is_expired(instance, unexpire=True):
if not context.identity_map.has_key(identitykey):
@@ -1186,8 +1191,11 @@ class Mapper(object):
return obj
def translate_row(self, tomapper, row):
- """attempts to take a row and translate its values to a row that can
- be understood by another mapper."""
+ """translate the column keys of a row into a new or proxied row that
+ can be understood by another mapper.
+
+ This can be used in conjunction with populate_instance to populate
+ an instance using an alternate mapper."""
newrow = util.DictDecorator(row)
for c in tomapper.mapped_table.c:
c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=True)
@@ -1195,9 +1203,11 @@ class Mapper(object):
newrow[c] = row[c2]
return newrow
- def populate_instance(self, selectcontext, instance, row, identitykey, isnew, frommapper=None):
- if frommapper is not None:
- row = frommapper.translate_row(self, row)
+ def populate_instance(self, selectcontext, instance, row, identitykey, isnew):
+ """populate an instance from a result row.
+
+ This method iterates through the list of MapperProperty objects attached to this Mapper
+ and calls each properties execute() method."""
for prop in self.__props.values():
prop.execute(selectcontext, instance, row, identitykey, isnew)
@@ -1213,6 +1223,24 @@ class MapperExtension(object):
Note: this is not called if a session is provided with the __init__ params (i.e. _sa_session)"""
return EXT_PASS
+ def load(self, query, *args, **kwargs):
+ """override the load method of the Query object.
+
+ the return value of this method is used as the result of query.load() if the
+ value is anything other than EXT_PASS."""
+ return EXT_PASS
+ def get(self, query, *args, **kwargs):
+ """override the get method of the Query object.
+
+ the return value of this method is used as the result of query.get() if the
+ value is anything other than EXT_PASS."""
+ return EXT_PASS
+ def get_by(self, query, *args, **kwargs):
+ """override the get_by method of the Query object.
+
+ the return value of this method is used as the result of query.get_by() if the
+ value is anything other than EXT_PASS."""
+ return EXT_PASS
def select_by(self, query, *args, **kwargs):
"""override the select_by method of the Query object.
@@ -1271,14 +1299,7 @@ class MapperExtension(object):
as relationships to other classes). If this method returns EXT_PASS, instance population
will proceed normally. If any other value or None is returned, instance population
will not proceed, giving this extension an opportunity to populate the instance itself,
- if desired..
-
- A common usage of this method is to have population performed by an alternate mapper. This can
- be acheived via the populate_instance() call on Mapper.
-
- def populate_instance(self, mapper, selectcontext, instance, row, identitykey, isnew):
- othermapper.populate_instance(selectcontext, instance, row, identitykey, isnew, frommapper=mapper)
- return None
+ if desired.
"""
return EXT_PASS
def before_insert(self, mapper, connection, instance):
@@ -1304,41 +1325,24 @@ class MapperExtension(object):
class _ExtensionCarrier(MapperExtension):
def __init__(self):
- self.elements = []
+ self.__elements = []
+ self.__callables = {}
def insert(self, extension):
"""insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
- self.elements.insert(0, extension)
+ self.__elements.insert(0, extension)
def append(self, extension):
"""append a MapperExtension at the end of this ExtensionCarrier's list."""
- self.elements.append(extension)
- # TODO: shrink down this approach using __getattribute__ or similar
- def get_session(self):
- return self._do('get_session')
- def select_by(self, *args, **kwargs):
- return self._do('select_by', *args, **kwargs)
- def select(self, *args, **kwargs):
- return self._do('select', *args, **kwargs)
- def create_instance(self, *args, **kwargs):
- return self._do('create_instance', *args, **kwargs)
- def append_result(self, *args, **kwargs):
- return self._do('append_result', *args, **kwargs)
- def populate_instance(self, *args, **kwargs):
- return self._do('populate_instance', *args, **kwargs)
- def before_insert(self, *args, **kwargs):
- return self._do('before_insert', *args, **kwargs)
- def before_update(self, *args, **kwargs):
- return self._do('before_update', *args, **kwargs)
- def after_update(self, *args, **kwargs):
- return self._do('after_update', *args, **kwargs)
- def after_insert(self, *args, **kwargs):
- return self._do('after_insert', *args, **kwargs)
- def before_delete(self, *args, **kwargs):
- return self._do('before_delete', *args, **kwargs)
- def after_delete(self, *args, **kwargs):
- return self._do('after_delete', *args, **kwargs)
-
+ self.__elements.append(extension)
+ def __getattribute__(self, key):
+ if key in MapperExtension.__dict__:
+ try:
+ return self.__callables[key]
+ except KeyError:
+ return self.__callables.setdefault(key, lambda *args, **kwargs:self._do(key, *args, **kwargs))
+ else:
+ return super(_ExtensionCarrier, self).__getattribute__(key)
def _do(self, funcname, *args, **kwargs):
- for elem in self.elements:
+ for elem in self.__elements:
if elem is self:
raise exceptions.AssertionError("ExtensionCarrier set to itself")
ret = getattr(elem, funcname)(*args, **kwargs)
@@ -1346,7 +1350,7 @@ class _ExtensionCarrier(MapperExtension):
return ret
else:
return EXT_PASS
-
+
class ExtensionOption(MapperExtension):
def __init__(self, ext):
self.ext = ext
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 768b51959..501903983 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -9,10 +9,11 @@ well as relationships. the objects rely upon the LoaderStrategy objects in the
module to handle load operations. PropertyLoader also relies upon the dependency.py module
to handle flush-time dependency sorting and processing."""
-from sqlalchemy import sql, schema, util, attributes, exceptions, sql_util, logging
+from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
import mapper
import sync
import strategies
+import attributes
import session as sessionlib
import dependency
import util as mapperutil
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 82ed4e1a0..607b37fd9 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -54,6 +54,9 @@ class Query(object):
The ident argument is a scalar or tuple of primary key column values
in the order of the table def's primary key columns."""
+ ret = self.extension.get(self, ident, **kwargs)
+ if ret is not mapper.EXT_PASS:
+ return ret
key = self.mapper.identity_key(ident)
return self._get(key, ident, **kwargs)
@@ -63,6 +66,9 @@ class Query(object):
If not found, raises an exception. The method will *remove all pending changes* to the object
already existing in the Session. The ident argument is a scalar or tuple of primary
key column values in the order of the table def's primary key columns."""
+ ret = self.extension.load(self, ident, **kwargs)
+ if ret is not mapper.EXT_PASS:
+ return ret
key = self.mapper.identity_key(ident)
instance = self._get(key, ident, reload=True, **kwargs)
if instance is None:
@@ -83,6 +89,9 @@ class Query(object):
e.g. u = usermapper.get_by(user_name = 'fred')
"""
+ ret = self.extension.get_by(self, *args, **params)
+ if ret is not mapper.EXT_PASS:
+ return ret
x = self.select_whereclause(self.join_by(*args, **params), limit=1)
if x:
return x[0]
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 916a60d25..ac6b37e33 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -4,6 +4,8 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions."""
+
from sqlalchemy import sql, schema, util, attributes, exceptions, sql_util, logging
import mapper, query
from interfaces import *
@@ -11,7 +13,6 @@ import session as sessionlib
import util as mapperutil
import sets, random
-"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions."""
class ColumnLoader(LoaderStrategy):
def init(self):
@@ -73,7 +74,7 @@ class DeferredColumnLoader(LoaderStrategy):
clause = sql.and_()
for primary_key in pk:
- attr = self.parent._getattrbycolumn(instance, primary_key)
+ attr = self.parent.get_attr_by_column(instance, primary_key)
if not attr:
return None
clause.clauses.append(primary_key == attr)
@@ -178,7 +179,7 @@ class LazyLoader(AbstractRelationLoader):
return None
#print "setting up loader, lazywhere", str(self.lazywhere), "binds", self.lazybinds
for col, bind in self.lazybinds.iteritems():
- params[bind.key] = self.parent._getattrbycolumn(instance, col)
+ params[bind.key] = self.parent.get_attr_by_column(instance, col)
if params[bind.key] is None:
allparams = False
break
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
index 5f0331e16..6fe848a6f 100644
--- a/lib/sqlalchemy/orm/sync.py
+++ b/lib/sqlalchemy/orm/sync.py
@@ -132,7 +132,7 @@ class SyncRule(object):
if clearkeys or source is None:
value = None
else:
- value = self.source_mapper._getattrbycolumn(source, self.source_column)
+ value = self.source_mapper.get_attr_by_column(source, self.source_column)
if isinstance(dest, dict):
dest[self.dest_column.key] = value
else:
@@ -141,7 +141,7 @@ class SyncRule(object):
if logging.is_debug_enabled(self.logger):
self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value))
- self.dest_mapper._setattrbycolumn(dest, self.dest_column, value)
+ self.dest_mapper.set_attr_by_column(dest, self.dest_column, value)
SyncRule.logger = logging.class_logger(SyncRule)
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index c4fd92e36..815ffe029 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -14,8 +14,9 @@ an "identity map" pattern. The Unit of Work then maintains lists of objects tha
dirty, or deleted and provides the capability to flush all those changes at once.
"""
-from sqlalchemy import attributes, util, logging, topological
+from sqlalchemy import util, logging, topological
import sqlalchemy
+import attributes
from sqlalchemy.exceptions import *
import StringIO
import weakref