diff options
Diffstat (limited to 'Lib/test/test_heapq.py')
-rw-r--r-- | Lib/test/test_heapq.py | 94 |
1 files changed, 48 insertions, 46 deletions
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py index efd1080328..5932a40c28 100644 --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -1,31 +1,31 @@ """Unittests for heapq.""" -import random -import unittest -from test import test_support import sys +import random -# We do a bit of trickery here to be able to test both the C implementation -# and the Python implementation of the module. +from test import test_support +from unittest import TestCase, skipUnless -# Make it impossible to import the C implementation anymore. -sys.modules['_heapq'] = 0 -# We must also handle the case that heapq was imported before. -if 'heapq' in sys.modules: - del sys.modules['heapq'] +py_heapq = test_support.import_fresh_module('heapq', blocked=['_heapq']) +c_heapq = test_support.import_fresh_module('heapq', fresh=['_heapq']) -# Now we can import the module and get the pure Python implementation. -import heapq as py_heapq +# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when +# _heapq is imported, so check them there +func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', + 'heapreplace', '_nlargest', '_nsmallest'] -# Restore everything to normal. -del sys.modules['_heapq'] -del sys.modules['heapq'] +class TestModules(TestCase): + def test_py_functions(self): + for fname in func_names: + self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq') -# This is now the module with the C implementation. -import heapq as c_heapq + @skipUnless(c_heapq, 'requires _heapq') + def test_c_functions(self): + for fname in func_names: + self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq') -class TestHeap(unittest.TestCase): +class TestHeap(TestCase): module = None def test_push_pop(self): @@ -61,7 +61,7 @@ class TestHeap(unittest.TestCase): for pos, item in enumerate(heap): if pos: # pos 0 has no parent parentpos = (pos-1) >> 1 - self.assert_(heap[parentpos] <= item) + self.assertTrue(heap[parentpos] <= item) def test_heapify(self): for size in range(30): @@ -190,14 +190,8 @@ class TestHeap(unittest.TestCase): self.assertEqual(self.module.nlargest(n, data, key=f), sorted(data, key=f, reverse=True)[:n]) -class TestHeapPython(TestHeap): - module = py_heapq - -class TestHeapC(TestHeap): - module = c_heapq - def test_comparison_operator(self): - # Issue 3501: Make sure heapq works with both __lt__ and __le__ + # Issue 3051: Make sure heapq works with both __lt__ and __le__ def hsort(data, comp): data = map(comp, data) self.module.heapify(data) @@ -218,6 +212,15 @@ class TestHeapC(TestHeap): self.assertEqual(hsort(data, LE), target) +class TestHeapPython(TestHeap): + module = py_heapq + + +@skipUnless(c_heapq, 'requires _heapq') +class TestHeapC(TestHeap): + module = c_heapq + + #============================================================================== class LenOnly: @@ -312,34 +315,25 @@ def L(seqn): 'Test multiple tiers of iterators' return chain(imap(lambda x:x, R(Ig(G(seqn))))) -class TestErrorHandling(unittest.TestCase): - # only for C implementation - module = c_heapq +class TestErrorHandling(TestCase): + module = None def test_non_sequence(self): for f in (self.module.heapify, self.module.heappop): - self.assertRaises(TypeError, f, 10) + self.assertRaises((TypeError, AttributeError), f, 10) for f in (self.module.heappush, self.module.heapreplace, self.module.nlargest, self.module.nsmallest): - self.assertRaises(TypeError, f, 10, 10) + self.assertRaises((TypeError, AttributeError), f, 10, 10) def test_len_only(self): for f in (self.module.heapify, self.module.heappop): - self.assertRaises(TypeError, f, LenOnly()) + self.assertRaises((TypeError, AttributeError), f, LenOnly()) for f in (self.module.heappush, self.module.heapreplace): - self.assertRaises(TypeError, f, LenOnly(), 10) + self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10) for f in (self.module.nlargest, self.module.nsmallest): self.assertRaises(TypeError, f, 2, LenOnly()) def test_get_only(self): - for f in (self.module.heapify, self.module.heappop): - self.assertRaises(TypeError, f, GetOnly()) - for f in (self.module.heappush, self.module.heapreplace): - self.assertRaises(TypeError, f, GetOnly(), 10) - for f in (self.module.nlargest, self.module.nsmallest): - self.assertRaises(TypeError, f, 2, GetOnly()) - - def test_get_only(self): seq = [CmpErr(), CmpErr(), CmpErr()] for f in (self.module.heapify, self.module.heappop): self.assertRaises(ZeroDivisionError, f, seq) @@ -352,13 +346,13 @@ class TestErrorHandling(unittest.TestCase): for f in (self.module.heapify, self.module.heappop, self.module.heappush, self.module.heapreplace, self.module.nlargest, self.module.nsmallest): - self.assertRaises(TypeError, f, 10) + self.assertRaises((TypeError, AttributeError), f, 10) def test_iterable_args(self): for f in (self.module.nlargest, self.module.nsmallest): for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)): for g in (G, I, Ig, L, R): - with test_support._check_py3k_warnings( + with test_support.check_py3k_warnings( ("comparing unequal types not supported", DeprecationWarning), quiet=True): self.assertEqual(f(2, g(s)), f(2,s)) @@ -368,13 +362,21 @@ class TestErrorHandling(unittest.TestCase): self.assertRaises(ZeroDivisionError, f, 2, E(s)) +class TestErrorHandlingPython(TestErrorHandling): + module = py_heapq + + +@skipUnless(c_heapq, 'requires _heapq') +class TestErrorHandlingC(TestErrorHandling): + module = c_heapq + + #============================================================================== def test_main(verbose=None): - from types import BuiltinFunctionType - - test_classes = [TestHeapPython, TestHeapC, TestErrorHandling] + test_classes = [TestModules, TestHeapPython, TestHeapC, + TestErrorHandlingPython, TestErrorHandlingC] test_support.run_unittest(*test_classes) # verify reference counting |