summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/associationproxy.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/associationproxy.py')
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py288
1 files changed, 189 insertions, 99 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index ff9433d4d..56b91ce0b 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -76,7 +76,7 @@ def association_proxy(target_collection, attr, **kw):
return AssociationProxy(target_collection, attr, **kw)
-ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY')
+ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY")
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.AssociationProxy`.
@@ -92,10 +92,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
is_attribute = False
extension_type = ASSOCIATION_PROXY
- def __init__(self, target_collection, attr, creator=None,
- getset_factory=None, proxy_factory=None,
- proxy_bulk_set=None, info=None,
- cascade_scalar_deletes=False):
+ def __init__(
+ self,
+ target_collection,
+ attr,
+ creator=None,
+ getset_factory=None,
+ proxy_factory=None,
+ proxy_bulk_set=None,
+ info=None,
+ cascade_scalar_deletes=False,
+ ):
"""Construct a new :class:`.AssociationProxy`.
The :func:`.association_proxy` function is provided as the usual
@@ -162,8 +169,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
self.proxy_bulk_set = proxy_bulk_set
self.cascade_scalar_deletes = cascade_scalar_deletes
- self.key = '_%s_%s_%s' % (
- type(self).__name__, target_collection, id(self))
+ self.key = "_%s_%s_%s" % (
+ type(self).__name__,
+ target_collection,
+ id(self),
+ )
if info:
self.info = info
@@ -264,12 +274,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
def getter(target):
return _getter(target) if target is not None else None
+
if collection_class is dict:
+
def setter(o, k, v):
setattr(o, attr, v)
+
else:
+
def setter(o, v):
setattr(o, attr, v)
+
return getter, setter
@@ -325,20 +340,21 @@ class AssociationProxyInstance(object):
def for_proxy(cls, parent, owning_class, parent_instance):
target_collection = parent.target_collection
value_attr = parent.value_attr
- prop = orm.class_mapper(owning_class).\
- get_property(target_collection)
+ prop = orm.class_mapper(owning_class).get_property(target_collection)
# this was never asserted before but this should be made clear.
if not isinstance(prop, orm.RelationshipProperty):
raise NotImplementedError(
"association proxy to a non-relationship "
- "intermediary is not supported")
+ "intermediary is not supported"
+ )
target_class = prop.mapper.class_
try:
target_assoc = cls._cls_unwrap_target_assoc_proxy(
- target_class, value_attr)
+ target_class, value_attr
+ )
except AttributeError:
# the proxied attribute doesn't exist on the target class;
# return an "ambiguous" instance that will work on a per-object
@@ -353,8 +369,8 @@ class AssociationProxyInstance(object):
@classmethod
def _construct_for_assoc(
- cls, target_assoc, parent, owning_class,
- target_class, value_attr):
+ cls, target_assoc, parent, owning_class, target_class, value_attr
+ ):
if target_assoc is not None:
return ObjectAssociationProxyInstance(
parent, owning_class, target_class, value_attr
@@ -371,8 +387,9 @@ class AssociationProxyInstance(object):
)
def _get_property(self):
- return orm.class_mapper(self.owning_class).\
- get_property(self.target_collection)
+ return orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
@property
def _comparator(self):
@@ -388,7 +405,8 @@ class AssociationProxyInstance(object):
@util.memoized_property
def _unwrap_target_assoc_proxy(self):
return self._cls_unwrap_target_assoc_proxy(
- self.target_class, self.value_attr)
+ self.target_class, self.value_attr
+ )
@property
def remote_attr(self):
@@ -448,8 +466,11 @@ class AssociationProxyInstance(object):
@util.memoized_property
def _value_is_scalar(self):
- return not self._get_property().\
- mapper.get_property(self.value_attr).uselist
+ return (
+ not self._get_property()
+ .mapper.get_property(self.value_attr)
+ .uselist
+ )
@property
def _target_is_object(self):
@@ -468,12 +489,17 @@ class AssociationProxyInstance(object):
def getter(target):
return _getter(target) if target is not None else None
+
if collection_class is dict:
+
def setter(o, k, v):
return setattr(o, attr, v)
+
else:
+
def setter(o, v):
return setattr(o, attr, v)
+
return getter, setter
@property
@@ -500,14 +526,18 @@ class AssociationProxyInstance(object):
return proxy
self.collection_class, proxy = self._new(
- _lazy_collection(obj, self.target_collection))
+ _lazy_collection(obj, self.target_collection)
+ )
setattr(obj, self.key, (id(obj), id(self), proxy))
return proxy
def set(self, obj, values):
if self.scalar:
- creator = self.parent.creator \
- if self.parent.creator else self.target_class
+ creator = (
+ self.parent.creator
+ if self.parent.creator
+ else self.target_class
+ )
target = getattr(obj, self.target_collection)
if target is None:
if values is None:
@@ -535,35 +565,52 @@ class AssociationProxyInstance(object):
delattr(obj, self.target_collection)
def _new(self, lazy_collection):
- creator = self.parent.creator if self.parent.creator else \
- self.target_class
+ creator = (
+ self.parent.creator if self.parent.creator else self.target_class
+ )
collection_class = util.duck_type_collection(lazy_collection())
if self.parent.proxy_factory:
- return collection_class, self.parent.proxy_factory(
- lazy_collection, creator, self.value_attr, self)
+ return (
+ collection_class,
+ self.parent.proxy_factory(
+ lazy_collection, creator, self.value_attr, self
+ ),
+ )
if self.parent.getset_factory:
- getter, setter = self.parent.getset_factory(
- collection_class, self)
+ getter, setter = self.parent.getset_factory(collection_class, self)
else:
getter, setter = self.parent._default_getset(collection_class)
if collection_class is list:
- return collection_class, _AssociationList(
- lazy_collection, creator, getter, setter, self)
+ return (
+ collection_class,
+ _AssociationList(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
elif collection_class is dict:
- return collection_class, _AssociationDict(
- lazy_collection, creator, getter, setter, self)
+ return (
+ collection_class,
+ _AssociationDict(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
elif collection_class is set:
- return collection_class, _AssociationSet(
- lazy_collection, creator, getter, setter, self)
+ return (
+ collection_class,
+ _AssociationSet(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
else:
raise exc.ArgumentError(
- 'could not guess which interface to use for '
+ "could not guess which interface to use for "
'collection_class "%s" backing "%s"; specify a '
- 'proxy_factory and proxy_bulk_set manually' %
- (self.collection_class.__name__, self.target_collection))
+ "proxy_factory and proxy_bulk_set manually"
+ % (self.collection_class.__name__, self.target_collection)
+ )
def _set(self, proxy, values):
if self.parent.proxy_bulk_set:
@@ -576,16 +623,19 @@ class AssociationProxyInstance(object):
proxy.update(values)
else:
raise exc.ArgumentError(
- 'no proxy_bulk_set supplied for custom '
- 'collection_class implementation')
+ "no proxy_bulk_set supplied for custom "
+ "collection_class implementation"
+ )
def _inflate(self, proxy):
- creator = self.parent.creator and \
- self.parent.creator or self.target_class
+ creator = (
+ self.parent.creator and self.parent.creator or self.target_class
+ )
if self.parent.getset_factory:
getter, setter = self.parent.getset_factory(
- self.collection_class, self)
+ self.collection_class, self
+ )
else:
getter, setter = self.parent._default_getset(self.collection_class)
@@ -594,12 +644,13 @@ class AssociationProxyInstance(object):
proxy.setter = setter
def _criterion_exists(self, criterion=None, **kwargs):
- is_has = kwargs.pop('is_has', None)
+ is_has = kwargs.pop("is_has", None)
target_assoc = self._unwrap_target_assoc_proxy
if target_assoc is not None:
inner = target_assoc._criterion_exists(
- criterion=criterion, **kwargs)
+ criterion=criterion, **kwargs
+ )
return self._comparator._criterion_exists(inner)
if self._target_is_object:
@@ -631,15 +682,15 @@ class AssociationProxyInstance(object):
"""
if self._unwrap_target_assoc_proxy is None and (
- self.scalar and (
- not self._target_is_object or self._value_is_scalar)
+ self.scalar
+ and (not self._target_is_object or self._value_is_scalar)
):
raise exc.InvalidRequestError(
- "'any()' not implemented for scalar "
- "attributes. Use has()."
+ "'any()' not implemented for scalar " "attributes. Use has()."
)
return self._criterion_exists(
- criterion=criterion, is_has=False, **kwargs)
+ criterion=criterion, is_has=False, **kwargs
+ )
def has(self, criterion=None, **kwargs):
"""Produce a proxied 'has' expression using EXISTS.
@@ -651,14 +702,15 @@ class AssociationProxyInstance(object):
"""
if self._unwrap_target_assoc_proxy is None and (
- not self.scalar or (
- self._target_is_object and not self._value_is_scalar)
+ not self.scalar
+ or (self._target_is_object and not self._value_is_scalar)
):
raise exc.InvalidRequestError(
- "'has()' not implemented for collections. "
- "Use any().")
+ "'has()' not implemented for collections. " "Use any()."
+ )
return self._criterion_exists(
- criterion=criterion, is_has=True, **kwargs)
+ criterion=criterion, is_has=True, **kwargs
+ )
class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
@@ -673,10 +725,14 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
"Association proxy %s.%s refers to an attribute '%s' that is not "
"directly mapped on class %s; therefore this operation cannot "
"proceed since we don't know what type of object is referred "
- "towards" % (
- self.owning_class.__name__, self.target_collection,
- self.value_attr, self.target_class
- ))
+ "towards"
+ % (
+ self.owning_class.__name__,
+ self.target_collection,
+ self.value_attr,
+ self.target_class,
+ )
+ )
def get(self, obj):
self._ambiguous()
@@ -718,27 +774,32 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
return self
def _populate_cache(self, instance_class):
- prop = orm.class_mapper(self.owning_class).\
- get_property(self.target_collection)
+ prop = orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
if inspect(instance_class).mapper.isa(prop.mapper):
target_class = instance_class
try:
target_assoc = self._cls_unwrap_target_assoc_proxy(
- target_class, self.value_attr)
+ target_class, self.value_attr
+ )
except AttributeError:
pass
else:
- self._lookup_cache[instance_class] = \
- self._construct_for_assoc(
- target_assoc, self.parent, self.owning_class,
- target_class, self.value_attr
+ self._lookup_cache[instance_class] = self._construct_for_assoc(
+ target_assoc,
+ self.parent,
+ self.owning_class,
+ target_class,
+ self.value_attr,
)
class ObjectAssociationProxyInstance(AssociationProxyInstance):
"""an :class:`.AssociationProxyInstance` that has an object as a target.
"""
+
_target_is_object = True
_is_canonical = True
@@ -756,17 +817,21 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
if target_assoc is not None:
return self._comparator._criterion_exists(
target_assoc.contains(obj)
- if not target_assoc.scalar else target_assoc == obj
+ if not target_assoc.scalar
+ else target_assoc == obj
)
- elif self._target_is_object and self.scalar and \
- not self._value_is_scalar:
+ elif (
+ self._target_is_object
+ and self.scalar
+ and not self._value_is_scalar
+ ):
return self._comparator.has(
getattr(self.target_class, self.value_attr).contains(obj)
)
- elif self._target_is_object and self.scalar and \
- self._value_is_scalar:
+ elif self._target_is_object and self.scalar and self._value_is_scalar:
raise exc.InvalidRequestError(
- "contains() doesn't apply to a scalar object endpoint; use ==")
+ "contains() doesn't apply to a scalar object endpoint; use =="
+ )
else:
return self._comparator._criterion_exists(**{self.value_attr: obj})
@@ -777,7 +842,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
if obj is None:
return or_(
self._comparator.has(**{self.value_attr: obj}),
- self._comparator == None
+ self._comparator == None,
)
else:
return self._comparator.has(**{self.value_attr: obj})
@@ -786,14 +851,17 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
# note the has() here will fail for collections; eq_()
# is only allowed with a scalar.
return self._comparator.has(
- getattr(self.target_class, self.value_attr) != obj)
+ getattr(self.target_class, self.value_attr) != obj
+ )
class ColumnAssociationProxyInstance(
- ColumnOperators, AssociationProxyInstance):
+ ColumnOperators, AssociationProxyInstance
+):
"""an :class:`.AssociationProxyInstance` that has a database column as a
target.
"""
+
_target_is_object = False
_is_canonical = True
@@ -803,9 +871,7 @@ class ColumnAssociationProxyInstance(
self.remote_attr.operate(operator.eq, other)
)
if other is None:
- return or_(
- expr, self._comparator == None
- )
+ return or_(expr, self._comparator == None)
else:
return expr
@@ -824,11 +890,11 @@ class _lazy_collection(object):
return getattr(self.parent, self.target)
def __getstate__(self):
- return {'obj': self.parent, 'target': self.target}
+ return {"obj": self.parent, "target": self.target}
def __setstate__(self, state):
- self.parent = state['obj']
- self.target = state['target']
+ self.parent = state["obj"]
+ self.target = state["target"]
class _AssociationCollection(object):
@@ -874,11 +940,11 @@ class _AssociationCollection(object):
__nonzero__ = __bool__
def __getstate__(self):
- return {'parent': self.parent, 'lazy_collection': self.lazy_collection}
+ return {"parent": self.parent, "lazy_collection": self.lazy_collection}
def __setstate__(self, state):
- self.parent = state['parent']
- self.lazy_collection = state['lazy_collection']
+ self.parent = state["parent"]
+ self.lazy_collection = state["lazy_collection"]
self.parent._inflate(self)
@@ -925,8 +991,8 @@ class _AssociationList(_AssociationCollection):
if len(value) != len(rng):
raise ValueError(
"attempt to assign sequence of size %s to "
- "extended slice of size %s" % (len(value),
- len(rng)))
+ "extended slice of size %s" % (len(value), len(rng))
+ )
for i, item in zip(rng, value):
self._set(self.col[i], item)
@@ -968,8 +1034,14 @@ class _AssociationList(_AssociationCollection):
col.append(item)
def count(self, value):
- return sum([1 for _ in
- util.itertools_filter(lambda v: v == value, iter(self))])
+ return sum(
+ [
+ 1
+ for _ in util.itertools_filter(
+ lambda v: v == value, iter(self)
+ )
+ ]
+ )
def extend(self, values):
for v in values:
@@ -999,7 +1071,7 @@ class _AssociationList(_AssociationCollection):
raise NotImplementedError
def clear(self):
- del self.col[0:len(self.col)]
+ del self.col[0 : len(self.col)]
def __eq__(self, other):
return list(self) == other
@@ -1040,6 +1112,7 @@ class _AssociationList(_AssociationCollection):
if not isinstance(n, int):
return NotImplemented
return list(self) * n
+
__rmul__ = __mul__
def __iadd__(self, iterable):
@@ -1072,13 +1145,17 @@ class _AssociationList(_AssociationCollection):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(list, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
-_NotProvided = util.symbol('_NotProvided')
+_NotProvided = util.symbol("_NotProvided")
class _AssociationDict(_AssociationCollection):
@@ -1160,6 +1237,7 @@ class _AssociationDict(_AssociationCollection):
return self.col.keys()
if util.py2k:
+
def iteritems(self):
return ((key, self._get(self.col[key])) for key in self.col)
@@ -1174,7 +1252,9 @@ class _AssociationDict(_AssociationCollection):
def items(self):
return [(k, self._get(self.col[k])) for k in self]
+
else:
+
def items(self):
return ((key, self._get(self.col[key])) for key in self.col)
@@ -1194,14 +1274,15 @@ class _AssociationDict(_AssociationCollection):
def update(self, *a, **kw):
if len(a) > 1:
- raise TypeError('update expected at most 1 arguments, got %i' %
- len(a))
+ raise TypeError(
+ "update expected at most 1 arguments, got %i" % len(a)
+ )
elif len(a) == 1:
seq_or_map = a[0]
# discern dict from sequence - took the advice from
# http://www.voidspace.org.uk/python/articles/duck_typing.shtml
# still not perfect :(
- if hasattr(seq_or_map, 'keys'):
+ if hasattr(seq_or_map, "keys"):
for item in seq_or_map:
self[item] = seq_or_map[item]
else:
@@ -1211,7 +1292,8 @@ class _AssociationDict(_AssociationCollection):
except ValueError:
raise ValueError(
"dictionary update sequence "
- "requires 2-element tuples")
+ "requires 2-element tuples"
+ )
for key, value in kw:
self[key] = value
@@ -1223,8 +1305,12 @@ class _AssociationDict(_AssociationCollection):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(dict, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(dict, func_name)
+ ):
func.__doc__ = getattr(dict, func_name).__doc__
del func_name, func
@@ -1288,7 +1374,7 @@ class _AssociationSet(_AssociationCollection):
def pop(self):
if not self.col:
- raise KeyError('pop from an empty set')
+ raise KeyError("pop from an empty set")
member = self.col.pop()
return self._get(member)
@@ -1420,7 +1506,11 @@ class _AssociationSet(_AssociationCollection):
raise TypeError("%s objects are unhashable" % type(self).__name__)
for func_name, func in list(locals().items()):
- if (util.callable(func) and func.__name__ == func_name and
- not func.__doc__ and hasattr(set, func_name)):
+ if (
+ util.callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(set, func_name)
+ ):
func.__doc__ = getattr(set, func_name).__doc__
del func_name, func