summaryrefslogtreecommitdiff
path: root/Lib/test/test_contextlib.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_contextlib.py')
-rw-r--r--Lib/test/test_contextlib.py69
1 files changed, 59 insertions, 10 deletions
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index f8db88cc58..97470c78fb 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -2,12 +2,14 @@
from __future__ import with_statement
+import sys
import os
import decimal
import tempfile
import unittest
import threading
from contextlib import * # Tests __all__
+from test.test_support import run_suite
class ContextManagerTestCase(unittest.TestCase):
@@ -45,6 +47,28 @@ class ContextManagerTestCase(unittest.TestCase):
self.fail("Expected ZeroDivisionError")
self.assertEqual(state, [1, 42, 999])
+ def test_contextmanager_no_reraise(self):
+ @contextmanager
+ def whee():
+ yield
+ ctx = whee().__context__()
+ ctx.__enter__()
+ # Calling __exit__ should not result in an exception
+ self.failIf(ctx.__exit__(TypeError, TypeError("foo"), None))
+
+ def test_contextmanager_trap_yield_after_throw(self):
+ @contextmanager
+ def whoo():
+ try:
+ yield
+ except:
+ yield
+ ctx = whoo().__context__()
+ ctx.__enter__()
+ self.assertRaises(
+ RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
+ )
+
def test_contextmanager_except(self):
state = []
@contextmanager
@@ -62,6 +86,21 @@ class ContextManagerTestCase(unittest.TestCase):
raise ZeroDivisionError(999)
self.assertEqual(state, [1, 42, 999])
+ def test_contextmanager_attribs(self):
+ def attribs(**kw):
+ def decorate(func):
+ for k,v in kw.items():
+ setattr(func,k,v)
+ return func
+ return decorate
+ @contextmanager
+ @attribs(foo='bar')
+ def baz(spam):
+ """Whee!"""
+ self.assertEqual(baz.__name__,'baz')
+ self.assertEqual(baz.foo, 'bar')
+ self.assertEqual(baz.__doc__, "Whee!")
+
class NestedTestCase(unittest.TestCase):
# XXX This needs more work
@@ -274,21 +313,31 @@ class DecimalContextTestCase(unittest.TestCase):
def testBasic(self):
ctx = decimal.getcontext()
- ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
- with decimal.ExtendedContext:
- self.assertEqual(decimal.getcontext().prec,
- decimal.ExtendedContext.prec)
- self.assertEqual(decimal.getcontext().prec, save_prec)
+ orig_context = ctx.copy()
try:
+ ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
with decimal.ExtendedContext:
self.assertEqual(decimal.getcontext().prec,
decimal.ExtendedContext.prec)
- 1/0
- except ZeroDivisionError:
self.assertEqual(decimal.getcontext().prec, save_prec)
- else:
- self.fail("Didn't raise ZeroDivisionError")
+ try:
+ with decimal.ExtendedContext:
+ self.assertEqual(decimal.getcontext().prec,
+ decimal.ExtendedContext.prec)
+ 1/0
+ except ZeroDivisionError:
+ self.assertEqual(decimal.getcontext().prec, save_prec)
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+ finally:
+ decimal.setcontext(orig_context)
+
+# This is needed to make the test actually run under regrtest.py!
+def test_main():
+ run_suite(
+ unittest.defaultTestLoader.loadTestsFromModule(sys.modules[__name__])
+ )
if __name__ == "__main__":
- unittest.main()
+ test_main()