diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2021-09-30 19:56:41 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-30 19:56:41 +0300 |
commit | 7873884d4730d7e637a968011b8958bd79fd3398 (patch) | |
tree | b9d4ae24dcb6bb037558299e05726af0611e35de /Lib | |
parent | 80285ecc8deaa2b0e7351bf4be863d1a0ad3c188 (diff) | |
download | cpython-git-7873884d4730d7e637a968011b8958bd79fd3398.tar.gz |
[3.10] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28657)
* Work correctly if an additional fresh module imports other
additional fresh module which imports a blocked module.
* Raises ImportError if the specified module cannot be imported
while all additional fresh modules are successfully imported.
* Support blocking packages.
* Always restore the import state of fresh and blocked modules
and their submodules.
* Fix test_decimal and test_xml_etree which depended on an undesired
side effect of import_fresh_module().
(cherry picked from commit ec4d917a6a68824f1895f75d113add9410283da7)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/support/import_helper.py | 67 | ||||
-rw-r--r-- | Lib/test/test_decimal.py | 2 | ||||
-rw-r--r-- | Lib/test/test_xml_etree.py | 13 |
3 files changed, 30 insertions, 52 deletions
diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 5d1e940687..43ae314834 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -80,33 +80,13 @@ def import_module(name, deprecated=False, *, required_on=()): raise unittest.SkipTest(str(msg)) -def _save_and_remove_module(name, orig_modules): - """Helper function to save and remove a module from sys.modules - - Raise ImportError if the module can't be imported. - """ - # try to import the module and raise an error if it can't be imported - if name not in sys.modules: - __import__(name) - del sys.modules[name] +def _save_and_remove_modules(names): + orig_modules = {} + prefixes = tuple(name + '.' for name in names) for modname in list(sys.modules): - if modname == name or modname.startswith(name + '.'): - orig_modules[modname] = sys.modules[modname] - del sys.modules[modname] - - -def _save_and_block_module(name, orig_modules): - """Helper function to save and block a module in sys.modules - - Return True if the module was in sys.modules, False otherwise. - """ - saved = True - try: - orig_modules[name] = sys.modules[name] - except KeyError: - saved = False - sys.modules[name] = None - return saved + if modname in names or modname.startswith(prefixes): + orig_modules[modname] = sys.modules.pop(modname) + return orig_modules def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): @@ -118,7 +98,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): this operation. *fresh* is an iterable of additional module names that are also removed - from the sys.modules cache before doing the import. + from the sys.modules cache before doing the import. If one of these + modules can't be imported, None is returned. *blocked* is an iterable of module names that are replaced with None in the module cache during the import to ensure that attempts to import @@ -139,24 +120,24 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): with _ignore_deprecated_imports(deprecated): # Keep track of modules saved for later restoration as well # as those which just need a blocking entry removed - orig_modules = {} - names_to_remove = [] - _save_and_remove_module(name, orig_modules) + fresh = list(fresh) + blocked = list(blocked) + names = {name, *fresh, *blocked} + orig_modules = _save_and_remove_modules(names) + for modname in blocked: + sys.modules[modname] = None + try: - for fresh_name in fresh: - _save_and_remove_module(fresh_name, orig_modules) - for blocked_name in blocked: - if not _save_and_block_module(blocked_name, orig_modules): - names_to_remove.append(blocked_name) - fresh_module = importlib.import_module(name) - except ImportError: - fresh_module = None + # Return None when one of the "fresh" modules can not be imported. + try: + for modname in fresh: + __import__(modname) + except ImportError: + return None + return importlib.import_module(name) finally: - for orig_name, module in orig_modules.items(): - sys.modules[orig_name] = module - for name_to_remove in names_to_remove: - del sys.modules[name_to_remove] - return fresh_module + _save_and_remove_modules(names) + sys.modules.update(orig_modules) class CleanImport(object): diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 99263bb13b..b6173a5ffe 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -62,7 +62,7 @@ if sys.platform == 'darwin': C = import_fresh_module('decimal', fresh=['_decimal']) P = import_fresh_module('decimal', blocked=['_decimal']) -orig_sys_decimal = sys.modules['decimal'] +import decimal as orig_sys_decimal # fractions module must import the correct decimal module. cfractions = import_fresh_module('fractions', fresh=['fractions']) diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index c79b5462b9..5a8824a78f 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -26,7 +26,7 @@ from itertools import product, islice from test import support from test.support import os_helper from test.support import warnings_helper -from test.support import findfile, gc_collect, swap_attr +from test.support import findfile, gc_collect, swap_attr, swap_item from test.support.import_helper import import_fresh_module from test.support.os_helper import TESTFN @@ -167,12 +167,11 @@ class ElementTestCase: cls.modules = {pyET, ET} def pickleRoundTrip(self, obj, name, dumper, loader, proto): - save_m = sys.modules[name] try: - sys.modules[name] = dumper - temp = pickle.dumps(obj, proto) - sys.modules[name] = loader - result = pickle.loads(temp) + with swap_item(sys.modules, name, dumper): + temp = pickle.dumps(obj, proto) + with swap_item(sys.modules, name, loader): + result = pickle.loads(temp) except pickle.PicklingError as pe: # pyET must be second, because pyET may be (equal to) ET. human = dict([(ET, "cET"), (pyET, "pyET")]) @@ -180,8 +179,6 @@ class ElementTestCase: % (obj, human.get(dumper, dumper), human.get(loader, loader))) from pe - finally: - sys.modules[name] = save_m return result def assertEqualElements(self, alice, bob): |