summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/associationproxy.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2011-02-13 20:20:34 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2011-02-13 20:20:34 -0500
commit2722035809364af9d6ea533241d34935ca17e6af (patch)
tree3b8503cc261399330c35aa9b9b6e3b829705a6db /lib/sqlalchemy/ext/associationproxy.py
parent2e4da52221c9f231117b93c9709a36dc65b8c9b0 (diff)
downloadsqlalchemy-2722035809364af9d6ea533241d34935ca17e6af.tar.gz
- Association proxy now has correct behavior for
any(), has(), and contains() when proxying a many-to-one scalar attribute to a one-to-many collection (i.e. the reverse of the 'typical' association proxy use case) [ticket:2054]
Diffstat (limited to 'lib/sqlalchemy/ext/associationproxy.py')
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py68
1 files changed, 48 insertions, 20 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index 969f60326..31bfa90ff 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -23,7 +23,9 @@ from sqlalchemy.sql import not_
def association_proxy(target_collection, attr, **kw):
- """Return a Python property implementing a view of *attr* over a collection.
+ """Return a Python property implementing a view of a target
+ attribute which references an attribute on members of the
+ target.
Implements a read/write view over an instance's *target_collection*,
extracting *attr* from each member of the collection. The property acts
@@ -35,16 +37,19 @@ def association_proxy(target_collection, attr, **kw):
Unlike the list comprehension, the collection returned by the property is
always in sync with *target_collection*, and mutations made to either
collection will be reflected in both.
+
+ The association proxy also works with scalar attributes, which in
+ turn reference scalar attributes or collections.
Implements a Python property representing a relationship as a collection of
- simpler values. The proxied property will mimic the collection type of
+ simpler values, or a scalar value. The proxied property will mimic the collection type of
the target (list, dict or set), or, in the case of a one to one relationship,
a simple scalar value.
:param target_collection: Name of the relationship attribute we'll proxy to,
usually created with :func:`~sqlalchemy.orm.relationship`.
- :param attr: Attribute on the associated instances we'll proxy for.
+ :param attr: Attribute on the associated instance or instances we'll proxy for.
For example, given a target collection of [obj1, obj2], a list created
by this proxy property would look like [getattr(obj1, *attr*),
@@ -75,7 +80,7 @@ def association_proxy(target_collection, attr, **kw):
situation.
:param \*\*kw: Passes along any other keyword arguments to
- :class:`AssociationProxy`.
+ :class:`.AssociationProxy`.
"""
return AssociationProxy(target_collection, attr, **kw)
@@ -85,7 +90,8 @@ class AssociationProxy(object):
"""A descriptor that presents a read/write view of an object attribute."""
def __init__(self, target_collection, attr, creator=None,
- getset_factory=None, proxy_factory=None, proxy_bulk_set=None):
+ getset_factory=None, proxy_factory=None,
+ proxy_bulk_set=None):
"""Arguments are:
target_collection
@@ -137,7 +143,6 @@ class AssociationProxy(object):
self.proxy_factory = proxy_factory
self.proxy_bulk_set = proxy_bulk_set
- self.scalar = None
self.owning_class = None
self.key = '_%s_%s_%s' % (
type(self).__name__, target_collection, id(self))
@@ -147,23 +152,28 @@ class AssociationProxy(object):
return (orm.class_mapper(self.owning_class).
get_property(self.target_collection))
- @property
+ @util.memoized_property
def target_class(self):
"""The class the proxy is attached to."""
return self._get_property().mapper.class_
- def _target_is_scalar(self):
- return not self._get_property().uselist
+ @util.memoized_property
+ def scalar(self):
+ scalar = not self._get_property().uselist
+ if scalar:
+ self._initialize_scalar_accessors()
+ return scalar
+
+ @util.memoized_property
+ def _value_is_scalar(self):
+ return not self._get_property().\
+ mapper.get_property(self.value_attr).uselist
def __get__(self, obj, class_):
if self.owning_class is None:
self.owning_class = class_ and class_ or type(obj)
if obj is None:
return self
- elif self.scalar is None:
- self.scalar = self._target_is_scalar()
- if self.scalar:
- self._initialize_scalar_accessors()
if self.scalar:
return self._scalar_get(getattr(obj, self.target_collection))
@@ -183,10 +193,6 @@ class AssociationProxy(object):
def __set__(self, obj, values):
if self.owning_class is None:
self.owning_class = type(obj)
- if self.scalar is None:
- self.scalar = self._target_is_scalar()
- if self.scalar:
- self._initialize_scalar_accessors()
if self.scalar:
creator = self.creator and self.creator or self.target_class
@@ -278,13 +284,35 @@ class AssociationProxy(object):
return self._get_property().comparator
def any(self, criterion=None, **kwargs):
- return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
+ if self._value_is_scalar:
+ value_expr = getattr(self.target_class, self.value_attr).has(criterion, **kwargs)
+ else:
+ value_expr = getattr(self.target_class, self.value_attr).any(criterion, **kwargs)
+
+ # check _value_is_scalar here, otherwise
+ # we're scalar->scalar - call .any() so that
+ # the "can't call any() on a scalar" msg is raised.
+ if self.scalar and not self._value_is_scalar:
+ return self._comparator.has(
+ value_expr
+ )
+ else:
+ return self._comparator.any(
+ value_expr
+ )
def has(self, criterion=None, **kwargs):
- return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
+ return self._comparator.has(
+ getattr(self.target_class, self.value_attr).has(criterion, **kwargs)
+ )
def contains(self, obj):
- return self._comparator.any(**{self.value_attr: obj})
+ if self.scalar and not self._value_is_scalar:
+ return self._comparator.has(
+ getattr(self.target_class, self.value_attr).contains(obj)
+ )
+ else:
+ return self._comparator.any(**{self.value_attr: obj})
def __eq__(self, obj):
return self._comparator.has(**{self.value_attr: obj})