diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2011-02-13 20:20:34 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2011-02-13 20:20:34 -0500 |
| commit | 2722035809364af9d6ea533241d34935ca17e6af (patch) | |
| tree | 3b8503cc261399330c35aa9b9b6e3b829705a6db /lib/sqlalchemy/ext/associationproxy.py | |
| parent | 2e4da52221c9f231117b93c9709a36dc65b8c9b0 (diff) | |
| download | sqlalchemy-2722035809364af9d6ea533241d34935ca17e6af.tar.gz | |
- Association proxy now has correct behavior for
any(), has(), and contains() when proxying
a many-to-one scalar attribute to a one-to-many
collection (i.e. the reverse of the 'typical'
association proxy use case) [ticket:2054]
Diffstat (limited to 'lib/sqlalchemy/ext/associationproxy.py')
| -rw-r--r-- | lib/sqlalchemy/ext/associationproxy.py | 68 |
1 files changed, 48 insertions, 20 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 969f60326..31bfa90ff 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -23,7 +23,9 @@ from sqlalchemy.sql import not_ def association_proxy(target_collection, attr, **kw): - """Return a Python property implementing a view of *attr* over a collection. + """Return a Python property implementing a view of a target + attribute which references an attribute on members of the + target. Implements a read/write view over an instance's *target_collection*, extracting *attr* from each member of the collection. The property acts @@ -35,16 +37,19 @@ def association_proxy(target_collection, attr, **kw): Unlike the list comprehension, the collection returned by the property is always in sync with *target_collection*, and mutations made to either collection will be reflected in both. + + The association proxy also works with scalar attributes, which in + turn reference scalar attributes or collections. Implements a Python property representing a relationship as a collection of - simpler values. The proxied property will mimic the collection type of + simpler values, or a scalar value. The proxied property will mimic the collection type of the target (list, dict or set), or, in the case of a one to one relationship, a simple scalar value. :param target_collection: Name of the relationship attribute we'll proxy to, usually created with :func:`~sqlalchemy.orm.relationship`. - :param attr: Attribute on the associated instances we'll proxy for. + :param attr: Attribute on the associated instance or instances we'll proxy for. For example, given a target collection of [obj1, obj2], a list created by this proxy property would look like [getattr(obj1, *attr*), @@ -75,7 +80,7 @@ def association_proxy(target_collection, attr, **kw): situation. :param \*\*kw: Passes along any other keyword arguments to - :class:`AssociationProxy`. + :class:`.AssociationProxy`. """ return AssociationProxy(target_collection, attr, **kw) @@ -85,7 +90,8 @@ class AssociationProxy(object): """A descriptor that presents a read/write view of an object attribute.""" def __init__(self, target_collection, attr, creator=None, - getset_factory=None, proxy_factory=None, proxy_bulk_set=None): + getset_factory=None, proxy_factory=None, + proxy_bulk_set=None): """Arguments are: target_collection @@ -137,7 +143,6 @@ class AssociationProxy(object): self.proxy_factory = proxy_factory self.proxy_bulk_set = proxy_bulk_set - self.scalar = None self.owning_class = None self.key = '_%s_%s_%s' % ( type(self).__name__, target_collection, id(self)) @@ -147,23 +152,28 @@ class AssociationProxy(object): return (orm.class_mapper(self.owning_class). get_property(self.target_collection)) - @property + @util.memoized_property def target_class(self): """The class the proxy is attached to.""" return self._get_property().mapper.class_ - def _target_is_scalar(self): - return not self._get_property().uselist + @util.memoized_property + def scalar(self): + scalar = not self._get_property().uselist + if scalar: + self._initialize_scalar_accessors() + return scalar + + @util.memoized_property + def _value_is_scalar(self): + return not self._get_property().\ + mapper.get_property(self.value_attr).uselist def __get__(self, obj, class_): if self.owning_class is None: self.owning_class = class_ and class_ or type(obj) if obj is None: return self - elif self.scalar is None: - self.scalar = self._target_is_scalar() - if self.scalar: - self._initialize_scalar_accessors() if self.scalar: return self._scalar_get(getattr(obj, self.target_collection)) @@ -183,10 +193,6 @@ class AssociationProxy(object): def __set__(self, obj, values): if self.owning_class is None: self.owning_class = type(obj) - if self.scalar is None: - self.scalar = self._target_is_scalar() - if self.scalar: - self._initialize_scalar_accessors() if self.scalar: creator = self.creator and self.creator or self.target_class @@ -278,13 +284,35 @@ class AssociationProxy(object): return self._get_property().comparator def any(self, criterion=None, **kwargs): - return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) + if self._value_is_scalar: + value_expr = getattr(self.target_class, self.value_attr).has(criterion, **kwargs) + else: + value_expr = getattr(self.target_class, self.value_attr).any(criterion, **kwargs) + + # check _value_is_scalar here, otherwise + # we're scalar->scalar - call .any() so that + # the "can't call any() on a scalar" msg is raised. + if self.scalar and not self._value_is_scalar: + return self._comparator.has( + value_expr + ) + else: + return self._comparator.any( + value_expr + ) def has(self, criterion=None, **kwargs): - return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) + return self._comparator.has( + getattr(self.target_class, self.value_attr).has(criterion, **kwargs) + ) def contains(self, obj): - return self._comparator.any(**{self.value_attr: obj}) + if self.scalar and not self._value_is_scalar: + return self._comparator.has( + getattr(self.target_class, self.value_attr).contains(obj) + ) + else: + return self._comparator.any(**{self.value_attr: obj}) def __eq__(self, obj): return self._comparator.has(**{self.value_attr: obj}) |
