summaryrefslogtreecommitdiff
path: root/Lib
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2013-12-06 23:23:15 +0200
committerSerhiy Storchaka <storchaka@gmail.com>2013-12-06 23:23:15 +0200
commit7c573857c7ac63395cc19b4785e0a91acbb83742 (patch)
tree07a97f21a663a1519f4e340dac0b46a631ac4134 /Lib
parentd919da9942290baf606fb03433d1e484d4f0f53d (diff)
downloadcpython-git-7c573857c7ac63395cc19b4785e0a91acbb83742.tar.gz
Issue #16373: Prevent infinite recursion for ABC Set class comparisons.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/_abcoll.py4
-rw-r--r--Lib/test/test_collections.py29
2 files changed, 31 insertions, 2 deletions
diff --git a/Lib/_abcoll.py b/Lib/_abcoll.py
index 0438afda28..8b650a7763 100644
--- a/Lib/_abcoll.py
+++ b/Lib/_abcoll.py
@@ -165,12 +165,12 @@ class Set(Sized, Iterable, Container):
def __gt__(self, other):
if not isinstance(other, Set):
return NotImplemented
- return other < self
+ return other.__lt__(self)
def __ge__(self, other):
if not isinstance(other, Set):
return NotImplemented
- return other <= self
+ return other.__le__(self)
def __eq__(self, other):
if not isinstance(other, Set):
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
index 3aaecb2a6f..784a31880e 100644
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -594,6 +594,35 @@ class TestCollectionABCs(ABCTestCase):
s |= s
self.assertEqual(s, full)
+ def test_issue16373(self):
+ # Recursion error comparing comparable and noncomparable
+ # Set instances
+ class MyComparableSet(Set):
+ def __contains__(self, x):
+ return False
+ def __len__(self):
+ return 0
+ def __iter__(self):
+ return iter([])
+ class MyNonComparableSet(Set):
+ def __contains__(self, x):
+ return False
+ def __len__(self):
+ return 0
+ def __iter__(self):
+ return iter([])
+ def __le__(self, x):
+ return NotImplemented
+ def __lt__(self, x):
+ return NotImplemented
+
+ cs = MyComparableSet()
+ ncs = MyNonComparableSet()
+ self.assertFalse(ncs < cs)
+ self.assertFalse(ncs <= cs)
+ self.assertFalse(cs > ncs)
+ self.assertFalse(cs >= ncs)
+
def test_Mapping(self):
for sample in [dict]:
self.assertIsInstance(sample(), Mapping)