diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2023-03-31 13:51:47 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2023-03-31 13:51:47 +0000 |
commit | cfd31fa67d622c9b1e9811dde2f1f87173ed648e (patch) | |
tree | 66413373adf63f39d85a375277d5c1e36265bae1 | |
parent | 5f73d2f6332dfaadb1b816f8dce539433014c654 (diff) | |
parent | a979b6dc5ebefedfd8c85f5695cc5be8882eaa29 (diff) | |
download | sqlalchemy-cfd31fa67d622c9b1e9811dde2f1f87173ed648e.tar.gz |
Merge "Add missing methods to OrderedSet." into main
-rw-r--r-- | doc/build/changelog/unreleased_20/9487.rst | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/cyextension/collections.pyx | 65 | ||||
-rw-r--r-- | lib/sqlalchemy/util/_py_collections.py | 37 | ||||
-rw-r--r-- | test/base/test_utils.py | 128 |
4 files changed, 198 insertions, 38 deletions
diff --git a/doc/build/changelog/unreleased_20/9487.rst b/doc/build/changelog/unreleased_20/9487.rst new file mode 100644 index 000000000..627be0e61 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9487.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, util + :tickets: 9487 + + Implemented missing methods ``copy`` and ``pop`` in + OrderedSet class. diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx index e6667dddd..d08fa3aab 100644 --- a/lib/sqlalchemy/cyextension/collections.pyx +++ b/lib/sqlalchemy/cyextension/collections.pyx @@ -1,8 +1,9 @@ cimport cython from cpython.dict cimport PyDict_Merge, PyDict_Update -from cpython.long cimport PyLong_FromLong +from cpython.long cimport PyLong_FromLongLong from cpython.set cimport PySet_Add +from collections.abc import Collection from itertools import filterfalse cdef bint add_not_present(set seen, object item, hashfunc): @@ -39,8 +40,7 @@ cdef class OrderedSet(set): else: self._list = [] - @cython.final - cdef OrderedSet _copy(self): + cpdef OrderedSet copy(self): cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) cp._list = list(self._list) set.update(cp, cp._list) @@ -63,6 +63,14 @@ cdef class OrderedSet(set): set.remove(self, element) self._list.remove(element) + def pop(self): + try: + value = self._list.pop() + except IndexError: + raise KeyError("pop from an empty set") from None + set.remove(self, value) + return value + def insert(self, Py_ssize_t pos, element): if element not in self: self._list.insert(pos, element) @@ -91,34 +99,25 @@ cdef class OrderedSet(set): __str__ = __repr__ - cpdef OrderedSet update(self, iterable): - for e in iterable: - if e not in self: - self._list.append(e) - set.add(self, e) - return self + def update(self, *iterables): + for iterable in iterables: + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) def __ior__(self, iterable): - return self.update(iterable) + self.update(iterable) + return self def union(self, *other): - result = self._copy() - for o in other: - result.update(o) + result = self.copy() + result.update(*other) return result def __or__(self, other): return self.union(other) - @cython.final - cdef set _to_set(self, other): - cdef set other_set - if isinstance(other, set): - other_set = <set> other - else: - other_set = set(other) - return other_set - def intersection(self, *other): cdef set other_set = set.intersection(self, *other) return self._from_list([a for a in self._list if a in other_set]) @@ -127,10 +126,18 @@ cdef class OrderedSet(set): return self.intersection(other) def symmetric_difference(self, other): - cdef set other_set = self._to_set(other) + cdef set other_set + if isinstance(other, set): + other_set = <set> other + collection = other_set + elif isinstance(other, Collection): + collection = other + other_set = set(other) + else: + collection = list(other) + other_set = set(collection) result = self._from_list([a for a in self._list if a not in other_set]) - # use other here to keep the order - result.update(a for a in other if a not in self) + result.update(a for a in collection if a not in self) return result def __xor__(self, other): @@ -152,9 +159,10 @@ cdef class OrderedSet(set): return self cpdef symmetric_difference_update(self, other): - set.symmetric_difference_update(self, other) + collection = other if isinstance(other, Collection) else list(other) + set.symmetric_difference_update(self, collection) self._list = [a for a in self._list if a in self] - self._list += [a for a in other if a in self] + self._list += [a for a in collection if a in self] def __ixor__(self, other): self.symmetric_difference_update(other) @@ -169,13 +177,12 @@ cdef class OrderedSet(set): return self cdef object cy_id(object item): - return PyLong_FromLong(<long> (<void *>item)) + return PyLong_FromLongLong(<long long> (<void *>item)) # NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped # instead of the __rmeth__, so they need to check that also self is of the # correct type. This is fixed in cython 3.x. See: # https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods - cdef class IdentitySet: """A set that considers only object id() for uniqueness. diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 8810800c4..9962493b5 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -168,8 +168,11 @@ class OrderedSet(Set[_T]): else: self._list = [] - def __reduce__(self): - return (OrderedSet, (self._list,)) + def copy(self) -> OrderedSet[_T]: + cp = self.__class__() + cp._list = self._list.copy() + set.update(cp, cp._list) + return cp def add(self, element: _T) -> None: if element not in self: @@ -180,6 +183,14 @@ class OrderedSet(Set[_T]): super().remove(element) self._list.remove(element) + def pop(self) -> _T: + try: + value = self._list.pop() + except IndexError: + raise KeyError("pop from an empty set") from None + super().remove(value) + return value + def insert(self, pos: int, element: _T) -> None: if element not in self: self._list.insert(pos, element) @@ -220,9 +231,8 @@ class OrderedSet(Set[_T]): return self # type: ignore def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]: - result: OrderedSet[Union[_T, _S]] = self.__class__(self) # type: ignore # noqa: E501 - for o in other: - result.update(o) + result: OrderedSet[Union[_T, _S]] = self.copy() # type: ignore + result.update(*other) return result def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: @@ -237,9 +247,17 @@ class OrderedSet(Set[_T]): return self.intersection(other) def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]: - other_set = other if isinstance(other, set) else set(other) + collection: Collection[_T] + if isinstance(other, set): + collection = other_set = other + elif isinstance(other, Collection): + collection = other + other_set = set(other) + else: + collection = list(other) + other_set = set(collection) result = self.__class__(a for a in self if a not in other_set) - result.update(a for a in other if a not in self) + result.update(a for a in collection if a not in self) return result def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: @@ -263,9 +281,10 @@ class OrderedSet(Set[_T]): return self def symmetric_difference_update(self, other: Iterable[Any]) -> None: - super().symmetric_difference_update(other) + collection = other if isinstance(other, Collection) else list(other) + super().symmetric_difference_update(collection) self._list = [a for a in self._list if a in self] - self._list += [a for a in other if a in self] + self._list += [a for a in collection if a in self] def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: self.symmetric_difference_update(other) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 01877f776..d77e1b0ae 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -19,9 +19,12 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_instance_of +from sqlalchemy.testing import is_none from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing import ne_ +from sqlalchemy.testing import not_in from sqlalchemy.testing.util import gc_collect from sqlalchemy.testing.util import picklers from sqlalchemy.util import classproperty @@ -209,6 +212,27 @@ class OrderedSetTest(fixtures.TestBase): eq_(o.difference(iter([3, 4])), util.OrderedSet([2, 5])) eq_(o.intersection(iter([3, 4, 6])), util.OrderedSet([3, 4])) eq_(o.union(iter([3, 4, 6])), util.OrderedSet([3, 2, 4, 5, 6])) + eq_( + o.symmetric_difference(iter([3, 4, 6])), util.OrderedSet([2, 5, 6]) + ) + + def test_mutators_against_iter_update(self): + # testing a set modified against an iterator + o = util.OrderedSet([3, 2, 4, 5]) + o.difference_update(iter([3, 4])) + eq_(list(o), [2, 5]) + + o = util.OrderedSet([3, 2, 4, 5]) + o.intersection_update(iter([3, 4])) + eq_(list(o), [3, 4]) + + o = util.OrderedSet([3, 2, 4, 5]) + o.update(iter([3, 4, 6])) + eq_(list(o), [3, 2, 4, 5, 6]) + + o = util.OrderedSet([3, 2, 4, 5]) + o.symmetric_difference_update(iter([3, 4, 6])) + eq_(list(o), [2, 5, 6]) def test_len(self): eq_(len(util.OrderedSet([1, 2, 3])), 3) @@ -229,6 +253,110 @@ class OrderedSetTest(fixtures.TestBase): o = util.OrderedSet([3, 2, 4, 5]) eq_(str(o), "OrderedSet([3, 2, 4, 5])") + def test_modify(self): + o = util.OrderedSet([3, 9, 11]) + is_none(o.add(42)) + in_(42, o) + in_(3, o) + + is_none(o.remove(9)) + not_in(9, o) + in_(3, o) + + is_none(o.discard(11)) + in_(3, o) + + o.add(99) + is_none(o.insert(1, 13)) + eq_(list(o), [3, 13, 42, 99]) + eq_(o[2], 42) + + val = o.pop() + eq_(val, 99) + not_in(99, o) + eq_(list(o), [3, 13, 42]) + + is_none(o.clear()) + not_in(3, o) + is_false(bool(o)) + + def test_empty_pop(self): + with expect_raises_message(KeyError, "pop from an empty set"): + util.OrderedSet().pop() + + @testing.combinations( + lambda o: o + util.OrderedSet([11, 22]), + lambda o: o | util.OrderedSet([11, 22]), + lambda o: o.union(util.OrderedSet([11, 22])), + lambda o: o.union([11, 2], [22, 8]), + ) + def test_op(self, fn): + o = util.OrderedSet(range(10)) + x = fn(o) + is_instance_of(x, util.OrderedSet) + in_(9, x) + in_(11, x) + not_in(11, o) + + def test_update(self): + o = util.OrderedSet(range(10)) + is_none(o.update([22, 2], [33, 11])) + in_(11, o) + in_(22, o) + + def test_set_ops(self): + o1, o2 = util.OrderedSet([1, 3, 5, 7]), {2, 3, 4, 5} + eq_(o1 & o2, {3, 5}) + eq_(o1.intersection(o2), {3, 5}) + o3 = o1.copy() + o3 &= o2 + eq_(o3, {3, 5}) + o3 = o1.copy() + is_none(o3.intersection_update(o2)) + eq_(o3, {3, 5}) + + eq_(o1 | o2, {1, 2, 3, 4, 5, 7}) + eq_(o1.union(o2), {1, 2, 3, 4, 5, 7}) + o3 = o1.copy() + o3 |= o2 + eq_(o3, {1, 2, 3, 4, 5, 7}) + o3 = o1.copy() + is_none(o3.update(o2)) + eq_(o3, {1, 2, 3, 4, 5, 7}) + + eq_(o1 - o2, {1, 7}) + eq_(o1.difference(o2), {1, 7}) + o3 = o1.copy() + o3 -= o2 + eq_(o3, {1, 7}) + o3 = o1.copy() + is_none(o3.difference_update(o2)) + eq_(o3, {1, 7}) + + eq_(o1 ^ o2, {1, 2, 4, 7}) + eq_(o1.symmetric_difference(o2), {1, 2, 4, 7}) + o3 = o1.copy() + o3 ^= o2 + eq_(o3, {1, 2, 4, 7}) + o3 = o1.copy() + is_none(o3.symmetric_difference_update(o2)) + eq_(o3, {1, 2, 4, 7}) + + def test_copy(self): + o = util.OrderedSet([3, 2, 4, 5]) + cp = o.copy() + is_instance_of(cp, util.OrderedSet) + eq_(o, cp) + o.add(42) + is_false(42 in cp) + + def test_pickle(self): + o = util.OrderedSet([2, 4, 9, 42]) + for loads, dumps in picklers(): + l = loads(dumps(o)) + is_instance_of(l, util.OrderedSet) + eq_(list(l), [2, 4, 9, 42]) + class ImmutableDictTest(fixtures.TestBase): def test_union_no_change(self): |