diff options
Diffstat (limited to 'lib/sqlalchemy/ext/associationproxy.py')
| -rw-r--r-- | lib/sqlalchemy/ext/associationproxy.py | 288 |
1 files changed, 189 insertions, 99 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index ff9433d4d..56b91ce0b 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -76,7 +76,7 @@ def association_proxy(target_collection, attr, **kw): return AssociationProxy(target_collection, attr, **kw) -ASSOCIATION_PROXY = util.symbol('ASSOCIATION_PROXY') +ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY") """Symbol indicating an :class:`InspectionAttr` that's of type :class:`.AssociationProxy`. @@ -92,10 +92,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo): is_attribute = False extension_type = ASSOCIATION_PROXY - def __init__(self, target_collection, attr, creator=None, - getset_factory=None, proxy_factory=None, - proxy_bulk_set=None, info=None, - cascade_scalar_deletes=False): + def __init__( + self, + target_collection, + attr, + creator=None, + getset_factory=None, + proxy_factory=None, + proxy_bulk_set=None, + info=None, + cascade_scalar_deletes=False, + ): """Construct a new :class:`.AssociationProxy`. The :func:`.association_proxy` function is provided as the usual @@ -162,8 +169,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo): self.proxy_bulk_set = proxy_bulk_set self.cascade_scalar_deletes = cascade_scalar_deletes - self.key = '_%s_%s_%s' % ( - type(self).__name__, target_collection, id(self)) + self.key = "_%s_%s_%s" % ( + type(self).__name__, + target_collection, + id(self), + ) if info: self.info = info @@ -264,12 +274,17 @@ class AssociationProxy(interfaces.InspectionAttrInfo): def getter(target): return _getter(target) if target is not None else None + if collection_class is dict: + def setter(o, k, v): setattr(o, attr, v) + else: + def setter(o, v): setattr(o, attr, v) + return getter, setter @@ -325,20 +340,21 @@ class AssociationProxyInstance(object): def for_proxy(cls, parent, owning_class, parent_instance): target_collection = parent.target_collection value_attr = parent.value_attr - prop = orm.class_mapper(owning_class).\ - get_property(target_collection) + prop = orm.class_mapper(owning_class).get_property(target_collection) # this was never asserted before but this should be made clear. if not isinstance(prop, orm.RelationshipProperty): raise NotImplementedError( "association proxy to a non-relationship " - "intermediary is not supported") + "intermediary is not supported" + ) target_class = prop.mapper.class_ try: target_assoc = cls._cls_unwrap_target_assoc_proxy( - target_class, value_attr) + target_class, value_attr + ) except AttributeError: # the proxied attribute doesn't exist on the target class; # return an "ambiguous" instance that will work on a per-object @@ -353,8 +369,8 @@ class AssociationProxyInstance(object): @classmethod def _construct_for_assoc( - cls, target_assoc, parent, owning_class, - target_class, value_attr): + cls, target_assoc, parent, owning_class, target_class, value_attr + ): if target_assoc is not None: return ObjectAssociationProxyInstance( parent, owning_class, target_class, value_attr @@ -371,8 +387,9 @@ class AssociationProxyInstance(object): ) def _get_property(self): - return orm.class_mapper(self.owning_class).\ - get_property(self.target_collection) + return orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) @property def _comparator(self): @@ -388,7 +405,8 @@ class AssociationProxyInstance(object): @util.memoized_property def _unwrap_target_assoc_proxy(self): return self._cls_unwrap_target_assoc_proxy( - self.target_class, self.value_attr) + self.target_class, self.value_attr + ) @property def remote_attr(self): @@ -448,8 +466,11 @@ class AssociationProxyInstance(object): @util.memoized_property def _value_is_scalar(self): - return not self._get_property().\ - mapper.get_property(self.value_attr).uselist + return ( + not self._get_property() + .mapper.get_property(self.value_attr) + .uselist + ) @property def _target_is_object(self): @@ -468,12 +489,17 @@ class AssociationProxyInstance(object): def getter(target): return _getter(target) if target is not None else None + if collection_class is dict: + def setter(o, k, v): return setattr(o, attr, v) + else: + def setter(o, v): return setattr(o, attr, v) + return getter, setter @property @@ -500,14 +526,18 @@ class AssociationProxyInstance(object): return proxy self.collection_class, proxy = self._new( - _lazy_collection(obj, self.target_collection)) + _lazy_collection(obj, self.target_collection) + ) setattr(obj, self.key, (id(obj), id(self), proxy)) return proxy def set(self, obj, values): if self.scalar: - creator = self.parent.creator \ - if self.parent.creator else self.target_class + creator = ( + self.parent.creator + if self.parent.creator + else self.target_class + ) target = getattr(obj, self.target_collection) if target is None: if values is None: @@ -535,35 +565,52 @@ class AssociationProxyInstance(object): delattr(obj, self.target_collection) def _new(self, lazy_collection): - creator = self.parent.creator if self.parent.creator else \ - self.target_class + creator = ( + self.parent.creator if self.parent.creator else self.target_class + ) collection_class = util.duck_type_collection(lazy_collection()) if self.parent.proxy_factory: - return collection_class, self.parent.proxy_factory( - lazy_collection, creator, self.value_attr, self) + return ( + collection_class, + self.parent.proxy_factory( + lazy_collection, creator, self.value_attr, self + ), + ) if self.parent.getset_factory: - getter, setter = self.parent.getset_factory( - collection_class, self) + getter, setter = self.parent.getset_factory(collection_class, self) else: getter, setter = self.parent._default_getset(collection_class) if collection_class is list: - return collection_class, _AssociationList( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationList( + lazy_collection, creator, getter, setter, self + ), + ) elif collection_class is dict: - return collection_class, _AssociationDict( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationDict( + lazy_collection, creator, getter, setter, self + ), + ) elif collection_class is set: - return collection_class, _AssociationSet( - lazy_collection, creator, getter, setter, self) + return ( + collection_class, + _AssociationSet( + lazy_collection, creator, getter, setter, self + ), + ) else: raise exc.ArgumentError( - 'could not guess which interface to use for ' + "could not guess which interface to use for " 'collection_class "%s" backing "%s"; specify a ' - 'proxy_factory and proxy_bulk_set manually' % - (self.collection_class.__name__, self.target_collection)) + "proxy_factory and proxy_bulk_set manually" + % (self.collection_class.__name__, self.target_collection) + ) def _set(self, proxy, values): if self.parent.proxy_bulk_set: @@ -576,16 +623,19 @@ class AssociationProxyInstance(object): proxy.update(values) else: raise exc.ArgumentError( - 'no proxy_bulk_set supplied for custom ' - 'collection_class implementation') + "no proxy_bulk_set supplied for custom " + "collection_class implementation" + ) def _inflate(self, proxy): - creator = self.parent.creator and \ - self.parent.creator or self.target_class + creator = ( + self.parent.creator and self.parent.creator or self.target_class + ) if self.parent.getset_factory: getter, setter = self.parent.getset_factory( - self.collection_class, self) + self.collection_class, self + ) else: getter, setter = self.parent._default_getset(self.collection_class) @@ -594,12 +644,13 @@ class AssociationProxyInstance(object): proxy.setter = setter def _criterion_exists(self, criterion=None, **kwargs): - is_has = kwargs.pop('is_has', None) + is_has = kwargs.pop("is_has", None) target_assoc = self._unwrap_target_assoc_proxy if target_assoc is not None: inner = target_assoc._criterion_exists( - criterion=criterion, **kwargs) + criterion=criterion, **kwargs + ) return self._comparator._criterion_exists(inner) if self._target_is_object: @@ -631,15 +682,15 @@ class AssociationProxyInstance(object): """ if self._unwrap_target_assoc_proxy is None and ( - self.scalar and ( - not self._target_is_object or self._value_is_scalar) + self.scalar + and (not self._target_is_object or self._value_is_scalar) ): raise exc.InvalidRequestError( - "'any()' not implemented for scalar " - "attributes. Use has()." + "'any()' not implemented for scalar " "attributes. Use has()." ) return self._criterion_exists( - criterion=criterion, is_has=False, **kwargs) + criterion=criterion, is_has=False, **kwargs + ) def has(self, criterion=None, **kwargs): """Produce a proxied 'has' expression using EXISTS. @@ -651,14 +702,15 @@ class AssociationProxyInstance(object): """ if self._unwrap_target_assoc_proxy is None and ( - not self.scalar or ( - self._target_is_object and not self._value_is_scalar) + not self.scalar + or (self._target_is_object and not self._value_is_scalar) ): raise exc.InvalidRequestError( - "'has()' not implemented for collections. " - "Use any().") + "'has()' not implemented for collections. " "Use any()." + ) return self._criterion_exists( - criterion=criterion, is_has=True, **kwargs) + criterion=criterion, is_has=True, **kwargs + ) class AmbiguousAssociationProxyInstance(AssociationProxyInstance): @@ -673,10 +725,14 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): "Association proxy %s.%s refers to an attribute '%s' that is not " "directly mapped on class %s; therefore this operation cannot " "proceed since we don't know what type of object is referred " - "towards" % ( - self.owning_class.__name__, self.target_collection, - self.value_attr, self.target_class - )) + "towards" + % ( + self.owning_class.__name__, + self.target_collection, + self.value_attr, + self.target_class, + ) + ) def get(self, obj): self._ambiguous() @@ -718,27 +774,32 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance): return self def _populate_cache(self, instance_class): - prop = orm.class_mapper(self.owning_class).\ - get_property(self.target_collection) + prop = orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) if inspect(instance_class).mapper.isa(prop.mapper): target_class = instance_class try: target_assoc = self._cls_unwrap_target_assoc_proxy( - target_class, self.value_attr) + target_class, self.value_attr + ) except AttributeError: pass else: - self._lookup_cache[instance_class] = \ - self._construct_for_assoc( - target_assoc, self.parent, self.owning_class, - target_class, self.value_attr + self._lookup_cache[instance_class] = self._construct_for_assoc( + target_assoc, + self.parent, + self.owning_class, + target_class, + self.value_attr, ) class ObjectAssociationProxyInstance(AssociationProxyInstance): """an :class:`.AssociationProxyInstance` that has an object as a target. """ + _target_is_object = True _is_canonical = True @@ -756,17 +817,21 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): if target_assoc is not None: return self._comparator._criterion_exists( target_assoc.contains(obj) - if not target_assoc.scalar else target_assoc == obj + if not target_assoc.scalar + else target_assoc == obj ) - elif self._target_is_object and self.scalar and \ - not self._value_is_scalar: + elif ( + self._target_is_object + and self.scalar + and not self._value_is_scalar + ): return self._comparator.has( getattr(self.target_class, self.value_attr).contains(obj) ) - elif self._target_is_object and self.scalar and \ - self._value_is_scalar: + elif self._target_is_object and self.scalar and self._value_is_scalar: raise exc.InvalidRequestError( - "contains() doesn't apply to a scalar object endpoint; use ==") + "contains() doesn't apply to a scalar object endpoint; use ==" + ) else: return self._comparator._criterion_exists(**{self.value_attr: obj}) @@ -777,7 +842,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): if obj is None: return or_( self._comparator.has(**{self.value_attr: obj}), - self._comparator == None + self._comparator == None, ) else: return self._comparator.has(**{self.value_attr: obj}) @@ -786,14 +851,17 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): # note the has() here will fail for collections; eq_() # is only allowed with a scalar. return self._comparator.has( - getattr(self.target_class, self.value_attr) != obj) + getattr(self.target_class, self.value_attr) != obj + ) class ColumnAssociationProxyInstance( - ColumnOperators, AssociationProxyInstance): + ColumnOperators, AssociationProxyInstance +): """an :class:`.AssociationProxyInstance` that has a database column as a target. """ + _target_is_object = False _is_canonical = True @@ -803,9 +871,7 @@ class ColumnAssociationProxyInstance( self.remote_attr.operate(operator.eq, other) ) if other is None: - return or_( - expr, self._comparator == None - ) + return or_(expr, self._comparator == None) else: return expr @@ -824,11 +890,11 @@ class _lazy_collection(object): return getattr(self.parent, self.target) def __getstate__(self): - return {'obj': self.parent, 'target': self.target} + return {"obj": self.parent, "target": self.target} def __setstate__(self, state): - self.parent = state['obj'] - self.target = state['target'] + self.parent = state["obj"] + self.target = state["target"] class _AssociationCollection(object): @@ -874,11 +940,11 @@ class _AssociationCollection(object): __nonzero__ = __bool__ def __getstate__(self): - return {'parent': self.parent, 'lazy_collection': self.lazy_collection} + return {"parent": self.parent, "lazy_collection": self.lazy_collection} def __setstate__(self, state): - self.parent = state['parent'] - self.lazy_collection = state['lazy_collection'] + self.parent = state["parent"] + self.lazy_collection = state["lazy_collection"] self.parent._inflate(self) @@ -925,8 +991,8 @@ class _AssociationList(_AssociationCollection): if len(value) != len(rng): raise ValueError( "attempt to assign sequence of size %s to " - "extended slice of size %s" % (len(value), - len(rng))) + "extended slice of size %s" % (len(value), len(rng)) + ) for i, item in zip(rng, value): self._set(self.col[i], item) @@ -968,8 +1034,14 @@ class _AssociationList(_AssociationCollection): col.append(item) def count(self, value): - return sum([1 for _ in - util.itertools_filter(lambda v: v == value, iter(self))]) + return sum( + [ + 1 + for _ in util.itertools_filter( + lambda v: v == value, iter(self) + ) + ] + ) def extend(self, values): for v in values: @@ -999,7 +1071,7 @@ class _AssociationList(_AssociationCollection): raise NotImplementedError def clear(self): - del self.col[0:len(self.col)] + del self.col[0 : len(self.col)] def __eq__(self, other): return list(self) == other @@ -1040,6 +1112,7 @@ class _AssociationList(_AssociationCollection): if not isinstance(n, int): return NotImplemented return list(self) * n + __rmul__ = __mul__ def __iadd__(self, iterable): @@ -1072,13 +1145,17 @@ class _AssociationList(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(list, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func -_NotProvided = util.symbol('_NotProvided') +_NotProvided = util.symbol("_NotProvided") class _AssociationDict(_AssociationCollection): @@ -1160,6 +1237,7 @@ class _AssociationDict(_AssociationCollection): return self.col.keys() if util.py2k: + def iteritems(self): return ((key, self._get(self.col[key])) for key in self.col) @@ -1174,7 +1252,9 @@ class _AssociationDict(_AssociationCollection): def items(self): return [(k, self._get(self.col[k])) for k in self] + else: + def items(self): return ((key, self._get(self.col[key])) for key in self.col) @@ -1194,14 +1274,15 @@ class _AssociationDict(_AssociationCollection): def update(self, *a, **kw): if len(a) > 1: - raise TypeError('update expected at most 1 arguments, got %i' % - len(a)) + raise TypeError( + "update expected at most 1 arguments, got %i" % len(a) + ) elif len(a) == 1: seq_or_map = a[0] # discern dict from sequence - took the advice from # http://www.voidspace.org.uk/python/articles/duck_typing.shtml # still not perfect :( - if hasattr(seq_or_map, 'keys'): + if hasattr(seq_or_map, "keys"): for item in seq_or_map: self[item] = seq_or_map[item] else: @@ -1211,7 +1292,8 @@ class _AssociationDict(_AssociationCollection): except ValueError: raise ValueError( "dictionary update sequence " - "requires 2-element tuples") + "requires 2-element tuples" + ) for key, value in kw: self[key] = value @@ -1223,8 +1305,12 @@ class _AssociationDict(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(dict, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(dict, func_name) + ): func.__doc__ = getattr(dict, func_name).__doc__ del func_name, func @@ -1288,7 +1374,7 @@ class _AssociationSet(_AssociationCollection): def pop(self): if not self.col: - raise KeyError('pop from an empty set') + raise KeyError("pop from an empty set") member = self.col.pop() return self._get(member) @@ -1420,7 +1506,11 @@ class _AssociationSet(_AssociationCollection): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in list(locals().items()): - if (util.callable(func) and func.__name__ == func_name and - not func.__doc__ and hasattr(set, func_name)): + if ( + util.callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(set, func_name) + ): func.__doc__ = getattr(set, func_name).__doc__ del func_name, func |
