summaryrefslogtreecommitdiff
path: root/Lib/test/support/import_helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/support/import_helper.py')
-rw-r--r--Lib/test/support/import_helper.py67
1 files changed, 24 insertions, 43 deletions
diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py
index 5d1e940687..43ae314834 100644
--- a/Lib/test/support/import_helper.py
+++ b/Lib/test/support/import_helper.py
@@ -80,33 +80,13 @@ def import_module(name, deprecated=False, *, required_on=()):
raise unittest.SkipTest(str(msg))
-def _save_and_remove_module(name, orig_modules):
- """Helper function to save and remove a module from sys.modules
-
- Raise ImportError if the module can't be imported.
- """
- # try to import the module and raise an error if it can't be imported
- if name not in sys.modules:
- __import__(name)
- del sys.modules[name]
+def _save_and_remove_modules(names):
+ orig_modules = {}
+ prefixes = tuple(name + '.' for name in names)
for modname in list(sys.modules):
- if modname == name or modname.startswith(name + '.'):
- orig_modules[modname] = sys.modules[modname]
- del sys.modules[modname]
-
-
-def _save_and_block_module(name, orig_modules):
- """Helper function to save and block a module in sys.modules
-
- Return True if the module was in sys.modules, False otherwise.
- """
- saved = True
- try:
- orig_modules[name] = sys.modules[name]
- except KeyError:
- saved = False
- sys.modules[name] = None
- return saved
+ if modname in names or modname.startswith(prefixes):
+ orig_modules[modname] = sys.modules.pop(modname)
+ return orig_modules
def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
@@ -118,7 +98,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
this operation.
*fresh* is an iterable of additional module names that are also removed
- from the sys.modules cache before doing the import.
+ from the sys.modules cache before doing the import. If one of these
+ modules can't be imported, None is returned.
*blocked* is an iterable of module names that are replaced with None
in the module cache during the import to ensure that attempts to import
@@ -139,24 +120,24 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
with _ignore_deprecated_imports(deprecated):
# Keep track of modules saved for later restoration as well
# as those which just need a blocking entry removed
- orig_modules = {}
- names_to_remove = []
- _save_and_remove_module(name, orig_modules)
+ fresh = list(fresh)
+ blocked = list(blocked)
+ names = {name, *fresh, *blocked}
+ orig_modules = _save_and_remove_modules(names)
+ for modname in blocked:
+ sys.modules[modname] = None
+
try:
- for fresh_name in fresh:
- _save_and_remove_module(fresh_name, orig_modules)
- for blocked_name in blocked:
- if not _save_and_block_module(blocked_name, orig_modules):
- names_to_remove.append(blocked_name)
- fresh_module = importlib.import_module(name)
- except ImportError:
- fresh_module = None
+ # Return None when one of the "fresh" modules can not be imported.
+ try:
+ for modname in fresh:
+ __import__(modname)
+ except ImportError:
+ return None
+ return importlib.import_module(name)
finally:
- for orig_name, module in orig_modules.items():
- sys.modules[orig_name] = module
- for name_to_remove in names_to_remove:
- del sys.modules[name_to_remove]
- return fresh_module
+ _save_and_remove_modules(names)
+ sys.modules.update(orig_modules)
class CleanImport(object):