diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-10-22 00:24:26 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-10-22 00:24:26 +0000 |
| commit | bc240be3f87b41232671e4da7f59679744959154 (patch) | |
| tree | 4e978300d6e796f8a476b2e6ccf543ba2e9d881f /lib/sqlalchemy/orm | |
| parent | 97feb4dbeee3ef5bc50de667ec25a43d44a5ff2c (diff) | |
| download | sqlalchemy-bc240be3f87b41232671e4da7f59679744959154.tar.gz | |
- attributes module and test suite moves underneath 'orm' package
- fixed table comparison example in metadata.txt
- docstrings all over the place
- renamed mapper _getattrbycolumn/_setattrbycolumn to get_attr_by_column,set_attr_by_column
- removed frommapper parameter from populate_instance(). the two operations can be performed separately
- fix to examples/adjacencytree/byroot_tree.py to fire off lazy loaders upon load, to reduce query calling
- added get(), get_by(), load() to MapperExtension
- re-implemented ExtensionOption (called by extension() function)
- redid _ExtensionCarrier to function dynamically based on __getattribute__
- added logging to attributes package, indicating the execution of a lazy callable
- going to close [ticket:329]
Diffstat (limited to 'lib/sqlalchemy/orm')
| -rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 757 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 154 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/sync.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 3 |
7 files changed, 855 insertions, 82 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py new file mode 100644 index 000000000..f5eb669f3 --- /dev/null +++ b/lib/sqlalchemy/orm/attributes.py @@ -0,0 +1,757 @@ +# attributes.py - manages object attributes +# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from sqlalchemy import util +import util as orm_util +import weakref +from sqlalchemy import logging + +class InstrumentedAttribute(object): + """a property object that instruments attribute access on object instances. All methods correspond to + a single attribute on a particular class.""" + + PASSIVE_NORESULT = object() + + def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs): + self.manager = manager + self.key = key + self.uselist = uselist + self.callable_ = callable_ + self.typecallable= typecallable + self.trackparent = trackparent + self.mutable_scalars = mutable_scalars + if copy_function is None: + if uselist: + self._copyfunc = lambda x: [y for y in x] + else: + # scalar values are assumed to be immutable unless a copy function + # is passed + self._copyfunc = lambda x: x + else: + self._copyfunc = copy_function + if compare_function is None: + self._compare_function = lambda x,y: x == y + else: + self._compare_function = compare_function + self.extensions = util.to_list(extension or []) + + def __set__(self, obj, value): + self.set(None, obj, value) + def __delete__(self, obj): + self.delete(None, obj) + def __get__(self, obj, owner): + if obj is None: + return self + return self.get(obj) + + def is_equal(self, x, y): + return self._compare_function(x, y) + def copy(self, value): + return self._copyfunc(value) + + def check_mutable_modified(self, obj): + if self.mutable_scalars: + h = self.get_history(obj, passive=True) + if h is not None and h.is_modified(): + obj._state['modified'] = True + return True + else: + return False + else: + return False + + + def hasparent(self, item, optimistic=False): + """return the boolean value of a "hasparent" flag attached to the given item. + + the 'optimistic' flag determines what the default return value should be if + no "hasparent" flag can be located. as this function is used to determine if + an instance is an "orphan", instances that were loaded from storage should be assumed + to not be orphans, until a True/False value for this flag is set. an instance attribute + that is loaded by a callable function will also not have a "hasparent" flag. + """ + return item._state.get(('hasparent', id(self)), optimistic) + + def sethasparent(self, item, value): + """sets a boolean flag on the given item corresponding to whether or not it is + attached to a parent object via the attribute represented by this InstrumentedAttribute.""" + item._state[('hasparent', id(self))] = value + + def get_history(self, obj, passive=False): + """return a new AttributeHistory object for the given object/this attribute's key. + + if passive is True, then dont execute any callables; if the attribute's value + can only be achieved via executing a callable, then return None.""" + # get the current state. this may trigger a lazy load if + # passive is False. + current = self.get(obj, passive=passive, raiseerr=False) + if current is InstrumentedAttribute.PASSIVE_NORESULT: + return None + return AttributeHistory(self, obj, current, passive=passive) + + def set_callable(self, obj, callable_): + """set a callable function for this attribute on the given object. + + this callable will be executed when the attribute is next accessed, + and is assumed to construct part of the instances previously stored state. When + its value or values are loaded, they will be established as part of the + instance's "committed state". while "trackparent" information will be assembled + for these instances, attribute-level event handlers will not be fired. + + the callable overrides the class level callable set in the InstrumentedAttribute + constructor. + """ + if callable_ is None: + self.initialize(obj) + else: + obj._state[('callable', self)] = callable_ + + def reset(self, obj): + """removes any per-instance callable functions corresponding to this InstrumentedAttribute's attribute + from the given object, and removes this InstrumentedAttribute's + attribute from the given object's dictionary.""" + try: + del obj._state[('callable', self)] + except KeyError: + pass + self.clear(obj) + + def clear(self, obj): + """removes this InstrumentedAttribute's attribute from the given object's dictionary. subsequent calls to + getattr(obj, key) will raise an AttributeError by default.""" + try: + del obj.__dict__[self.key] + except KeyError: + pass + + def _get_callable(self, obj): + if obj._state.has_key(('callable', self)): + return obj._state[('callable', self)] + elif self.callable_ is not None: + return self.callable_(obj) + else: + return None + + def _blank_list(self): + if self.typecallable is not None: + return self.typecallable() + else: + return [] + + def _adapt_list(self, data): + if self.typecallable is not None: + t = self.typecallable() + if data is not None: + [t.append(x) for x in data] + return t + else: + return data + + def initialize(self, obj): + """initialize this attribute on the given object instance. + + if this is a list-based attribute, a new, blank list will be created. + if a scalar attribute, the value will be initialized to None.""" + if self.uselist: + l = InstrumentedList(self, obj, self._blank_list()) + obj.__dict__[self.key] = l + return l + else: + obj.__dict__[self.key] = None + return None + + def get(self, obj, passive=False, raiseerr=True): + """retrieves a value from the given object. if a callable is assembled + on this object's attribute, and passive is False, the callable will be executed + and the resulting value will be set as the new value for this attribute.""" + try: + return obj.__dict__[self.key] + except KeyError: + state = obj._state + # if an instance-wide "trigger" was set, call that + # and start again + if state.has_key('trigger'): + trig = state['trigger'] + del state['trigger'] + trig() + return self.get(obj, passive=passive, raiseerr=raiseerr) + + if self.uselist: + callable_ = self._get_callable(obj) + if callable_ is not None: + if passive: + return InstrumentedAttribute.PASSIVE_NORESULT + self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key)) + values = callable_() + l = InstrumentedList(self, obj, self._adapt_list(values), init=False) + + # if a callable was executed, then its part of the "committed state" + # if any, so commit the newly loaded data + orig = state.get('original', None) + if orig is not None: + orig.commit_attribute(self, obj, l) + + else: + # note that we arent raising AttributeErrors, just creating a new + # blank list and setting it. + # this might be a good thing to be changeable by options. + l = InstrumentedList(self, obj, self._blank_list(), init=False) + obj.__dict__[self.key] = l + return l + else: + callable_ = self._get_callable(obj) + if callable_ is not None: + if passive: + return InstrumentedAttribute.PASSIVE_NORESULT + self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key)) + value = callable_() + obj.__dict__[self.key] = value + + # if a callable was executed, then its part of the "committed state" + # if any, so commit the newly loaded data + orig = state.get('original', None) + if orig is not None: + orig.commit_attribute(self, obj) + return obj.__dict__[self.key] + else: + # note that we arent raising AttributeErrors, just returning None. + # this might be a good thing to be changeable by options. + return None + + def set(self, event, obj, value): + """sets a value on the given object. 'event' is the InstrumentedAttribute that + initiated the set() operation and is used to control the depth of a circular setter + operation.""" + if event is not self: + state = obj._state + # if an instance-wide "trigger" was set, call that + if state.has_key('trigger'): + trig = state['trigger'] + del state['trigger'] + trig() + if self.uselist: + value = InstrumentedList(self, obj, value) + old = self.get(obj) + obj.__dict__[self.key] = value + state['modified'] = True + if not self.uselist: + if self.trackparent: + if value is not None: + self.sethasparent(value, True) + if old is not None: + self.sethasparent(old, False) + for ext in self.extensions: + ext.set(event or self, obj, value, old) + else: + # mark all the old elements as detached from the parent + old.list_replaced() + + def delete(self, event, obj): + """deletes a value from the given object. 'event' is the InstrumentedAttribute that + initiated the delete() operation and is used to control the depth of a circular delete + operation.""" + if event is not self: + try: + if not self.uselist and (self.trackparent or len(self.extensions)): + old = self.get(obj) + del obj.__dict__[self.key] + except KeyError: + # TODO: raise this? not consistent with get() ? + raise AttributeError(self.key) + obj._state['modified'] = True + if not self.uselist: + if self.trackparent: + if old is not None: + self.sethasparent(old, False) + for ext in self.extensions: + ext.delete(event or self, obj, old) + + def append(self, event, obj, value): + """appends an element to a list based element or sets a scalar based element to the given value. + Used by GenericBackrefExtension to "append" an item independent of list/scalar semantics. + 'event' is the InstrumentedAttribute that initiated the append() operation and is used to control + the depth of a circular append operation.""" + if self.uselist: + if event is not self: + self.get(obj).append_with_event(value, event) + else: + self.set(event, obj, value) + + def remove(self, event, obj, value): + """removes an element from a list based element or sets a scalar based element to None. + Used by GenericBackrefExtension to "remove" an item independent of list/scalar semantics. + 'event' is the InstrumentedAttribute that initiated the remove() operation and is used to control + the depth of a circular remove operation.""" + if self.uselist: + if event is not self: + self.get(obj).remove_with_event(value, event) + else: + self.set(event, obj, None) + + def append_event(self, event, obj, value): + """called by InstrumentedList when an item is appended""" + obj._state['modified'] = True + if self.trackparent and value is not None: + self.sethasparent(value, True) + for ext in self.extensions: + ext.append(event or self, obj, value) + + def remove_event(self, event, obj, value): + """called by InstrumentedList when an item is removed""" + obj._state['modified'] = True + if self.trackparent and value is not None: + self.sethasparent(value, False) + for ext in self.extensions: + ext.delete(event or self, obj, value) +InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute) + +class InstrumentedList(object): + """instruments a list-based attribute. all mutator operations (i.e. append, remove, etc.) will fire off events to the + InstrumentedAttribute that manages the object's attribute. those events in turn trigger things like + backref operations and whatever is implemented by do_list_value_changed on InstrumentedAttribute. + + note that this list does a lot less than earlier versions of SA list-based attributes, which used HistoryArraySet. + this list wrapper does *not* maintain setlike semantics, meaning you can add as many duplicates as + you want (which can break a lot of SQL), and also does not do anything related to history tracking. + + Please see ticket #213 for information on the future of this class, where it will be broken out into more + collection-specific subtypes.""" + def __init__(self, attr, obj, data, init=True): + self.attr = attr + # this weakref is to prevent circular references between the parent object + # and the list attribute, which interferes with immediate garbage collection. + self.__obj = weakref.ref(obj) + self.key = attr.key + self.data = data or attr._blank_list() + + # adapt to lists or sets + # TODO: make three subclasses of InstrumentedList that come off from a + # metaclass, based on the type of data sent in + if hasattr(self.data, 'append'): + self._data_appender = self.data.append + self._clear_data = self._clear_list + elif hasattr(self.data, 'add'): + self._data_appender = self.data.add + self._clear_data = self._clear_set + if isinstance(self.data, dict): + self._clear_data = self._clear_dict + + if init: + for x in self.data: + self.__setrecord(x) + + def list_replaced(self): + """fires off delete event handlers for each item in the list but + doesnt affect the original data list""" + [self.__delrecord(x) for x in self.data] + + def clear(self): + """clears all items in this InstrumentedList and fires off delete event handlers for each item""" + self._clear_data() + def _clear_dict(self): + [self.__delrecord(x) for x in self.data.values()] + self.data.clear() + def _clear_set(self): + [self.__delrecord(x) for x in self.data] + self.data.clear() + def _clear_list(self): + self[:] = [] + + def __getstate__(self): + """implemented to allow pickling, since __obj is a weakref, also the InstrumentedAttribute has callables + attached to it""" + return {'key':self.key, 'obj':self.obj, 'data':self.data} + def __setstate__(self, d): + """implemented to allow pickling, since __obj is a weakref, also the InstrumentedAttribute has callables + attached to it""" + self.key = d['key'] + self.__obj = weakref.ref(d['obj']) + self.data = d['data'] + self.attr = getattr(d['obj'].__class__, self.key) + + obj = property(lambda s:s.__obj()) + + def unchanged_items(self): + """deprecated""" + return self.attr.get_history(self.obj).unchanged_items + def added_items(self): + """deprecated""" + return self.attr.get_history(self.obj).added_items + def deleted_items(self): + """deprecated""" + return self.attr.get_history(self.obj).deleted_items + + def __iter__(self): + return iter(self.data) + def __repr__(self): + return repr(self.data) + + def __getattr__(self, attr): + """proxies unknown methods and attributes to the underlying + data array. this allows custom list classes to be used.""" + return getattr(self.data, attr) + + def __setrecord(self, item, event=None): + self.attr.append_event(event, self.obj, item) + return True + + def __delrecord(self, item, event=None): + self.attr.remove_event(event, self.obj, item) + return True + + def append_with_event(self, item, event): + self.__setrecord(item, event) + self._data_appender(item) + + def append_without_event(self, item): + self._data_appender(item) + + def remove_with_event(self, item, event): + self.__delrecord(item, event) + self.data.remove(item) + + def append(self, item, _mapper_nohistory=False): + """fires off dependent events, and appends the given item to the underlying list. + _mapper_nohistory is a backwards compatibility hack; call append_without_event instead.""" + if _mapper_nohistory: + self.append_without_event(item) + else: + self.__setrecord(item) + self._data_appender(item) + + + def __getitem__(self, i): + return self.data[i] + def __setitem__(self, i, item): + if isinstance(i, slice): + self.__setslice__(i.start, i.stop, item) + else: + self.__setrecord(item) + self.data[i] = item + def __delitem__(self, i): + if isinstance(i, slice): + self.__delslice__(i.start, i.stop) + else: + self.__delrecord(self.data[i], None) + del self.data[i] + + def __lt__(self, other): return self.data < self.__cast(other) + def __le__(self, other): return self.data <= self.__cast(other) + def __eq__(self, other): return self.data == self.__cast(other) + def __ne__(self, other): return self.data != self.__cast(other) + def __gt__(self, other): return self.data > self.__cast(other) + def __ge__(self, other): return self.data >= self.__cast(other) + def __cast(self, other): + if isinstance(other, InstrumentedList): return other.data + else: return other + def __cmp__(self, other): + return cmp(self.data, self.__cast(other)) + def __contains__(self, item): return item in self.data + def __len__(self): return len(self.data) + def __setslice__(self, i, j, other): + i = max(i, 0); j = max(j, 0) + [self.__delrecord(x) for x in self.data[i:]] + g = [a for a in list(other) if self.__setrecord(a)] + self.data[i:] = g + def __delslice__(self, i, j): + i = max(i, 0); j = max(j, 0) + for a in self.data[i:j]: + self.__delrecord(a) + del self.data[i:j] + def insert(self, i, item): + if self.__setrecord(item): + self.data.insert(i, item) + def pop(self, i=-1): + item = self.data[i] + self.__delrecord(item) + return self.data.pop(i) + def remove(self, item): + self.__delrecord(item) + self.data.remove(item) + def extend(self, item_list): + for item in item_list: + self.append(item) + def __add__(self, other): + raise NotImplementedError() + def __radd__(self, other): + raise NotImplementedError() + def __iadd__(self, other): + raise NotImplementedError() + +class AttributeExtension(object): + """an abstract class which specifies "append", "delete", and "set" + event handlers to be attached to an object property.""" + def append(self, event, obj, child): + pass + def delete(self, event, obj, child): + pass + def set(self, event, obj, child, oldchild): + pass + +class GenericBackrefExtension(AttributeExtension): + """an extension which synchronizes a two-way relationship. A typical two-way + relationship is a parent object containing a list of child objects, where each + child object references the parent. The other are two objects which contain + scalar references to each other.""" + def __init__(self, key): + self.key = key + def set(self, event, obj, child, oldchild): + if oldchild is child: + return + if oldchild is not None: + getattr(oldchild.__class__, self.key).remove(event, oldchild, obj) + if child is not None: + getattr(child.__class__, self.key).append(event, child, obj) + def append(self, event, obj, child): + getattr(child.__class__, self.key).append(event, child, obj) + def delete(self, event, obj, child): + getattr(child.__class__, self.key).remove(event, child, obj) + +class CommittedState(object): + """stores the original state of an object when the commit() method on the attribute manager + is called.""" + NO_VALUE = object() + + def __init__(self, manager, obj): + self.data = {} + for attr in manager.managed_attributes(obj.__class__): + self.commit_attribute(attr, obj) + + def commit_attribute(self, attr, obj, value=NO_VALUE): + """establish the value of attribute 'attr' on instance 'obj' as "committed". + + this corresponds to a previously saved state being restored. """ + if value is CommittedState.NO_VALUE: + if obj.__dict__.has_key(attr.key): + value = obj.__dict__[attr.key] + if value is not CommittedState.NO_VALUE: + self.data[attr.key] = attr.copy(value) + + # not tracking parent on lazy-loaded instances at the moment. + # its not needed since they will be "optimistically" tested + #if attr.uselist: + #if attr.trackparent: + # [attr.sethasparent(x, True) for x in self.data[attr.key] if x is not None] + #else: + #if attr.trackparent and value is not None: + # attr.sethasparent(value, True) + + def rollback(self, manager, obj): + for attr in manager.managed_attributes(obj.__class__): + if self.data.has_key(attr.key): + if attr.uselist: + obj.__dict__[attr.key][:] = self.data[attr.key] + else: + obj.__dict__[attr.key] = self.data[attr.key] + else: + del obj.__dict__[attr.key] + + def __repr__(self): + return "CommittedState: %s" % repr(self.data) + +class AttributeHistory(object): + """calculates the "history" of a particular attribute on a particular instance, based on the CommittedState + associated with the instance, if any.""" + def __init__(self, attr, obj, current, passive=False): + self.attr = attr + + # get the "original" value. if a lazy load was fired when we got + # the 'current' value, this "original" was also populated just + # now as well (therefore we have to get it second) + orig = obj._state.get('original', None) + if orig is not None: + original = orig.data.get(attr.key) + else: + original = None + + if attr.uselist: + self._current = current + else: + self._current = [current] + if attr.uselist: + s = util.Set(original or []) + self._added_items = [] + self._unchanged_items = [] + self._deleted_items = [] + if current: + for a in current: + if a in s: + self._unchanged_items.append(a) + else: + self._added_items.append(a) + for a in s: + if a not in self._unchanged_items: + self._deleted_items.append(a) + else: + if attr.is_equal(current, original): + self._unchanged_items = [current] + self._added_items = [] + self._deleted_items = [] + else: + self._added_items = [current] + if original is not None: + self._deleted_items = [original] + else: + self._deleted_items = [] + self._unchanged_items = [] + #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items + def __iter__(self): + return iter(self._current) + def is_modified(self): + return len(self._deleted_items) > 0 or len(self._added_items) > 0 + def added_items(self): + return self._added_items + def unchanged_items(self): + return self._unchanged_items + def deleted_items(self): + return self._deleted_items + def hasparent(self, obj): + """deprecated. this should be called directly from the appropriate InstrumentedAttribute object.""" + return self.attr.hasparent(obj) + +class AttributeManager(object): + """allows the instrumentation of object attributes. AttributeManager is stateless, but can be + overridden by subclasses to redefine some of its factory operations.""" + + def rollback(self, *obj): + """retrieves the committed history for each object in the given list, and rolls back the attributes + each instance to their original value.""" + for o in obj: + orig = o._state.get('original') + if orig is not None: + orig.rollback(self, o) + else: + self._clear(o) + + def _clear(self, obj): + for attr in self.managed_attributes(obj.__class__): + try: + del obj.__dict__[attr.key] + except KeyError: + pass + + def commit(self, *obj): + """creates a CommittedState instance for each object in the given list, representing + its "unchanged" state, and associates it with the instance. AttributeHistory objects + will indicate the modified state of instance attributes as compared to its value in this + CommittedState object.""" + for o in obj: + o._state['original'] = CommittedState(self, o) + o._state['modified'] = False + + def managed_attributes(self, class_): + """returns an iterator of all InstrumentedAttribute objects associated with the given class.""" + if not isinstance(class_, type): + raise repr(class_) + " is not a type" + for key in dir(class_): + value = getattr(class_, key, None) + if isinstance(value, InstrumentedAttribute): + yield value + + def noninherited_managed_attributes(self, class_): + if not isinstance(class_, type): + raise repr(class_) + " is not a type" + for key in list(class_.__dict__): + value = getattr(class_, key, None) + if isinstance(value, InstrumentedAttribute): + yield value + + def is_modified(self, object): + for attr in self.managed_attributes(object.__class__): + if attr.check_mutable_modified(object): + return True + return object._state.get('modified', False) + + def init_attr(self, obj): + """sets up the __sa_attr_state dictionary on the given instance. This dictionary is + automatically created when the '_state' attribute of the class is first accessed, but calling + it here will save a single throw of an AttributeError that occurs in that creation step.""" + setattr(obj, '_%s__sa_attr_state' % obj.__class__.__name__, {}) + + def get_history(self, obj, key, **kwargs): + """returns a new AttributeHistory object for the given attribute on the given object.""" + return getattr(obj.__class__, key).get_history(obj, **kwargs) + + def get_as_list(self, obj, key, passive=False): + """returns an attribute of the given name from the given object. if the attribute + is a scalar, returns it as a single-item list, otherwise returns the list based attribute. + if the attribute's value is to be produced by an unexecuted callable, + the callable will only be executed if the given 'passive' flag is False. + """ + attr = getattr(obj.__class__, key) + x = attr.get(obj, passive=passive) + if x is InstrumentedAttribute.PASSIVE_NORESULT: + return [] + elif attr.uselist: + return x + else: + return [x] + + def trigger_history(self, obj, callable): + """clears all managed object attributes and places the given callable + as an attribute-wide "trigger", which will execute upon the next attribute access, after + which the trigger is removed.""" + self._clear(obj) + try: + del obj._state['original'] + except KeyError: + pass + obj._state['trigger'] = callable + + def untrigger_history(self, obj): + """removes a trigger function set by trigger_history. does not restore the previous state of the object.""" + del obj._state['trigger'] + + def has_trigger(self, obj): + """returns True if the given object has a trigger function set by trigger_history().""" + return obj._state.has_key('trigger') + + def reset_instance_attribute(self, obj, key): + """removes any per-instance callable functions corresponding to given attribute key + from the given object, and removes this attribute from the given object's dictionary.""" + attr = getattr(obj.__class__, key) + attr.reset(obj) + + def reset_class_managed(self, class_): + """removes all InstrumentedAttribute property objects from the given class.""" + for attr in self.noninherited_managed_attributes(class_): + delattr(class_, attr.key) + + def is_class_managed(self, class_, key): + """returns True if the given key correponds to an instrumented property on the given class.""" + return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute) + + def init_instance_attribute(self, obj, key, uselist, callable_=None, **kwargs): + """initializes an attribute on an instance to either a blank value, cancelling + out any class- or instance-level callables that were present, or if a callable + is supplied sets the callable to be invoked when the attribute is next accessed.""" + getattr(obj.__class__, key).set_callable(obj, callable_) + + def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs): + """creates a scalar property object, defaulting to InstrumentedAttribute, which + will communicate change events back to this AttributeManager.""" + return InstrumentedAttribute(self, key, uselist, callable_, typecallable, **kwargs) + + def register_attribute(self, class_, key, uselist, callable_=None, **kwargs): + """registers an attribute at the class level to be instrumented for all instances + of the class.""" + #print self, "register attribute", key, "for class", class_ + if not hasattr(class_, '_state'): + def _get_state(self): + try: + return self.__sa_attr_state + except AttributeError: + self.__sa_attr_state = {} + return self.__sa_attr_state + class_._state = property(_get_state) + + typecallable = kwargs.pop('typecallable', None) + if typecallable is None: + typecallable = getattr(class_, key, None) + if isinstance(typecallable, InstrumentedAttribute): + typecallable = None + setattr(class_, key, self.create_prop(class_, key, uselist, callable_, typecallable=typecallable, **kwargs)) + diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index b3dbe7a92..1327644b2 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -21,7 +21,7 @@ mapper_registry = weakref.WeakKeyDictionary() # a list of MapperExtensions that will be installed in all mappers by default global_extensions = [] -# a constant returned by _getattrbycolumn to indicate +# a constant returned by get_attr_by_column to indicate # this mapper is not handling an attribute for a particular # column NO_ATTRIBUTE = object() @@ -289,10 +289,6 @@ class Mapper(object): def _compile_extensions(self): """goes through the global_extensions list as well as the list of MapperExtensions specified for this Mapper and creates a linked list of those extensions.""" - # uber-pendantic style of making mapper chain, as various testbase/ - # threadlocal/assignmapper combinations keep putting dupes etc. in the list - # TODO: do something that isnt 21 lines.... - extlist = util.Set() for ext_class in global_extensions: if isinstance(ext_class, MapperExtension): @@ -307,7 +303,7 @@ class Mapper(object): self.extension = _ExtensionCarrier() for ext in extlist: - self.extension.elements.append(ext) + self.extension.append(ext) def _compile_inheritance(self): """determines if this Mapper inherits from another mapper, and if so calculates the mapped_table @@ -744,7 +740,7 @@ class Mapper(object): def primary_key_from_instance(self, instance): """return the list of primary key values for the given instance.""" - return [self._getattrbycolumn(instance, column) for column in self.pks_by_table[self.mapped_table]] + return [self.get_attr_by_column(instance, column) for column in self.pks_by_table[self.mapped_table]] def instance_key(self, instance): """deprecated. a synonym for identity_key_from_instance.""" @@ -773,21 +769,27 @@ class Mapper(object): raise exceptions.InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self))) return prop[0] - def _getattrbycolumn(self, obj, column, raiseerror=True): + def get_attr_by_column(self, obj, column, raiseerror=True): + """return an instance attribute using a Column as the key.""" prop = self._getpropbycolumn(column, raiseerror) if prop is None: return NO_ATTRIBUTE #self.__log_debug("get column attribute '%s' from instance %s" % (column.key, mapperutil.instance_str(obj))) return prop.getattr(obj) - def _setattrbycolumn(self, obj, column, value): + def set_attr_by_column(self, obj, column, value): + """set the value of an instance attribute using a Column as the key.""" self.columntoproperty[column][0].setattr(obj, value) def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False): - """save a list of objects. + """issue INSERT and/or UPDATE statements for a list of objects. + + this is called within the context of a UOWTransaction during a flush operation. - this method is called within a unit of work flush() process. It saves objects that are mapped not just - by this mapper, but inherited mappers as well, so that insert ordering of polymorphic objects is maintained.""" + save_obj issues SQL statements not just for instances mapped directly by this mapper, but + for instances mapped by all inheriting mappers as well. This is to maintain proper insert + ordering among a polymorphic chain of instances. Therefore save_obj is typically + called only on a "base mapper", or a mapper which does not inherit from any other mapper.""" self.__log_debug("save_obj() start, " + (single and "non-batched" or "batched")) @@ -848,7 +850,7 @@ class Mapper(object): for col in table.columns: if col is mapper.version_id_col: if not isinsert: - params[col._label] = mapper._getattrbycolumn(obj, col) + params[col._label] = mapper.get_attr_by_column(obj, col) params[col.key] = params[col._label] + 1 else: params[col.key] = 1 @@ -857,14 +859,14 @@ class Mapper(object): if not isinsert: # doing an UPDATE? put primary key values as "WHERE" parameters # matching the bindparam we are creating below, i.e. "<tablename>_<colname>" - params[col._label] = mapper._getattrbycolumn(obj, col) + params[col._label] = mapper.get_attr_by_column(obj, col) else: # doing an INSERT, primary key col ? # if the primary key values are not populated, # leave them out of the INSERT altogether, since PostGres doesn't want # them to be present for SERIAL to take effect. A SQLEngine that uses # explicit sequences will put them back in if they are needed - value = mapper._getattrbycolumn(obj, col) + value = mapper.get_attr_by_column(obj, col) if value is not None: params[col.key] = value elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col): @@ -882,7 +884,7 @@ class Mapper(object): if post_update_cols is not None and col not in post_update_cols: continue elif is_row_switch: - params[col.key] = self._getattrbycolumn(obj, col) + params[col.key] = self.get_attr_by_column(obj, col) hasdata = True continue prop = mapper._getpropbycolumn(col, False) @@ -901,7 +903,7 @@ class Mapper(object): # default. if its None and theres no default, we still might # not want to put it in the col list but SQLIte doesnt seem to like that # if theres no columns at all - value = mapper._getattrbycolumn(obj, col, False) + value = mapper.get_attr_by_column(obj, col, False) if value is NO_ATTRIBUTE: continue if col.default is None or value is not None: @@ -955,8 +957,8 @@ class Mapper(object): if primary_key is not None: i = 0 for col in mapper.pks_by_table[table]: - if mapper._getattrbycolumn(obj, col) is None and len(primary_key) > i: - mapper._setattrbycolumn(obj, col, primary_key[i]) + if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i: + mapper.set_attr_by_column(obj, col, primary_key[i]) i+=1 mapper._postfetch(connection, table, obj, c, c.last_inserted_params()) @@ -987,26 +989,29 @@ class Mapper(object): if resultproxy.lastrow_has_defaults(): clause = sql.and_() for p in self.pks_by_table[table]: - clause.clauses.append(p == self._getattrbycolumn(obj, p)) + clause.clauses.append(p == self.get_attr_by_column(obj, p)) row = connection.execute(table.select(clause), None).fetchone() for c in table.c: - if self._getattrbycolumn(obj, c, False) is None: - self._setattrbycolumn(obj, c, row[c]) + if self.get_attr_by_column(obj, c, False) is None: + self.set_attr_by_column(obj, c, row[c]) else: for c in table.c: if c.primary_key or not params.has_key(c.name): continue - v = self._getattrbycolumn(obj, c, False) + v = self.get_attr_by_column(obj, c, False) if v is NO_ATTRIBUTE: continue elif v != params.get_original(c.name): - self._setattrbycolumn(obj, c, params.get_original(c.name)) + self.set_attr_by_column(obj, c, params.get_original(c.name)) def delete_obj(self, objects, uowtransaction): - """called by a UnitOfWork object to delete objects, which involves a - DELETE statement for each table used by this mapper, for each object in the list.""" + """issue DELETE statements for a list of objects. + + this is called within the context of a UOWTransaction during a flush operation.""" + + self.__log_debug("delete_obj() start") + connection = uowtransaction.transaction.connection(self) - #print "DELETE_OBJ MAPPER", self.class_.__name__, objects [self.extension.before_delete(self, connection, obj) for obj in objects] deleted_objects = util.Set() @@ -1021,9 +1026,9 @@ class Mapper(object): else: delete.append(params) for col in self.pks_by_table[table]: - params[col.key] = self._getattrbycolumn(obj, col) + params[col.key] = self.get_attr_by_column(obj, col) if self.version_id_col is not None: - params[self.version_id_col.key] = self._getattrbycolumn(obj, self.version_id_col) + params[self.version_id_col.key] = self.get_attr_by_column(obj, self.version_id_col) deleted_objects.add(obj) if len(delete): def comparator(a, b): @@ -1123,8 +1128,8 @@ class Mapper(object): instance = context.session._get(identitykey) self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) isnew = False - if context.version_check and self.version_id_col is not None and self._getattrbycolumn(instance, self.version_id_col) != row[self.version_id_col]: - raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._getattrbycolumn(instance, self.version_id_col), row[self.version_id_col])) + if context.version_check and self.version_id_col is not None and self.get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]: + raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self.get_attr_by_column(instance, self.version_id_col), row[self.version_id_col])) if populate_existing or context.session.is_expired(instance, unexpire=True): if not context.identity_map.has_key(identitykey): @@ -1186,8 +1191,11 @@ class Mapper(object): return obj def translate_row(self, tomapper, row): - """attempts to take a row and translate its values to a row that can - be understood by another mapper.""" + """translate the column keys of a row into a new or proxied row that + can be understood by another mapper. + + This can be used in conjunction with populate_instance to populate + an instance using an alternate mapper.""" newrow = util.DictDecorator(row) for c in tomapper.mapped_table.c: c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=True) @@ -1195,9 +1203,11 @@ class Mapper(object): newrow[c] = row[c2] return newrow - def populate_instance(self, selectcontext, instance, row, identitykey, isnew, frommapper=None): - if frommapper is not None: - row = frommapper.translate_row(self, row) + def populate_instance(self, selectcontext, instance, row, identitykey, isnew): + """populate an instance from a result row. + + This method iterates through the list of MapperProperty objects attached to this Mapper + and calls each properties execute() method.""" for prop in self.__props.values(): prop.execute(selectcontext, instance, row, identitykey, isnew) @@ -1213,6 +1223,24 @@ class MapperExtension(object): Note: this is not called if a session is provided with the __init__ params (i.e. _sa_session)""" return EXT_PASS + def load(self, query, *args, **kwargs): + """override the load method of the Query object. + + the return value of this method is used as the result of query.load() if the + value is anything other than EXT_PASS.""" + return EXT_PASS + def get(self, query, *args, **kwargs): + """override the get method of the Query object. + + the return value of this method is used as the result of query.get() if the + value is anything other than EXT_PASS.""" + return EXT_PASS + def get_by(self, query, *args, **kwargs): + """override the get_by method of the Query object. + + the return value of this method is used as the result of query.get_by() if the + value is anything other than EXT_PASS.""" + return EXT_PASS def select_by(self, query, *args, **kwargs): """override the select_by method of the Query object. @@ -1271,14 +1299,7 @@ class MapperExtension(object): as relationships to other classes). If this method returns EXT_PASS, instance population will proceed normally. If any other value or None is returned, instance population will not proceed, giving this extension an opportunity to populate the instance itself, - if desired.. - - A common usage of this method is to have population performed by an alternate mapper. This can - be acheived via the populate_instance() call on Mapper. - - def populate_instance(self, mapper, selectcontext, instance, row, identitykey, isnew): - othermapper.populate_instance(selectcontext, instance, row, identitykey, isnew, frommapper=mapper) - return None + if desired. """ return EXT_PASS def before_insert(self, mapper, connection, instance): @@ -1304,41 +1325,24 @@ class MapperExtension(object): class _ExtensionCarrier(MapperExtension): def __init__(self): - self.elements = [] + self.__elements = [] + self.__callables = {} def insert(self, extension): """insert a MapperExtension at the beginning of this ExtensionCarrier's list.""" - self.elements.insert(0, extension) + self.__elements.insert(0, extension) def append(self, extension): """append a MapperExtension at the end of this ExtensionCarrier's list.""" - self.elements.append(extension) - # TODO: shrink down this approach using __getattribute__ or similar - def get_session(self): - return self._do('get_session') - def select_by(self, *args, **kwargs): - return self._do('select_by', *args, **kwargs) - def select(self, *args, **kwargs): - return self._do('select', *args, **kwargs) - def create_instance(self, *args, **kwargs): - return self._do('create_instance', *args, **kwargs) - def append_result(self, *args, **kwargs): - return self._do('append_result', *args, **kwargs) - def populate_instance(self, *args, **kwargs): - return self._do('populate_instance', *args, **kwargs) - def before_insert(self, *args, **kwargs): - return self._do('before_insert', *args, **kwargs) - def before_update(self, *args, **kwargs): - return self._do('before_update', *args, **kwargs) - def after_update(self, *args, **kwargs): - return self._do('after_update', *args, **kwargs) - def after_insert(self, *args, **kwargs): - return self._do('after_insert', *args, **kwargs) - def before_delete(self, *args, **kwargs): - return self._do('before_delete', *args, **kwargs) - def after_delete(self, *args, **kwargs): - return self._do('after_delete', *args, **kwargs) - + self.__elements.append(extension) + def __getattribute__(self, key): + if key in MapperExtension.__dict__: + try: + return self.__callables[key] + except KeyError: + return self.__callables.setdefault(key, lambda *args, **kwargs:self._do(key, *args, **kwargs)) + else: + return super(_ExtensionCarrier, self).__getattribute__(key) def _do(self, funcname, *args, **kwargs): - for elem in self.elements: + for elem in self.__elements: if elem is self: raise exceptions.AssertionError("ExtensionCarrier set to itself") ret = getattr(elem, funcname)(*args, **kwargs) @@ -1346,7 +1350,7 @@ class _ExtensionCarrier(MapperExtension): return ret else: return EXT_PASS - + class ExtensionOption(MapperExtension): def __init__(self, ext): self.ext = ext diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 768b51959..501903983 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -9,10 +9,11 @@ well as relationships. the objects rely upon the LoaderStrategy objects in the module to handle load operations. PropertyLoader also relies upon the dependency.py module to handle flush-time dependency sorting and processing.""" -from sqlalchemy import sql, schema, util, attributes, exceptions, sql_util, logging +from sqlalchemy import sql, schema, util, exceptions, sql_util, logging import mapper import sync import strategies +import attributes import session as sessionlib import dependency import util as mapperutil diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 82ed4e1a0..607b37fd9 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -54,6 +54,9 @@ class Query(object): The ident argument is a scalar or tuple of primary key column values in the order of the table def's primary key columns.""" + ret = self.extension.get(self, ident, **kwargs) + if ret is not mapper.EXT_PASS: + return ret key = self.mapper.identity_key(ident) return self._get(key, ident, **kwargs) @@ -63,6 +66,9 @@ class Query(object): If not found, raises an exception. The method will *remove all pending changes* to the object already existing in the Session. The ident argument is a scalar or tuple of primary key column values in the order of the table def's primary key columns.""" + ret = self.extension.load(self, ident, **kwargs) + if ret is not mapper.EXT_PASS: + return ret key = self.mapper.identity_key(ident) instance = self._get(key, ident, reload=True, **kwargs) if instance is None: @@ -83,6 +89,9 @@ class Query(object): e.g. u = usermapper.get_by(user_name = 'fred') """ + ret = self.extension.get_by(self, *args, **params) + if ret is not mapper.EXT_PASS: + return ret x = self.select_whereclause(self.join_by(*args, **params), limit=1) if x: return x[0] diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 916a60d25..ac6b37e33 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -4,6 +4,8 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions.""" + from sqlalchemy import sql, schema, util, attributes, exceptions, sql_util, logging import mapper, query from interfaces import * @@ -11,7 +13,6 @@ import session as sessionlib import util as mapperutil import sets, random -"""sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions.""" class ColumnLoader(LoaderStrategy): def init(self): @@ -73,7 +74,7 @@ class DeferredColumnLoader(LoaderStrategy): clause = sql.and_() for primary_key in pk: - attr = self.parent._getattrbycolumn(instance, primary_key) + attr = self.parent.get_attr_by_column(instance, primary_key) if not attr: return None clause.clauses.append(primary_key == attr) @@ -178,7 +179,7 @@ class LazyLoader(AbstractRelationLoader): return None #print "setting up loader, lazywhere", str(self.lazywhere), "binds", self.lazybinds for col, bind in self.lazybinds.iteritems(): - params[bind.key] = self.parent._getattrbycolumn(instance, col) + params[bind.key] = self.parent.get_attr_by_column(instance, col) if params[bind.key] is None: allparams = False break diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 5f0331e16..6fe848a6f 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -132,7 +132,7 @@ class SyncRule(object): if clearkeys or source is None: value = None else: - value = self.source_mapper._getattrbycolumn(source, self.source_column) + value = self.source_mapper.get_attr_by_column(source, self.source_column) if isinstance(dest, dict): dest[self.dest_column.key] = value else: @@ -141,7 +141,7 @@ class SyncRule(object): if logging.is_debug_enabled(self.logger): self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value)) - self.dest_mapper._setattrbycolumn(dest, self.dest_column, value) + self.dest_mapper.set_attr_by_column(dest, self.dest_column, value) SyncRule.logger = logging.class_logger(SyncRule) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index c4fd92e36..815ffe029 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -14,8 +14,9 @@ an "identity map" pattern. The Unit of Work then maintains lists of objects tha dirty, or deleted and provides the capability to flush all those changes at once. """ -from sqlalchemy import attributes, util, logging, topological +from sqlalchemy import util, logging, topological import sqlalchemy +import attributes from sqlalchemy.exceptions import * import StringIO import weakref |
