summaryrefslogtreecommitdiff
path: root/Lib/test/test_support.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_support.py')
-rw-r--r--Lib/test/test_support.py47
1 files changed, 24 insertions, 23 deletions
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
index cb2e03b905..81bb3ca017 100644
--- a/Lib/test/test_support.py
+++ b/Lib/test/test_support.py
@@ -35,7 +35,8 @@ __all__ = ["Error", "TestFailed", "ResourceDenied", "import_module",
"run_with_locale", "set_memlimit", "bigmemtest", "bigaddrspacetest",
"BasicTestRunner", "run_unittest", "run_doctest", "threading_setup",
"threading_cleanup", "reap_children", "cpython_only",
- "check_impl_detail", "get_attribute", "py3k_bytes"]
+ "check_impl_detail", "get_attribute", "py3k_bytes",
+ "import_fresh_module"]
class Error(Exception):
@@ -83,23 +84,20 @@ def import_module(name, deprecated=False):
def _save_and_remove_module(name, orig_modules):
"""Helper function to save and remove a module from sys.modules
- Return value is True if the module was in sys.modules and
- False otherwise."""
- saved = True
- try:
- orig_modules[name] = sys.modules[name]
- except KeyError:
- saved = False
- else:
+ Raise ImportError if the module can't be imported."""
+ # try to import the module and raise an error if it can't be imported
+ if name not in sys.modules:
+ __import__(name)
del sys.modules[name]
- return saved
-
+ for modname in list(sys.modules):
+ if modname == name or modname.startswith(name + '.'):
+ orig_modules[modname] = sys.modules[modname]
+ del sys.modules[modname]
def _save_and_block_module(name, orig_modules):
"""Helper function to save and block a module in sys.modules
- Return value is True if the module was in sys.modules and
- False otherwise."""
+ Return True if the module was in sys.modules, False otherwise."""
saved = True
try:
orig_modules[name] = sys.modules[name]
@@ -115,14 +113,15 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
the sys.modules cache is restored to its original state.
Modules named in fresh are also imported anew if needed by the import.
+ If one of these modules can't be imported, None is returned.
Importing of modules named in blocked is prevented while the fresh import
takes place.
If deprecated is True, any module or package deprecation messages
will be suppressed."""
- # NOTE: test_heapq and test_warnings include extra sanity checks to make
- # sure that this utility function is working as expected
+ # NOTE: test_heapq, test_json, and test_warnings include extra sanity
+ # checks to make sure that this utility function is working as expected
with _ignore_deprecated_imports(deprecated):
# Keep track of modules saved for later restoration as well
# as those which just need a blocking entry removed
@@ -136,6 +135,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
if not _save_and_block_module(blocked_name, orig_modules):
names_to_remove.append(blocked_name)
fresh_module = importlib.import_module(name)
+ except ImportError:
+ fresh_module = None
finally:
for orig_name, module in orig_modules.items():
sys.modules[orig_name] = module
@@ -813,14 +814,8 @@ def transient_internet(resource_name, timeout=30.0, errnos=()):
@contextlib.contextmanager
def captured_output(stream_name):
- """Run the 'with' statement body using a StringIO object in place of a
- specific attribute on the sys module.
- Example use (with 'stream_name=stdout')::
-
- with captured_stdout() as s:
- print "hello"
- assert s.getvalue() == "hello"
- """
+ """Return a context manager used by captured_stdout and captured_stdin
+ that temporarily replaces the sys stream *stream_name* with a StringIO."""
import StringIO
orig_stdout = getattr(sys, stream_name)
setattr(sys, stream_name, StringIO.StringIO())
@@ -830,6 +825,12 @@ def captured_output(stream_name):
setattr(sys, stream_name, orig_stdout)
def captured_stdout():
+ """Capture the output of sys.stdout:
+
+ with captured_stdout() as s:
+ print "hello"
+ self.assertEqual(s.getvalue(), "hello")
+ """
return captured_output("stdout")
def captured_stdin():