summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-03-31 13:51:47 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-03-31 13:51:47 +0000
commitcfd31fa67d622c9b1e9811dde2f1f87173ed648e (patch)
tree66413373adf63f39d85a375277d5c1e36265bae1
parent5f73d2f6332dfaadb1b816f8dce539433014c654 (diff)
parenta979b6dc5ebefedfd8c85f5695cc5be8882eaa29 (diff)
downloadsqlalchemy-cfd31fa67d622c9b1e9811dde2f1f87173ed648e.tar.gz
Merge "Add missing methods to OrderedSet." into main
-rw-r--r--doc/build/changelog/unreleased_20/9487.rst6
-rw-r--r--lib/sqlalchemy/cyextension/collections.pyx65
-rw-r--r--lib/sqlalchemy/util/_py_collections.py37
-rw-r--r--test/base/test_utils.py128
4 files changed, 198 insertions, 38 deletions
diff --git a/doc/build/changelog/unreleased_20/9487.rst b/doc/build/changelog/unreleased_20/9487.rst
new file mode 100644
index 000000000..627be0e61
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9487.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: bug, util
+ :tickets: 9487
+
+ Implemented missing methods ``copy`` and ``pop`` in
+ OrderedSet class.
diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx
index e6667dddd..d08fa3aab 100644
--- a/lib/sqlalchemy/cyextension/collections.pyx
+++ b/lib/sqlalchemy/cyextension/collections.pyx
@@ -1,8 +1,9 @@
cimport cython
from cpython.dict cimport PyDict_Merge, PyDict_Update
-from cpython.long cimport PyLong_FromLong
+from cpython.long cimport PyLong_FromLongLong
from cpython.set cimport PySet_Add
+from collections.abc import Collection
from itertools import filterfalse
cdef bint add_not_present(set seen, object item, hashfunc):
@@ -39,8 +40,7 @@ cdef class OrderedSet(set):
else:
self._list = []
- @cython.final
- cdef OrderedSet _copy(self):
+ cpdef OrderedSet copy(self):
cdef OrderedSet cp = OrderedSet.__new__(OrderedSet)
cp._list = list(self._list)
set.update(cp, cp._list)
@@ -63,6 +63,14 @@ cdef class OrderedSet(set):
set.remove(self, element)
self._list.remove(element)
+ def pop(self):
+ try:
+ value = self._list.pop()
+ except IndexError:
+ raise KeyError("pop from an empty set") from None
+ set.remove(self, value)
+ return value
+
def insert(self, Py_ssize_t pos, element):
if element not in self:
self._list.insert(pos, element)
@@ -91,34 +99,25 @@ cdef class OrderedSet(set):
__str__ = __repr__
- cpdef OrderedSet update(self, iterable):
- for e in iterable:
- if e not in self:
- self._list.append(e)
- set.add(self, e)
- return self
+ def update(self, *iterables):
+ for iterable in iterables:
+ for e in iterable:
+ if e not in self:
+ self._list.append(e)
+ set.add(self, e)
def __ior__(self, iterable):
- return self.update(iterable)
+ self.update(iterable)
+ return self
def union(self, *other):
- result = self._copy()
- for o in other:
- result.update(o)
+ result = self.copy()
+ result.update(*other)
return result
def __or__(self, other):
return self.union(other)
- @cython.final
- cdef set _to_set(self, other):
- cdef set other_set
- if isinstance(other, set):
- other_set = <set> other
- else:
- other_set = set(other)
- return other_set
-
def intersection(self, *other):
cdef set other_set = set.intersection(self, *other)
return self._from_list([a for a in self._list if a in other_set])
@@ -127,10 +126,18 @@ cdef class OrderedSet(set):
return self.intersection(other)
def symmetric_difference(self, other):
- cdef set other_set = self._to_set(other)
+ cdef set other_set
+ if isinstance(other, set):
+ other_set = <set> other
+ collection = other_set
+ elif isinstance(other, Collection):
+ collection = other
+ other_set = set(other)
+ else:
+ collection = list(other)
+ other_set = set(collection)
result = self._from_list([a for a in self._list if a not in other_set])
- # use other here to keep the order
- result.update(a for a in other if a not in self)
+ result.update(a for a in collection if a not in self)
return result
def __xor__(self, other):
@@ -152,9 +159,10 @@ cdef class OrderedSet(set):
return self
cpdef symmetric_difference_update(self, other):
- set.symmetric_difference_update(self, other)
+ collection = other if isinstance(other, Collection) else list(other)
+ set.symmetric_difference_update(self, collection)
self._list = [a for a in self._list if a in self]
- self._list += [a for a in other if a in self]
+ self._list += [a for a in collection if a in self]
def __ixor__(self, other):
self.symmetric_difference_update(other)
@@ -169,13 +177,12 @@ cdef class OrderedSet(set):
return self
cdef object cy_id(object item):
- return PyLong_FromLong(<long> (<void *>item))
+ return PyLong_FromLongLong(<long long> (<void *>item))
# NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped
# instead of the __rmeth__, so they need to check that also self is of the
# correct type. This is fixed in cython 3.x. See:
# https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods
-
cdef class IdentitySet:
"""A set that considers only object id() for uniqueness.
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py
index 8810800c4..9962493b5 100644
--- a/lib/sqlalchemy/util/_py_collections.py
+++ b/lib/sqlalchemy/util/_py_collections.py
@@ -168,8 +168,11 @@ class OrderedSet(Set[_T]):
else:
self._list = []
- def __reduce__(self):
- return (OrderedSet, (self._list,))
+ def copy(self) -> OrderedSet[_T]:
+ cp = self.__class__()
+ cp._list = self._list.copy()
+ set.update(cp, cp._list)
+ return cp
def add(self, element: _T) -> None:
if element not in self:
@@ -180,6 +183,14 @@ class OrderedSet(Set[_T]):
super().remove(element)
self._list.remove(element)
+ def pop(self) -> _T:
+ try:
+ value = self._list.pop()
+ except IndexError:
+ raise KeyError("pop from an empty set") from None
+ super().remove(value)
+ return value
+
def insert(self, pos: int, element: _T) -> None:
if element not in self:
self._list.insert(pos, element)
@@ -220,9 +231,8 @@ class OrderedSet(Set[_T]):
return self # type: ignore
def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]:
- result: OrderedSet[Union[_T, _S]] = self.__class__(self) # type: ignore # noqa: E501
- for o in other:
- result.update(o)
+ result: OrderedSet[Union[_T, _S]] = self.copy() # type: ignore
+ result.update(*other)
return result
def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
@@ -237,9 +247,17 @@ class OrderedSet(Set[_T]):
return self.intersection(other)
def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]:
- other_set = other if isinstance(other, set) else set(other)
+ collection: Collection[_T]
+ if isinstance(other, set):
+ collection = other_set = other
+ elif isinstance(other, Collection):
+ collection = other
+ other_set = set(other)
+ else:
+ collection = list(other)
+ other_set = set(collection)
result = self.__class__(a for a in self if a not in other_set)
- result.update(a for a in other if a not in self)
+ result.update(a for a in collection if a not in self)
return result
def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
@@ -263,9 +281,10 @@ class OrderedSet(Set[_T]):
return self
def symmetric_difference_update(self, other: Iterable[Any]) -> None:
- super().symmetric_difference_update(other)
+ collection = other if isinstance(other, Collection) else list(other)
+ super().symmetric_difference_update(collection)
self._list = [a for a in self._list if a in self]
- self._list += [a for a in other if a in self]
+ self._list += [a for a in collection if a in self]
def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
self.symmetric_difference_update(other)
diff --git a/test/base/test_utils.py b/test/base/test_utils.py
index 01877f776..d77e1b0ae 100644
--- a/test/base/test_utils.py
+++ b/test/base/test_utils.py
@@ -19,9 +19,12 @@ from sqlalchemy.testing import fixtures
from sqlalchemy.testing import in_
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_instance_of
+from sqlalchemy.testing import is_none
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing import ne_
+from sqlalchemy.testing import not_in
from sqlalchemy.testing.util import gc_collect
from sqlalchemy.testing.util import picklers
from sqlalchemy.util import classproperty
@@ -209,6 +212,27 @@ class OrderedSetTest(fixtures.TestBase):
eq_(o.difference(iter([3, 4])), util.OrderedSet([2, 5]))
eq_(o.intersection(iter([3, 4, 6])), util.OrderedSet([3, 4]))
eq_(o.union(iter([3, 4, 6])), util.OrderedSet([3, 2, 4, 5, 6]))
+ eq_(
+ o.symmetric_difference(iter([3, 4, 6])), util.OrderedSet([2, 5, 6])
+ )
+
+ def test_mutators_against_iter_update(self):
+ # testing a set modified against an iterator
+ o = util.OrderedSet([3, 2, 4, 5])
+ o.difference_update(iter([3, 4]))
+ eq_(list(o), [2, 5])
+
+ o = util.OrderedSet([3, 2, 4, 5])
+ o.intersection_update(iter([3, 4]))
+ eq_(list(o), [3, 4])
+
+ o = util.OrderedSet([3, 2, 4, 5])
+ o.update(iter([3, 4, 6]))
+ eq_(list(o), [3, 2, 4, 5, 6])
+
+ o = util.OrderedSet([3, 2, 4, 5])
+ o.symmetric_difference_update(iter([3, 4, 6]))
+ eq_(list(o), [2, 5, 6])
def test_len(self):
eq_(len(util.OrderedSet([1, 2, 3])), 3)
@@ -229,6 +253,110 @@ class OrderedSetTest(fixtures.TestBase):
o = util.OrderedSet([3, 2, 4, 5])
eq_(str(o), "OrderedSet([3, 2, 4, 5])")
+ def test_modify(self):
+ o = util.OrderedSet([3, 9, 11])
+ is_none(o.add(42))
+ in_(42, o)
+ in_(3, o)
+
+ is_none(o.remove(9))
+ not_in(9, o)
+ in_(3, o)
+
+ is_none(o.discard(11))
+ in_(3, o)
+
+ o.add(99)
+ is_none(o.insert(1, 13))
+ eq_(list(o), [3, 13, 42, 99])
+ eq_(o[2], 42)
+
+ val = o.pop()
+ eq_(val, 99)
+ not_in(99, o)
+ eq_(list(o), [3, 13, 42])
+
+ is_none(o.clear())
+ not_in(3, o)
+ is_false(bool(o))
+
+ def test_empty_pop(self):
+ with expect_raises_message(KeyError, "pop from an empty set"):
+ util.OrderedSet().pop()
+
+ @testing.combinations(
+ lambda o: o + util.OrderedSet([11, 22]),
+ lambda o: o | util.OrderedSet([11, 22]),
+ lambda o: o.union(util.OrderedSet([11, 22])),
+ lambda o: o.union([11, 2], [22, 8]),
+ )
+ def test_op(self, fn):
+ o = util.OrderedSet(range(10))
+ x = fn(o)
+ is_instance_of(x, util.OrderedSet)
+ in_(9, x)
+ in_(11, x)
+ not_in(11, o)
+
+ def test_update(self):
+ o = util.OrderedSet(range(10))
+ is_none(o.update([22, 2], [33, 11]))
+ in_(11, o)
+ in_(22, o)
+
+ def test_set_ops(self):
+ o1, o2 = util.OrderedSet([1, 3, 5, 7]), {2, 3, 4, 5}
+ eq_(o1 & o2, {3, 5})
+ eq_(o1.intersection(o2), {3, 5})
+ o3 = o1.copy()
+ o3 &= o2
+ eq_(o3, {3, 5})
+ o3 = o1.copy()
+ is_none(o3.intersection_update(o2))
+ eq_(o3, {3, 5})
+
+ eq_(o1 | o2, {1, 2, 3, 4, 5, 7})
+ eq_(o1.union(o2), {1, 2, 3, 4, 5, 7})
+ o3 = o1.copy()
+ o3 |= o2
+ eq_(o3, {1, 2, 3, 4, 5, 7})
+ o3 = o1.copy()
+ is_none(o3.update(o2))
+ eq_(o3, {1, 2, 3, 4, 5, 7})
+
+ eq_(o1 - o2, {1, 7})
+ eq_(o1.difference(o2), {1, 7})
+ o3 = o1.copy()
+ o3 -= o2
+ eq_(o3, {1, 7})
+ o3 = o1.copy()
+ is_none(o3.difference_update(o2))
+ eq_(o3, {1, 7})
+
+ eq_(o1 ^ o2, {1, 2, 4, 7})
+ eq_(o1.symmetric_difference(o2), {1, 2, 4, 7})
+ o3 = o1.copy()
+ o3 ^= o2
+ eq_(o3, {1, 2, 4, 7})
+ o3 = o1.copy()
+ is_none(o3.symmetric_difference_update(o2))
+ eq_(o3, {1, 2, 4, 7})
+
+ def test_copy(self):
+ o = util.OrderedSet([3, 2, 4, 5])
+ cp = o.copy()
+ is_instance_of(cp, util.OrderedSet)
+ eq_(o, cp)
+ o.add(42)
+ is_false(42 in cp)
+
+ def test_pickle(self):
+ o = util.OrderedSet([2, 4, 9, 42])
+ for loads, dumps in picklers():
+ l = loads(dumps(o))
+ is_instance_of(l, util.OrderedSet)
+ eq_(list(l), [2, 4, 9, 42])
+
class ImmutableDictTest(fixtures.TestBase):
def test_union_no_change(self):