summaryrefslogtreecommitdiff
path: root/Lib/test/test_set.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_set.py')
-rw-r--r--Lib/test/test_set.py181
1 files changed, 102 insertions, 79 deletions
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index 3539a14065..b8753a1206 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -6,7 +6,6 @@ import weakref
import operator
import copy
import pickle
-import os
from random import randrange, shuffle
import sys
import collections
@@ -66,7 +65,7 @@ class TestJointOps(unittest.TestCase):
self.assertEqual(c in self.s, c in self.d)
self.assertRaises(TypeError, self.s.__contains__, [[]])
s = self.thetype([frozenset(self.letters)])
- self.assert_(self.thetype(self.letters) in s)
+ self.assertIn(self.thetype(self.letters), s)
def test_union(self):
u = self.s.union(self.otherword)
@@ -130,7 +129,7 @@ class TestJointOps(unittest.TestCase):
actual = s1.isdisjoint(s2)
expected = f(s1, s2)
self.assertEqual(actual, expected)
- self.assert_(actual is True or actual is False)
+ self.assertTrue(actual is True or actual is False)
def test_and(self):
i = self.s.intersection(self.otherword)
@@ -213,19 +212,19 @@ class TestJointOps(unittest.TestCase):
def test_sub_and_super(self):
p, q, r = map(self.thetype, ['ab', 'abcde', 'def'])
- self.assert_(p < q)
- self.assert_(p <= q)
- self.assert_(q <= q)
- self.assert_(q > p)
- self.assert_(q >= p)
- self.failIf(q < r)
- self.failIf(q <= r)
- self.failIf(q > r)
- self.failIf(q >= r)
- self.assert_(set('a').issubset('abc'))
- self.assert_(set('abc').issuperset('a'))
- self.failIf(set('a').issubset('cbs'))
- self.failIf(set('cbs').issuperset('a'))
+ self.assertTrue(p < q)
+ self.assertTrue(p <= q)
+ self.assertTrue(q <= q)
+ self.assertTrue(q > p)
+ self.assertTrue(q >= p)
+ self.assertFalse(q < r)
+ self.assertFalse(q <= r)
+ self.assertFalse(q > r)
+ self.assertFalse(q >= r)
+ self.assertTrue(set('a').issubset('abc'))
+ self.assertTrue(set('abc').issuperset('a'))
+ self.assertFalse(set('a').issubset('cbs'))
+ self.assertFalse(set('cbs').issuperset('a'))
def test_pickling(self):
for i in range(pickle.HIGHEST_PROTOCOL + 1):
@@ -273,7 +272,7 @@ class TestJointOps(unittest.TestCase):
s=H()
f=set()
f.add(s)
- self.assert_(s in f)
+ self.assertIn(s, f)
f.remove(s)
f.add(s)
f.discard(s)
@@ -339,7 +338,7 @@ class TestJointOps(unittest.TestCase):
obj.x = iter(container)
del obj, container
gc.collect()
- self.assert_(ref() is None, "Cycle was not collected")
+ self.assertTrue(ref() is None, "Cycle was not collected")
class TestSet(TestJointOps):
thetype = set
@@ -373,7 +372,7 @@ class TestSet(TestJointOps):
def test_add(self):
self.s.add('Q')
- self.assert_('Q' in self.s)
+ self.assertIn('Q', self.s)
dup = self.s.copy()
self.s.add('Q')
self.assertEqual(self.s, dup)
@@ -381,13 +380,13 @@ class TestSet(TestJointOps):
def test_remove(self):
self.s.remove('a')
- self.assert_('a' not in self.s)
+ self.assertNotIn('a', self.s)
self.assertRaises(KeyError, self.s.remove, 'Q')
self.assertRaises(TypeError, self.s.remove, [])
s = self.thetype([frozenset(self.word)])
- self.assert_(self.thetype(self.word) in s)
+ self.assertIn(self.thetype(self.word), s)
s.remove(self.thetype(self.word))
- self.assert_(self.thetype(self.word) not in s)
+ self.assertNotIn(self.thetype(self.word), s)
self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
def test_remove_keyerror_unpacking(self):
@@ -406,7 +405,7 @@ class TestSet(TestJointOps):
try:
self.s.remove(key)
except KeyError as e:
- self.assert_(e.args[0] is key,
+ self.assertTrue(e.args[0] is key,
"KeyError should be {0}, not {1}".format(key,
e.args[0]))
else:
@@ -414,26 +413,26 @@ class TestSet(TestJointOps):
def test_discard(self):
self.s.discard('a')
- self.assert_('a' not in self.s)
+ self.assertNotIn('a', self.s)
self.s.discard('Q')
self.assertRaises(TypeError, self.s.discard, [])
s = self.thetype([frozenset(self.word)])
- self.assert_(self.thetype(self.word) in s)
+ self.assertIn(self.thetype(self.word), s)
s.discard(self.thetype(self.word))
- self.assert_(self.thetype(self.word) not in s)
+ self.assertNotIn(self.thetype(self.word), s)
s.discard(self.thetype(self.word))
def test_pop(self):
for i in xrange(len(self.s)):
elem = self.s.pop()
- self.assert_(elem not in self.s)
+ self.assertNotIn(elem, self.s)
self.assertRaises(KeyError, self.s.pop)
def test_update(self):
retval = self.s.update(self.otherword)
self.assertEqual(retval, None)
for c in (self.word + self.otherword):
- self.assert_(c in self.s)
+ self.assertIn(c, self.s)
self.assertRaises(PassThru, self.s.update, check_pass_thru())
self.assertRaises(TypeError, self.s.update, [[]])
for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')):
@@ -451,16 +450,16 @@ class TestSet(TestJointOps):
def test_ior(self):
self.s |= set(self.otherword)
for c in (self.word + self.otherword):
- self.assert_(c in self.s)
+ self.assertIn(c, self.s)
def test_intersection_update(self):
retval = self.s.intersection_update(self.otherword)
self.assertEqual(retval, None)
for c in (self.word + self.otherword):
if c in self.otherword and c in self.word:
- self.assert_(c in self.s)
+ self.assertIn(c, self.s)
else:
- self.assert_(c not in self.s)
+ self.assertNotIn(c, self.s)
self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru())
self.assertRaises(TypeError, self.s.intersection_update, [[]])
for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')):
@@ -478,18 +477,18 @@ class TestSet(TestJointOps):
self.s &= set(self.otherword)
for c in (self.word + self.otherword):
if c in self.otherword and c in self.word:
- self.assert_(c in self.s)
+ self.assertIn(c, self.s)
else:
- self.assert_(c not in self.s)
+ self.assertNotIn(c, self.s)
def test_difference_update(self):
retval = self.s.difference_update(self.otherword)
self.assertEqual(retval, None)
for c in (self.word + self.otherword):
if c in self.word and c not in self.otherword:
- self.assert_(c in self.s)
+ self.assertIn(c, self.s)
else:
- self.assert_(c not in self.s)
+ self.assertNotIn(c, self.s)
self.assertRaises(PassThru, self.s.difference_update, check_pass_thru())
self.assertRaises(TypeError, self.s.difference_update, [[]])
self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
@@ -515,18 +514,18 @@ class TestSet(TestJointOps):
self.s -= set(self.otherword)
for c in (self.word + self.otherword):
if c in self.word and c not in self.otherword:
- self.assert_(c in self.s)
+ self.assertIn(c, self.s)
else:
- self.assert_(c not in self.s)
+ self.assertNotIn(c, self.s)
def test_symmetric_difference_update(self):
retval = self.s.symmetric_difference_update(self.otherword)
self.assertEqual(retval, None)
for c in (self.word + self.otherword):
if (c in self.word) ^ (c in self.otherword):
- self.assert_(c in self.s)
+ self.assertIn(c, self.s)
else:
- self.assert_(c not in self.s)
+ self.assertNotIn(c, self.s)
self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru())
self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')):
@@ -539,9 +538,9 @@ class TestSet(TestJointOps):
self.s ^= set(self.otherword)
for c in (self.word + self.otherword):
if (c in self.word) ^ (c in self.otherword):
- self.assert_(c in self.s)
+ self.assertIn(c, self.s)
else:
- self.assert_(c not in self.s)
+ self.assertNotIn(c, self.s)
def test_inplace_on_self(self):
t = self.s.copy()
@@ -565,7 +564,7 @@ class TestSet(TestJointOps):
# C API test only available in a debug build
if hasattr(set, "test_c_api"):
def test_c_api(self):
- self.assertEqual(set('abc').test_c_api(), True)
+ self.assertEqual(set().test_c_api(), True)
class SetSubclass(set):
pass
@@ -769,7 +768,7 @@ class TestBasicOps(unittest.TestCase):
def test_iteration(self):
for v in self.set:
- self.assert_(v in self.values)
+ self.assertIn(v, self.values)
setiter = iter(self.set)
# note: __length_hint__ is an internal undocumented API,
# don't rely on it in your own programs
@@ -804,10 +803,10 @@ class TestBasicOpsSingleton(TestBasicOps):
self.repr = "set([3])"
def test_in(self):
- self.failUnless(3 in self.set)
+ self.assertIn(3, self.set)
def test_not_in(self):
- self.failUnless(2 not in self.set)
+ self.assertNotIn(2, self.set)
#------------------------------------------------------------------------------
@@ -821,10 +820,10 @@ class TestBasicOpsTuple(TestBasicOps):
self.repr = "set([(0, 'zero')])"
def test_in(self):
- self.failUnless((0, "zero") in self.set)
+ self.assertIn((0, "zero"), self.set)
def test_not_in(self):
- self.failUnless(9 not in self.set)
+ self.assertNotIn(9, self.set)
#------------------------------------------------------------------------------
@@ -1116,7 +1115,7 @@ class TestMutate(unittest.TestCase):
popped[self.set.pop()] = None
self.assertEqual(len(popped), len(self.values))
for v in self.values:
- self.failUnless(v in popped)
+ self.assertIn(v, popped)
def test_update_empty_tuple(self):
self.set.update(())
@@ -1349,7 +1348,7 @@ class TestOnlySetsOperator(TestOnlySetsInBinaryOps):
self.otherIsIterable = False
def test_ge_gt_le_lt(self):
- with test_support._check_py3k_warnings():
+ with test_support.check_py3k_warnings():
super(TestOnlySetsOperator, self).test_ge_gt_le_lt()
#------------------------------------------------------------------------------
@@ -1384,23 +1383,17 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps):
class TestCopying(unittest.TestCase):
def test_copy(self):
- dup = self.set.copy()
- with test_support.check_warnings():
- dup_list = sorted(dup)
- set_list = sorted(self.set)
- self.assertEqual(len(dup_list), len(set_list))
- for i in range(len(dup_list)):
- self.failUnless(dup_list[i] is set_list[i])
+ dup = list(self.set.copy())
+ self.assertEqual(len(dup), len(self.set))
+ for el in self.set:
+ self.assertIn(el, dup)
+ pos = dup.index(el)
+ self.assertIs(el, dup.pop(pos))
+ self.assertFalse(dup)
def test_deep_copy(self):
dup = copy.deepcopy(self.set)
- ##print type(dup), repr(dup)
- with test_support.check_warnings():
- dup_list = sorted(dup)
- set_list = sorted(self.set)
- self.assertEqual(len(dup_list), len(set_list))
- for i in range(len(dup_list)):
- self.assertEqual(dup_list[i], set_list[i])
+ self.assertSetEqual(dup, self.set)
#------------------------------------------------------------------------------
@@ -1441,13 +1434,13 @@ class TestIdentities(unittest.TestCase):
def test_binopsVsSubsets(self):
a, b = self.a, self.b
- self.assert_(a - b < a)
- self.assert_(b - a < b)
- self.assert_(a & b < a)
- self.assert_(a & b < b)
- self.assert_(a | b > a)
- self.assert_(a | b > b)
- self.assert_(a ^ b < a | b)
+ self.assertTrue(a - b < a)
+ self.assertTrue(b - a < b)
+ self.assertTrue(a & b < a)
+ self.assertTrue(a & b < b)
+ self.assertTrue(a | b > a)
+ self.assertTrue(a | b > b)
+ self.assertTrue(a ^ b < a | b)
def test_commutativity(self):
a, b = self.a, self.b
@@ -1559,9 +1552,8 @@ class TestVariousIteratorArgs(unittest.TestCase):
def test_constructor(self):
for cons in (set, frozenset):
for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
- with test_support.check_warnings():
- for g in (G, I, Ig, S, L, R):
- self.assertEqual(sorted(cons(g(s))), sorted(g(s)))
+ for g in (G, I, Ig, S, L, R):
+ self.assertSetEqual(cons(g(s)), set(g(s)))
self.assertRaises(TypeError, cons , X(s))
self.assertRaises(TypeError, cons , N(s))
self.assertRaises(ZeroDivisionError, cons , E(s))
@@ -1576,8 +1568,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
if isinstance(expected, bool):
self.assertEqual(actual, expected)
else:
- with test_support.check_warnings():
- self.assertEqual(sorted(actual), sorted(expected))
+ self.assertSetEqual(actual, expected)
self.assertRaises(TypeError, meth, X(s))
self.assertRaises(TypeError, meth, N(s))
self.assertRaises(ZeroDivisionError, meth, E(s))
@@ -1591,13 +1582,45 @@ class TestVariousIteratorArgs(unittest.TestCase):
t = s.copy()
getattr(s, methname)(list(g(data)))
getattr(t, methname)(g(data))
- with test_support.check_warnings():
- self.assertEqual(sorted(s), sorted(t))
+ self.assertSetEqual(s, t)
self.assertRaises(TypeError, getattr(set('january'), methname), X(data))
self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
+class bad_eq:
+ def __eq__(self, other):
+ if be_bad:
+ set2.clear()
+ raise ZeroDivisionError
+ return self is other
+ def __hash__(self):
+ return 0
+
+class bad_dict_clear:
+ def __eq__(self, other):
+ if be_bad:
+ dict2.clear()
+ return self is other
+ def __hash__(self):
+ return 0
+
+class TestWeirdBugs(unittest.TestCase):
+ def test_8420_set_merge(self):
+ # This used to segfault
+ global be_bad, set2, dict2
+ be_bad = False
+ set1 = {bad_eq()}
+ set2 = {bad_eq() for i in range(75)}
+ be_bad = True
+ self.assertRaises(ZeroDivisionError, set1.update, set2)
+
+ be_bad = False
+ set1 = {bad_dict_clear()}
+ dict2 = {bad_dict_clear(): None}
+ be_bad = True
+ set1.symmetric_difference_update(dict2)
+
# Application tests (based on David Eppstein's graph recipes ====================================
def powerset(U):
@@ -1699,13 +1722,12 @@ class TestGraphs(unittest.TestCase):
edge = vertex # Cuboctahedron vertices are edges in Cube
self.assertEqual(len(edge), 2) # Two cube vertices define an edge
for cubevert in edge:
- self.assert_(cubevert in g)
+ self.assertIn(cubevert, g)
#==============================================================================
def test_main(verbose=None):
- from test import test_sets
test_classes = (
TestSet,
TestSetSubclass,
@@ -1740,6 +1762,7 @@ def test_main(verbose=None):
TestIdentities,
TestVariousIteratorArgs,
TestGraphs,
+ TestWeirdBugs,
)
test_support.run_unittest(*test_classes)