summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNils Philippsen <nils@tiptoe.de>2016-01-18 03:03:33 +0100
committerNils Philippsen <nils@tiptoe.de>2016-01-19 11:21:33 +0100
commit9a55c8d57ce8469d3788ccb5fb377ea1636b8809 (patch)
treed8b15133682dcec89d0e944000d08aa974a22cac
parentd4d9a6524886eb33644e8ce42212267fa569e555 (diff)
downloadsqlalchemy-pr/228.tar.gz
association_proxy: allow more flexible use of creatorpr/228
Add code to deal with unbound and bound methods. Allow the creator callable to have one additional parameter which would receive the instance object through which the association proxy is accessed.
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py80
1 files changed, 77 insertions, 3 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index 31f16287d..78184c48b 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -13,6 +13,7 @@ transparent proxied access to the endpoint of an association object.
See the example ``examples/association/proxied_association.py``.
"""
+import inspect
import itertools
import operator
import weakref
@@ -486,8 +487,12 @@ class _AssociationCollection(object):
object attribute managed by a SQLAlchemy relationship())
creator
- A function that creates new target entities. Given one parameter:
- value. This assertion is assumed::
+ A function or method that creates new target entities. Given one
+ mandatory parameter 'value' and one optional parameter 'parent' which
+ would receive the instance object on which the association proxy is
+ accessed.
+
+ This assertion is assumed::
obj = creator(somevalue)
assert getter(obj) == somevalue
@@ -500,12 +505,77 @@ class _AssociationCollection(object):
value on the object.
"""
+
+ # Find out what kind of creator we are dealing with.
+
+ # save the instance or class on bound methods
+ self._creator_self_or_cls = getattr(
+ creator, '__self__', getattr(creator, 'im_self', None))
+
+ # get the function for unbound static, class methods, save the original
+ self._creator_orig = creator
+ if isinstance(creator, (classmethod, staticmethod)):
+ creator = creator.__func__
+
+ self._creator_argspec = argspec = inspect.getargspec(creator)
+
+ # check if creator needs the class prepended
+ self._creator_add_cls = isinstance(self._creator_orig, classmethod)
+
+ # check if creator needs the instance prepended
+ self._creator_add_self = \
+ not isinstance(self._creator_orig, staticmethod) and \
+ inspect.isfunction(creator) and \
+ argspec.args and argspec.args[0] == 'self' and \
+ creator.__name__ in dir(lazy_collection.ref())
+
+ self._creator_is_cls = creator_is_cls = inspect.isclass(creator)
+
+ # number of arguments (sans cls or self)
+ self._creator_args_len = len(inspect.getargspec(creator)[0])
+ if (creator_is_cls or self._creator_self_or_cls or
+ self._creator_add_cls or self._creator_add_self):
+ self._creator_args_len -= 1
+
+ if self._creator_args_len < 1:
+ raise ValueError(
+ "creator function needs to have at least one argument")
+
+ # If cls or self need to be prepended or parent appended, wrap the
+ # creator with a closure doing the magic.
+ if (not self._creator_self_or_cls and (
+ self._creator_add_self or self._creator_add_cls) or
+ self._creator_add_parent):
+ def _call_creator(*p):
+ self_or_cls = self._creator_self_or_cls
+ obj = self.lazy_collection.ref()
+ if not self_or_cls:
+ self_or_cls = obj
+ if self._creator_add_self:
+ p = (obj,) + p
+ elif self._creator_add_cls:
+ p = (type(obj),) + p
+ if self._creator_add_parent:
+ if not obj:
+ raise exc.InvalidRequestError(
+ "stale association proxy, parent object has gone out of "
+ "scope")
+ p += (obj,)
+ return creator(*p)
+ _call_creator.creator = creator
+ self.creator = _call_creator
+ else:
+ self.creator = creator
+
self.lazy_collection = lazy_collection
- self.creator = creator
self.getter = getter
self.setter = setter
self.parent = parent
+ @property
+ def _creator_add_parent(self):
+ return self._creator_args_len == 2
+
col = property(lambda self: self.lazy_collection())
def __len__(self):
@@ -723,6 +793,10 @@ _NotProvided = util.symbol('_NotProvided')
class _AssociationDict(_AssociationCollection):
"""Generic, converting, dict-to-dict proxy."""
+ @property
+ def _creator_add_parent(self):
+ return self._creator_args_len == 3
+
def _create(self, key, value):
return self.creator(key, value)