diff options
Diffstat (limited to 'lib/sqlalchemy/util/_collections.py')
-rw-r--r-- | lib/sqlalchemy/util/_collections.py | 34 |
1 files changed, 26 insertions, 8 deletions
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 86a90828a..c0a24ba4f 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -1,16 +1,17 @@ # util/_collections.py -# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file> +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """Collection classes and helpers.""" -import itertools +from __future__ import absolute_import import weakref import operator from .compat import threading, itertools_filterfalse from . import py2k +import types EMPTY_SET = frozenset() @@ -650,19 +651,31 @@ class IdentitySet(object): class WeakSequence(object): - def __init__(self, elements): - self._storage = weakref.WeakValueDictionary( - (idx, element) for idx, element in enumerate(elements) - ) + def __init__(self, __elements=()): + self._storage = [ + weakref.ref(element, self._remove) for element in __elements + ] + + def append(self, item): + self._storage.append(weakref.ref(item, self._remove)) + + def _remove(self, ref): + self._storage.remove(ref) + + def __len__(self): + return len(self._storage) def __iter__(self): - return iter(self._storage.values()) + return (obj for obj in + (ref() for ref in self._storage) if obj is not None) def __getitem__(self, index): try: - return self._storage[index] + obj = self._storage[index] except KeyError: raise IndexError("Index %s out of range" % index) + else: + return obj() class OrderedIdentitySet(IdentitySet): @@ -743,6 +756,11 @@ class UniqueAppender(object): def __iter__(self): return iter(self.data) +def coerce_generator_arg(arg): + if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): + return list(arg[0]) + else: + return arg def to_list(x, default=None): if x is None: |