diff options
Diffstat (limited to 'Lib/test/test_set.py')
-rw-r--r-- | Lib/test/test_set.py | 181 |
1 files changed, 102 insertions, 79 deletions
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 3539a14065..b8753a1206 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -6,7 +6,6 @@ import weakref import operator import copy import pickle -import os from random import randrange, shuffle import sys import collections @@ -66,7 +65,7 @@ class TestJointOps(unittest.TestCase): self.assertEqual(c in self.s, c in self.d) self.assertRaises(TypeError, self.s.__contains__, [[]]) s = self.thetype([frozenset(self.letters)]) - self.assert_(self.thetype(self.letters) in s) + self.assertIn(self.thetype(self.letters), s) def test_union(self): u = self.s.union(self.otherword) @@ -130,7 +129,7 @@ class TestJointOps(unittest.TestCase): actual = s1.isdisjoint(s2) expected = f(s1, s2) self.assertEqual(actual, expected) - self.assert_(actual is True or actual is False) + self.assertTrue(actual is True or actual is False) def test_and(self): i = self.s.intersection(self.otherword) @@ -213,19 +212,19 @@ class TestJointOps(unittest.TestCase): def test_sub_and_super(self): p, q, r = map(self.thetype, ['ab', 'abcde', 'def']) - self.assert_(p < q) - self.assert_(p <= q) - self.assert_(q <= q) - self.assert_(q > p) - self.assert_(q >= p) - self.failIf(q < r) - self.failIf(q <= r) - self.failIf(q > r) - self.failIf(q >= r) - self.assert_(set('a').issubset('abc')) - self.assert_(set('abc').issuperset('a')) - self.failIf(set('a').issubset('cbs')) - self.failIf(set('cbs').issuperset('a')) + self.assertTrue(p < q) + self.assertTrue(p <= q) + self.assertTrue(q <= q) + self.assertTrue(q > p) + self.assertTrue(q >= p) + self.assertFalse(q < r) + self.assertFalse(q <= r) + self.assertFalse(q > r) + self.assertFalse(q >= r) + self.assertTrue(set('a').issubset('abc')) + self.assertTrue(set('abc').issuperset('a')) + self.assertFalse(set('a').issubset('cbs')) + self.assertFalse(set('cbs').issuperset('a')) def test_pickling(self): for i in range(pickle.HIGHEST_PROTOCOL + 1): @@ -273,7 +272,7 @@ class TestJointOps(unittest.TestCase): s=H() f=set() f.add(s) - self.assert_(s in f) + self.assertIn(s, f) f.remove(s) f.add(s) f.discard(s) @@ -339,7 +338,7 @@ class TestJointOps(unittest.TestCase): obj.x = iter(container) del obj, container gc.collect() - self.assert_(ref() is None, "Cycle was not collected") + self.assertTrue(ref() is None, "Cycle was not collected") class TestSet(TestJointOps): thetype = set @@ -373,7 +372,7 @@ class TestSet(TestJointOps): def test_add(self): self.s.add('Q') - self.assert_('Q' in self.s) + self.assertIn('Q', self.s) dup = self.s.copy() self.s.add('Q') self.assertEqual(self.s, dup) @@ -381,13 +380,13 @@ class TestSet(TestJointOps): def test_remove(self): self.s.remove('a') - self.assert_('a' not in self.s) + self.assertNotIn('a', self.s) self.assertRaises(KeyError, self.s.remove, 'Q') self.assertRaises(TypeError, self.s.remove, []) s = self.thetype([frozenset(self.word)]) - self.assert_(self.thetype(self.word) in s) + self.assertIn(self.thetype(self.word), s) s.remove(self.thetype(self.word)) - self.assert_(self.thetype(self.word) not in s) + self.assertNotIn(self.thetype(self.word), s) self.assertRaises(KeyError, self.s.remove, self.thetype(self.word)) def test_remove_keyerror_unpacking(self): @@ -406,7 +405,7 @@ class TestSet(TestJointOps): try: self.s.remove(key) except KeyError as e: - self.assert_(e.args[0] is key, + self.assertTrue(e.args[0] is key, "KeyError should be {0}, not {1}".format(key, e.args[0])) else: @@ -414,26 +413,26 @@ class TestSet(TestJointOps): def test_discard(self): self.s.discard('a') - self.assert_('a' not in self.s) + self.assertNotIn('a', self.s) self.s.discard('Q') self.assertRaises(TypeError, self.s.discard, []) s = self.thetype([frozenset(self.word)]) - self.assert_(self.thetype(self.word) in s) + self.assertIn(self.thetype(self.word), s) s.discard(self.thetype(self.word)) - self.assert_(self.thetype(self.word) not in s) + self.assertNotIn(self.thetype(self.word), s) s.discard(self.thetype(self.word)) def test_pop(self): for i in xrange(len(self.s)): elem = self.s.pop() - self.assert_(elem not in self.s) + self.assertNotIn(elem, self.s) self.assertRaises(KeyError, self.s.pop) def test_update(self): retval = self.s.update(self.otherword) self.assertEqual(retval, None) for c in (self.word + self.otherword): - self.assert_(c in self.s) + self.assertIn(c, self.s) self.assertRaises(PassThru, self.s.update, check_pass_thru()) self.assertRaises(TypeError, self.s.update, [[]]) for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')): @@ -451,16 +450,16 @@ class TestSet(TestJointOps): def test_ior(self): self.s |= set(self.otherword) for c in (self.word + self.otherword): - self.assert_(c in self.s) + self.assertIn(c, self.s) def test_intersection_update(self): retval = self.s.intersection_update(self.otherword) self.assertEqual(retval, None) for c in (self.word + self.otherword): if c in self.otherword and c in self.word: - self.assert_(c in self.s) + self.assertIn(c, self.s) else: - self.assert_(c not in self.s) + self.assertNotIn(c, self.s) self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru()) self.assertRaises(TypeError, self.s.intersection_update, [[]]) for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')): @@ -478,18 +477,18 @@ class TestSet(TestJointOps): self.s &= set(self.otherword) for c in (self.word + self.otherword): if c in self.otherword and c in self.word: - self.assert_(c in self.s) + self.assertIn(c, self.s) else: - self.assert_(c not in self.s) + self.assertNotIn(c, self.s) def test_difference_update(self): retval = self.s.difference_update(self.otherword) self.assertEqual(retval, None) for c in (self.word + self.otherword): if c in self.word and c not in self.otherword: - self.assert_(c in self.s) + self.assertIn(c, self.s) else: - self.assert_(c not in self.s) + self.assertNotIn(c, self.s) self.assertRaises(PassThru, self.s.difference_update, check_pass_thru()) self.assertRaises(TypeError, self.s.difference_update, [[]]) self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) @@ -515,18 +514,18 @@ class TestSet(TestJointOps): self.s -= set(self.otherword) for c in (self.word + self.otherword): if c in self.word and c not in self.otherword: - self.assert_(c in self.s) + self.assertIn(c, self.s) else: - self.assert_(c not in self.s) + self.assertNotIn(c, self.s) def test_symmetric_difference_update(self): retval = self.s.symmetric_difference_update(self.otherword) self.assertEqual(retval, None) for c in (self.word + self.otherword): if (c in self.word) ^ (c in self.otherword): - self.assert_(c in self.s) + self.assertIn(c, self.s) else: - self.assert_(c not in self.s) + self.assertNotIn(c, self.s) self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru()) self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')): @@ -539,9 +538,9 @@ class TestSet(TestJointOps): self.s ^= set(self.otherword) for c in (self.word + self.otherword): if (c in self.word) ^ (c in self.otherword): - self.assert_(c in self.s) + self.assertIn(c, self.s) else: - self.assert_(c not in self.s) + self.assertNotIn(c, self.s) def test_inplace_on_self(self): t = self.s.copy() @@ -565,7 +564,7 @@ class TestSet(TestJointOps): # C API test only available in a debug build if hasattr(set, "test_c_api"): def test_c_api(self): - self.assertEqual(set('abc').test_c_api(), True) + self.assertEqual(set().test_c_api(), True) class SetSubclass(set): pass @@ -769,7 +768,7 @@ class TestBasicOps(unittest.TestCase): def test_iteration(self): for v in self.set: - self.assert_(v in self.values) + self.assertIn(v, self.values) setiter = iter(self.set) # note: __length_hint__ is an internal undocumented API, # don't rely on it in your own programs @@ -804,10 +803,10 @@ class TestBasicOpsSingleton(TestBasicOps): self.repr = "set([3])" def test_in(self): - self.failUnless(3 in self.set) + self.assertIn(3, self.set) def test_not_in(self): - self.failUnless(2 not in self.set) + self.assertNotIn(2, self.set) #------------------------------------------------------------------------------ @@ -821,10 +820,10 @@ class TestBasicOpsTuple(TestBasicOps): self.repr = "set([(0, 'zero')])" def test_in(self): - self.failUnless((0, "zero") in self.set) + self.assertIn((0, "zero"), self.set) def test_not_in(self): - self.failUnless(9 not in self.set) + self.assertNotIn(9, self.set) #------------------------------------------------------------------------------ @@ -1116,7 +1115,7 @@ class TestMutate(unittest.TestCase): popped[self.set.pop()] = None self.assertEqual(len(popped), len(self.values)) for v in self.values: - self.failUnless(v in popped) + self.assertIn(v, popped) def test_update_empty_tuple(self): self.set.update(()) @@ -1349,7 +1348,7 @@ class TestOnlySetsOperator(TestOnlySetsInBinaryOps): self.otherIsIterable = False def test_ge_gt_le_lt(self): - with test_support._check_py3k_warnings(): + with test_support.check_py3k_warnings(): super(TestOnlySetsOperator, self).test_ge_gt_le_lt() #------------------------------------------------------------------------------ @@ -1384,23 +1383,17 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps): class TestCopying(unittest.TestCase): def test_copy(self): - dup = self.set.copy() - with test_support.check_warnings(): - dup_list = sorted(dup) - set_list = sorted(self.set) - self.assertEqual(len(dup_list), len(set_list)) - for i in range(len(dup_list)): - self.failUnless(dup_list[i] is set_list[i]) + dup = list(self.set.copy()) + self.assertEqual(len(dup), len(self.set)) + for el in self.set: + self.assertIn(el, dup) + pos = dup.index(el) + self.assertIs(el, dup.pop(pos)) + self.assertFalse(dup) def test_deep_copy(self): dup = copy.deepcopy(self.set) - ##print type(dup), repr(dup) - with test_support.check_warnings(): - dup_list = sorted(dup) - set_list = sorted(self.set) - self.assertEqual(len(dup_list), len(set_list)) - for i in range(len(dup_list)): - self.assertEqual(dup_list[i], set_list[i]) + self.assertSetEqual(dup, self.set) #------------------------------------------------------------------------------ @@ -1441,13 +1434,13 @@ class TestIdentities(unittest.TestCase): def test_binopsVsSubsets(self): a, b = self.a, self.b - self.assert_(a - b < a) - self.assert_(b - a < b) - self.assert_(a & b < a) - self.assert_(a & b < b) - self.assert_(a | b > a) - self.assert_(a | b > b) - self.assert_(a ^ b < a | b) + self.assertTrue(a - b < a) + self.assertTrue(b - a < b) + self.assertTrue(a & b < a) + self.assertTrue(a & b < b) + self.assertTrue(a | b > a) + self.assertTrue(a | b > b) + self.assertTrue(a ^ b < a | b) def test_commutativity(self): a, b = self.a, self.b @@ -1559,9 +1552,8 @@ class TestVariousIteratorArgs(unittest.TestCase): def test_constructor(self): for cons in (set, frozenset): for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)): - with test_support.check_warnings(): - for g in (G, I, Ig, S, L, R): - self.assertEqual(sorted(cons(g(s))), sorted(g(s))) + for g in (G, I, Ig, S, L, R): + self.assertSetEqual(cons(g(s)), set(g(s))) self.assertRaises(TypeError, cons , X(s)) self.assertRaises(TypeError, cons , N(s)) self.assertRaises(ZeroDivisionError, cons , E(s)) @@ -1576,8 +1568,7 @@ class TestVariousIteratorArgs(unittest.TestCase): if isinstance(expected, bool): self.assertEqual(actual, expected) else: - with test_support.check_warnings(): - self.assertEqual(sorted(actual), sorted(expected)) + self.assertSetEqual(actual, expected) self.assertRaises(TypeError, meth, X(s)) self.assertRaises(TypeError, meth, N(s)) self.assertRaises(ZeroDivisionError, meth, E(s)) @@ -1591,13 +1582,45 @@ class TestVariousIteratorArgs(unittest.TestCase): t = s.copy() getattr(s, methname)(list(g(data))) getattr(t, methname)(g(data)) - with test_support.check_warnings(): - self.assertEqual(sorted(s), sorted(t)) + self.assertSetEqual(s, t) self.assertRaises(TypeError, getattr(set('january'), methname), X(data)) self.assertRaises(TypeError, getattr(set('january'), methname), N(data)) self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data)) +class bad_eq: + def __eq__(self, other): + if be_bad: + set2.clear() + raise ZeroDivisionError + return self is other + def __hash__(self): + return 0 + +class bad_dict_clear: + def __eq__(self, other): + if be_bad: + dict2.clear() + return self is other + def __hash__(self): + return 0 + +class TestWeirdBugs(unittest.TestCase): + def test_8420_set_merge(self): + # This used to segfault + global be_bad, set2, dict2 + be_bad = False + set1 = {bad_eq()} + set2 = {bad_eq() for i in range(75)} + be_bad = True + self.assertRaises(ZeroDivisionError, set1.update, set2) + + be_bad = False + set1 = {bad_dict_clear()} + dict2 = {bad_dict_clear(): None} + be_bad = True + set1.symmetric_difference_update(dict2) + # Application tests (based on David Eppstein's graph recipes ==================================== def powerset(U): @@ -1699,13 +1722,12 @@ class TestGraphs(unittest.TestCase): edge = vertex # Cuboctahedron vertices are edges in Cube self.assertEqual(len(edge), 2) # Two cube vertices define an edge for cubevert in edge: - self.assert_(cubevert in g) + self.assertIn(cubevert, g) #============================================================================== def test_main(verbose=None): - from test import test_sets test_classes = ( TestSet, TestSetSubclass, @@ -1740,6 +1762,7 @@ def test_main(verbose=None): TestIdentities, TestVariousIteratorArgs, TestGraphs, + TestWeirdBugs, ) test_support.run_unittest(*test_classes) |