diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-11-28 15:32:09 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-11-28 15:32:09 -0500 |
| commit | 2d9387139af453f855d6f20ec2e5f86023f24e2f (patch) | |
| tree | 9d564c34d3743792e013aea45932de2747d3bf6b /lib/sqlalchemy/util/_collections.py | |
| parent | 58b29394337b5a51ce71e96cc4ba6fd68218a999 (diff) | |
| download | sqlalchemy-2d9387139af453f855d6f20ec2e5f86023f24e2f.tar.gz | |
- replace util.py with util/ package, [ticket:1986]
Diffstat (limited to 'lib/sqlalchemy/util/_collections.py')
| -rw-r--r-- | lib/sqlalchemy/util/_collections.py | 885 |
1 files changed, 885 insertions, 0 deletions
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py new file mode 100644 index 000000000..6739c6d09 --- /dev/null +++ b/lib/sqlalchemy/util/_collections.py @@ -0,0 +1,885 @@ +"""Collection classes and helpers.""" + +import sys +import itertools +import weakref +import operator +from langhelpers import symbol +from compat import time_func, threading + +EMPTY_SET = frozenset() + + +class NamedTuple(tuple): + """tuple() subclass that adds labeled names. + + Is also pickleable. + + """ + + def __new__(cls, vals, labels=None): + vals = list(vals) + t = tuple.__new__(cls, vals) + if labels: + t.__dict__ = dict(itertools.izip(labels, vals)) + t._labels = labels + return t + + def keys(self): + return [l for l in self._labels if l is not None] + +class ImmutableContainer(object): + def _immutable(self, *arg, **kw): + raise TypeError("%s object is immutable" % self.__class__.__name__) + + __delitem__ = __setitem__ = __setattr__ = _immutable + +class frozendict(ImmutableContainer, dict): + + clear = pop = popitem = setdefault = \ + update = ImmutableContainer._immutable + + def __new__(cls, *args): + new = dict.__new__(cls) + dict.__init__(new, *args) + return new + + def __init__(self, *args): + pass + + def __reduce__(self): + return frozendict, (dict(self), ) + + def union(self, d): + if not self: + return frozendict(d) + else: + d2 = self.copy() + d2.update(d) + return frozendict(d2) + + def __repr__(self): + return "frozendict(%s)" % dict.__repr__(self) + +class Properties(object): + """Provide a __getattr__/__setattr__ interface over a dict.""" + + def __init__(self, data): + self.__dict__['_data'] = data + + def __len__(self): + return len(self._data) + + def __iter__(self): + return self._data.itervalues() + + def __add__(self, other): + return list(self) + list(other) + + def __setitem__(self, key, object): + self._data[key] = object + + def __getitem__(self, key): + return self._data[key] + + def __delitem__(self, key): + del self._data[key] + + def __setattr__(self, key, object): + self._data[key] = object + + def __getstate__(self): + return {'_data': self.__dict__['_data']} + + def __setstate__(self, state): + self.__dict__['_data'] = state['_data'] + + def __getattr__(self, key): + try: + return self._data[key] + except KeyError: + raise AttributeError(key) + + def __contains__(self, key): + return key in self._data + + def as_immutable(self): + """Return an immutable proxy for this :class:`.Properties`.""" + + return ImmutableProperties(self._data) + + def update(self, value): + self._data.update(value) + + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def keys(self): + return self._data.keys() + + def has_key(self, key): + return key in self._data + + def clear(self): + self._data.clear() + +class OrderedProperties(Properties): + """Provide a __getattr__/__setattr__ interface with an OrderedDict + as backing store.""" + def __init__(self): + Properties.__init__(self, OrderedDict()) + + +class ImmutableProperties(ImmutableContainer, Properties): + """Provide immutable dict/object attribute to an underlying dictionary.""" + + +class OrderedDict(dict): + """A dict that returns keys/values/items in the order they were added.""" + + def __init__(self, ____sequence=None, **kwargs): + self._list = [] + if ____sequence is None: + if kwargs: + self.update(**kwargs) + else: + self.update(____sequence, **kwargs) + + def clear(self): + self._list = [] + dict.clear(self) + + def copy(self): + return self.__copy__() + + def __copy__(self): + return OrderedDict(self) + + def sort(self, *arg, **kw): + self._list.sort(*arg, **kw) + + def update(self, ____sequence=None, **kwargs): + if ____sequence is not None: + if hasattr(____sequence, 'keys'): + for key in ____sequence.keys(): + self.__setitem__(key, ____sequence[key]) + else: + for key, value in ____sequence: + self[key] = value + if kwargs: + self.update(kwargs) + + def setdefault(self, key, value): + if key not in self: + self.__setitem__(key, value) + return value + else: + return self.__getitem__(key) + + def __iter__(self): + return iter(self._list) + + def values(self): + return [self[key] for key in self._list] + + def itervalues(self): + return iter(self.values()) + + def keys(self): + return list(self._list) + + def iterkeys(self): + return iter(self.keys()) + + def items(self): + return [(key, self[key]) for key in self.keys()] + + def iteritems(self): + return iter(self.items()) + + def __setitem__(self, key, object): + if key not in self: + try: + self._list.append(key) + except AttributeError: + # work around Python pickle loads() with + # dict subclass (seems to ignore __setstate__?) + self._list = [key] + dict.__setitem__(self, key, object) + + def __delitem__(self, key): + dict.__delitem__(self, key) + self._list.remove(key) + + def pop(self, key, *default): + present = key in self + value = dict.pop(self, key, *default) + if present: + self._list.remove(key) + return value + + def popitem(self): + item = dict.popitem(self) + self._list.remove(item[0]) + return item + +class OrderedSet(set): + def __init__(self, d=None): + set.__init__(self) + self._list = [] + if d is not None: + self.update(d) + + def add(self, element): + if element not in self: + self._list.append(element) + set.add(self, element) + + def remove(self, element): + set.remove(self, element) + self._list.remove(element) + + def insert(self, pos, element): + if element not in self: + self._list.insert(pos, element) + set.add(self, element) + + def discard(self, element): + if element in self: + self._list.remove(element) + set.remove(self, element) + + def clear(self): + set.clear(self) + self._list = [] + + def __getitem__(self, key): + return self._list[key] + + def __iter__(self): + return iter(self._list) + + def __add__(self, other): + return self.union(other) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self._list) + + __str__ = __repr__ + + def update(self, iterable): + add = self.add + for i in iterable: + add(i) + return self + + __ior__ = update + + def union(self, other): + result = self.__class__(self) + result.update(other) + return result + + __or__ = union + + def intersection(self, other): + other = set(other) + return self.__class__(a for a in self if a in other) + + __and__ = intersection + + def symmetric_difference(self, other): + other = set(other) + result = self.__class__(a for a in self if a not in other) + result.update(a for a in other if a not in self) + return result + + __xor__ = symmetric_difference + + def difference(self, other): + other = set(other) + return self.__class__(a for a in self if a not in other) + + __sub__ = difference + + def intersection_update(self, other): + other = set(other) + set.intersection_update(self, other) + self._list = [ a for a in self._list if a in other] + return self + + __iand__ = intersection_update + + def symmetric_difference_update(self, other): + set.symmetric_difference_update(self, other) + self._list = [ a for a in self._list if a in self] + self._list += [ a for a in other._list if a in self] + return self + + __ixor__ = symmetric_difference_update + + def difference_update(self, other): + set.difference_update(self, other) + self._list = [ a for a in self._list if a in self] + return self + + __isub__ = difference_update + + +class IdentitySet(object): + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + + """ + + _working_set = set + + def __init__(self, iterable=None): + self._members = dict() + if iterable: + for o in iterable: + self.add(o) + + def add(self, value): + self._members[id(value)] = value + + def __contains__(self, value): + return id(value) in self._members + + def remove(self, value): + del self._members[id(value)] + + def discard(self, value): + try: + self.remove(value) + except KeyError: + pass + + def pop(self): + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError('pop from an empty set') + + def clear(self): + self._members.clear() + + def __cmp__(self, other): + raise TypeError('cannot compare sets using cmp()') + + def __eq__(self, other): + if isinstance(other, IdentitySet): + return self._members == other._members + else: + return False + + def __ne__(self, other): + if isinstance(other, IdentitySet): + return self._members != other._members + else: + return True + + def issubset(self, iterable): + other = type(self)(iterable) + + if len(self) > len(other): + return False + for m in itertools.ifilterfalse(other._members.__contains__, + self._members.iterkeys()): + return False + return True + + def __le__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + def issuperset(self, iterable): + other = type(self)(iterable) + + if len(self) < len(other): + return False + + for m in itertools.ifilterfalse(self._members.__contains__, + other._members.iterkeys()): + return False + return True + + def __ge__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + def union(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._member_id_tuples()).union(_iter_id(iterable))) + return result + + def __or__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.union(other) + + def update(self, iterable): + self._members = self.union(iterable)._members + + def __ior__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + def difference(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._member_id_tuples()).difference(_iter_id(iterable))) + return result + + def __sub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.difference(other) + + def difference_update(self, iterable): + self._members = self.difference(iterable)._members + + def __isub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + def intersection(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._member_id_tuples()).intersection(_iter_id(iterable))) + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.intersection(other) + + def intersection_update(self, iterable): + self._members = self.intersection(iterable)._members + + def __iand__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + def symmetric_difference(self, iterable): + result = type(self)() + # testlib.pragma exempt:__hash__ + result._members.update( + self._working_set(self._member_id_tuples()).symmetric_difference(_iter_id(iterable))) + return result + + def _member_id_tuples(self): + return ((id(v), v) for v in self._members.itervalues()) + + def __xor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + def symmetric_difference_update(self, iterable): + self._members = self.symmetric_difference(iterable)._members + + def __ixor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + def copy(self): + return type(self)(self._members.itervalues()) + + __copy__ = copy + + def __len__(self): + return len(self._members) + + def __iter__(self): + return self._members.itervalues() + + def __hash__(self): + raise TypeError('set objects are unhashable') + + def __repr__(self): + return '%s(%r)' % (type(self).__name__, self._members.values()) + + +class OrderedIdentitySet(IdentitySet): + class _working_set(OrderedSet): + # a testing pragma: exempt the OIDS working set from the test suite's + # "never call the user's __hash__" assertions. this is a big hammer, + # but it's safe here: IDS operates on (id, instance) tuples in the + # working set. + __sa_hash_exempt__ = True + + def __init__(self, iterable=None): + IdentitySet.__init__(self) + self._members = OrderedDict() + if iterable: + for o in iterable: + self.add(o) + + +if sys.version_info >= (2, 5): + class PopulateDict(dict): + """A dict which populates missing values via a creation function. + + Note the creation function takes a key, unlike + collections.defaultdict. + + """ + + def __init__(self, creator): + self.creator = creator + + def __missing__(self, key): + self[key] = val = self.creator(key) + return val +else: + class PopulateDict(dict): + """A dict which populates missing values via a creation function.""" + + def __init__(self, creator): + self.creator = creator + + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + self[key] = value = self.creator(key) + return value + +# define collections that are capable of storing +# ColumnElement objects as hashable keys/elements. +column_set = set +column_dict = dict +ordered_column_set = OrderedSet +populate_column_dict = PopulateDict + +def unique_list(seq, compare_with=set): + seen = compare_with() + return [x for x in seq if x not in seen and not seen.add(x)] + +class UniqueAppender(object): + """Appends items to a collection ensuring uniqueness. + + Additional appends() of the same object are ignored. Membership is + determined by identity (``is a``) not equality (``==``). + """ + + def __init__(self, data, via=None): + self.data = data + self._unique = IdentitySet() + if via: + self._data_appender = getattr(data, via) + elif hasattr(data, 'append'): + self._data_appender = data.append + elif hasattr(data, 'add'): + # TODO: we think its a set here. bypass unneeded uniquing logic ? + self._data_appender = data.add + + def append(self, item): + if item not in self._unique: + self._data_appender(item) + self._unique.add(item) + + def __iter__(self): + return iter(self.data) + +def to_list(x, default=None): + if x is None: + return default + if not isinstance(x, (list, tuple)): + return [x] + else: + return x + +def to_set(x): + if x is None: + return set() + if not isinstance(x, set): + return set(to_list(x)) + else: + return x + +def to_column_set(x): + if x is None: + return column_set() + if not isinstance(x, column_set): + return column_set(to_list(x)) + else: + return x + +def update_copy(d, _new=None, **kw): + """Copy the given dict and update with the given values.""" + + d = d.copy() + if _new: + d.update(_new) + d.update(**kw) + return d + +def flatten_iterator(x): + """Given an iterator of which further sub-elements may also be + iterators, flatten the sub-elements into a single iterator. + + """ + for elem in x: + if not isinstance(elem, basestring) and hasattr(elem, '__iter__'): + for y in flatten_iterator(elem): + yield y + else: + yield elem + +class WeakIdentityMapping(weakref.WeakKeyDictionary): + """A WeakKeyDictionary with an object identity index. + + Adds a .by_id dictionary to a regular WeakKeyDictionary. Trades + performance during mutation operations for accelerated lookups by id(). + + The usual cautions about weak dictionaries and iteration also apply to + this subclass. + + """ + _none = symbol('none') + + def __init__(self): + weakref.WeakKeyDictionary.__init__(self) + self.by_id = {} + self._weakrefs = {} + + def __setitem__(self, object, value): + oid = id(object) + self.by_id[oid] = value + if oid not in self._weakrefs: + self._weakrefs[oid] = self._ref(object) + weakref.WeakKeyDictionary.__setitem__(self, object, value) + + def __delitem__(self, object): + del self._weakrefs[id(object)] + del self.by_id[id(object)] + weakref.WeakKeyDictionary.__delitem__(self, object) + + def setdefault(self, object, default=None): + value = weakref.WeakKeyDictionary.setdefault(self, object, default) + oid = id(object) + if value is default: + self.by_id[oid] = default + if oid not in self._weakrefs: + self._weakrefs[oid] = self._ref(object) + return value + + def pop(self, object, default=_none): + if default is self._none: + value = weakref.WeakKeyDictionary.pop(self, object) + else: + value = weakref.WeakKeyDictionary.pop(self, object, default) + if id(object) in self.by_id: + del self._weakrefs[id(object)] + del self.by_id[id(object)] + return value + + def popitem(self): + item = weakref.WeakKeyDictionary.popitem(self) + oid = id(item[0]) + del self._weakrefs[oid] + del self.by_id[oid] + return item + + def clear(self): + # Py2K + # in 3k, MutableMapping calls popitem() + self._weakrefs.clear() + self.by_id.clear() + # end Py2K + weakref.WeakKeyDictionary.clear(self) + + def update(self, *a, **kw): + raise NotImplementedError + + def _cleanup(self, wr, key=None): + if key is None: + key = wr.key + try: + del self._weakrefs[key] + except (KeyError, AttributeError): # pragma: no cover + pass # pragma: no cover + try: + del self.by_id[key] + except (KeyError, AttributeError): # pragma: no cover + pass # pragma: no cover + + class _keyed_weakref(weakref.ref): + def __init__(self, object, callback): + weakref.ref.__init__(self, object, callback) + self.key = id(object) + + def _ref(self, object): + return self._keyed_weakref(object, self._cleanup) + + +class LRUCache(dict): + """Dictionary with 'squishy' removal of least + recently used items. + + """ + def __init__(self, capacity=100, threshold=.5): + self.capacity = capacity + self.threshold = threshold + + def __getitem__(self, key): + item = dict.__getitem__(self, key) + item[2] = time_func() + return item[1] + + def values(self): + return [i[1] for i in dict.values(self)] + + def setdefault(self, key, value): + if key in self: + return self[key] + else: + self[key] = value + return value + + def __setitem__(self, key, value): + item = dict.get(self, key) + if item is None: + item = [key, value, time_func()] + dict.__setitem__(self, key, item) + else: + item[1] = value + self._manage_size() + + def _manage_size(self): + while len(self) > self.capacity + self.capacity * self.threshold: + bytime = sorted(dict.values(self), + key=operator.itemgetter(2), + reverse=True) + for item in bytime[self.capacity:]: + try: + del self[item[0]] + except KeyError: + # if we couldnt find a key, most + # likely some other thread broke in + # on us. loop around and try again + break + + +class ScopedRegistry(object): + """A Registry that can store one or multiple instances of a single + class on the basis of a "scope" function. + + The object implements ``__call__`` as the "getter", so by + calling ``myregistry()`` the contained object is returned + for the current scope. + + :param createfunc: + a callable that returns a new object to be placed in the registry + + :param scopefunc: + a callable that will return a key to store/retrieve an object. + """ + + def __init__(self, createfunc, scopefunc): + """Construct a new :class:`.ScopedRegistry`. + + :param createfunc: A creation function that will generate + a new value for the current scope, if none is present. + + :param scopefunc: A function that returns a hashable + token representing the current scope (such as, current + thread identifier). + + """ + self.createfunc = createfunc + self.scopefunc = scopefunc + self.registry = {} + + def __call__(self): + key = self.scopefunc() + try: + return self.registry[key] + except KeyError: + return self.registry.setdefault(key, self.createfunc()) + + def has(self): + """Return True if an object is present in the current scope.""" + + return self.scopefunc() in self.registry + + def set(self, obj): + """Set the value forthe current scope.""" + + self.registry[self.scopefunc()] = obj + + def clear(self): + """Clear the current scope, if any.""" + + try: + del self.registry[self.scopefunc()] + except KeyError: + pass + +class ThreadLocalRegistry(ScopedRegistry): + """A :class:`.ScopedRegistry` that uses a ``threading.local()`` + variable for storage. + + """ + def __init__(self, createfunc): + self.createfunc = createfunc + self.registry = threading.local() + + def __call__(self): + try: + return self.registry.value + except AttributeError: + val = self.registry.value = self.createfunc() + return val + + def has(self): + return hasattr(self.registry, "value") + + def set(self, obj): + self.registry.value = obj + + def clear(self): + try: + del self.registry.value + except AttributeError: + pass + + +def _iter_id(iterable): + """Generator: ((id(o), o) for o in iterable).""" + + for item in iterable: + yield id(item), item + |
