diff options
Diffstat (limited to 'Lib/test/support/import_helper.py')
-rw-r--r-- | Lib/test/support/import_helper.py | 67 |
1 files changed, 24 insertions, 43 deletions
diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 5d1e940687..43ae314834 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -80,33 +80,13 @@ def import_module(name, deprecated=False, *, required_on=()): raise unittest.SkipTest(str(msg)) -def _save_and_remove_module(name, orig_modules): - """Helper function to save and remove a module from sys.modules - - Raise ImportError if the module can't be imported. - """ - # try to import the module and raise an error if it can't be imported - if name not in sys.modules: - __import__(name) - del sys.modules[name] +def _save_and_remove_modules(names): + orig_modules = {} + prefixes = tuple(name + '.' for name in names) for modname in list(sys.modules): - if modname == name or modname.startswith(name + '.'): - orig_modules[modname] = sys.modules[modname] - del sys.modules[modname] - - -def _save_and_block_module(name, orig_modules): - """Helper function to save and block a module in sys.modules - - Return True if the module was in sys.modules, False otherwise. - """ - saved = True - try: - orig_modules[name] = sys.modules[name] - except KeyError: - saved = False - sys.modules[name] = None - return saved + if modname in names or modname.startswith(prefixes): + orig_modules[modname] = sys.modules.pop(modname) + return orig_modules def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): @@ -118,7 +98,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): this operation. *fresh* is an iterable of additional module names that are also removed - from the sys.modules cache before doing the import. + from the sys.modules cache before doing the import. If one of these + modules can't be imported, None is returned. *blocked* is an iterable of module names that are replaced with None in the module cache during the import to ensure that attempts to import @@ -139,24 +120,24 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): with _ignore_deprecated_imports(deprecated): # Keep track of modules saved for later restoration as well # as those which just need a blocking entry removed - orig_modules = {} - names_to_remove = [] - _save_and_remove_module(name, orig_modules) + fresh = list(fresh) + blocked = list(blocked) + names = {name, *fresh, *blocked} + orig_modules = _save_and_remove_modules(names) + for modname in blocked: + sys.modules[modname] = None + try: - for fresh_name in fresh: - _save_and_remove_module(fresh_name, orig_modules) - for blocked_name in blocked: - if not _save_and_block_module(blocked_name, orig_modules): - names_to_remove.append(blocked_name) - fresh_module = importlib.import_module(name) - except ImportError: - fresh_module = None + # Return None when one of the "fresh" modules can not be imported. + try: + for modname in fresh: + __import__(modname) + except ImportError: + return None + return importlib.import_module(name) finally: - for orig_name, module in orig_modules.items(): - sys.modules[orig_name] = module - for name_to_remove in names_to_remove: - del sys.modules[name_to_remove] - return fresh_module + _save_and_remove_modules(names) + sys.modules.update(orig_modules) class CleanImport(object): |