summaryrefslogtreecommitdiff
path: root/Lib
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2021-09-30 19:56:41 +0300
committerGitHub <noreply@github.com>2021-09-30 19:56:41 +0300
commit7873884d4730d7e637a968011b8958bd79fd3398 (patch)
treeb9d4ae24dcb6bb037558299e05726af0611e35de /Lib
parent80285ecc8deaa2b0e7351bf4be863d1a0ad3c188 (diff)
downloadcpython-git-7873884d4730d7e637a968011b8958bd79fd3398.tar.gz
[3.10] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28657)
* Work correctly if an additional fresh module imports other additional fresh module which imports a blocked module. * Raises ImportError if the specified module cannot be imported while all additional fresh modules are successfully imported. * Support blocking packages. * Always restore the import state of fresh and blocked modules and their submodules. * Fix test_decimal and test_xml_etree which depended on an undesired side effect of import_fresh_module(). (cherry picked from commit ec4d917a6a68824f1895f75d113add9410283da7)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/support/import_helper.py67
-rw-r--r--Lib/test/test_decimal.py2
-rw-r--r--Lib/test/test_xml_etree.py13
3 files changed, 30 insertions, 52 deletions
diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py
index 5d1e940687..43ae314834 100644
--- a/Lib/test/support/import_helper.py
+++ b/Lib/test/support/import_helper.py
@@ -80,33 +80,13 @@ def import_module(name, deprecated=False, *, required_on=()):
raise unittest.SkipTest(str(msg))
-def _save_and_remove_module(name, orig_modules):
- """Helper function to save and remove a module from sys.modules
-
- Raise ImportError if the module can't be imported.
- """
- # try to import the module and raise an error if it can't be imported
- if name not in sys.modules:
- __import__(name)
- del sys.modules[name]
+def _save_and_remove_modules(names):
+ orig_modules = {}
+ prefixes = tuple(name + '.' for name in names)
for modname in list(sys.modules):
- if modname == name or modname.startswith(name + '.'):
- orig_modules[modname] = sys.modules[modname]
- del sys.modules[modname]
-
-
-def _save_and_block_module(name, orig_modules):
- """Helper function to save and block a module in sys.modules
-
- Return True if the module was in sys.modules, False otherwise.
- """
- saved = True
- try:
- orig_modules[name] = sys.modules[name]
- except KeyError:
- saved = False
- sys.modules[name] = None
- return saved
+ if modname in names or modname.startswith(prefixes):
+ orig_modules[modname] = sys.modules.pop(modname)
+ return orig_modules
def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
@@ -118,7 +98,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
this operation.
*fresh* is an iterable of additional module names that are also removed
- from the sys.modules cache before doing the import.
+ from the sys.modules cache before doing the import. If one of these
+ modules can't be imported, None is returned.
*blocked* is an iterable of module names that are replaced with None
in the module cache during the import to ensure that attempts to import
@@ -139,24 +120,24 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
with _ignore_deprecated_imports(deprecated):
# Keep track of modules saved for later restoration as well
# as those which just need a blocking entry removed
- orig_modules = {}
- names_to_remove = []
- _save_and_remove_module(name, orig_modules)
+ fresh = list(fresh)
+ blocked = list(blocked)
+ names = {name, *fresh, *blocked}
+ orig_modules = _save_and_remove_modules(names)
+ for modname in blocked:
+ sys.modules[modname] = None
+
try:
- for fresh_name in fresh:
- _save_and_remove_module(fresh_name, orig_modules)
- for blocked_name in blocked:
- if not _save_and_block_module(blocked_name, orig_modules):
- names_to_remove.append(blocked_name)
- fresh_module = importlib.import_module(name)
- except ImportError:
- fresh_module = None
+ # Return None when one of the "fresh" modules can not be imported.
+ try:
+ for modname in fresh:
+ __import__(modname)
+ except ImportError:
+ return None
+ return importlib.import_module(name)
finally:
- for orig_name, module in orig_modules.items():
- sys.modules[orig_name] = module
- for name_to_remove in names_to_remove:
- del sys.modules[name_to_remove]
- return fresh_module
+ _save_and_remove_modules(names)
+ sys.modules.update(orig_modules)
class CleanImport(object):
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 99263bb13b..b6173a5ffe 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -62,7 +62,7 @@ if sys.platform == 'darwin':
C = import_fresh_module('decimal', fresh=['_decimal'])
P = import_fresh_module('decimal', blocked=['_decimal'])
-orig_sys_decimal = sys.modules['decimal']
+import decimal as orig_sys_decimal
# fractions module must import the correct decimal module.
cfractions = import_fresh_module('fractions', fresh=['fractions'])
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index c79b5462b9..5a8824a78f 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -26,7 +26,7 @@ from itertools import product, islice
from test import support
from test.support import os_helper
from test.support import warnings_helper
-from test.support import findfile, gc_collect, swap_attr
+from test.support import findfile, gc_collect, swap_attr, swap_item
from test.support.import_helper import import_fresh_module
from test.support.os_helper import TESTFN
@@ -167,12 +167,11 @@ class ElementTestCase:
cls.modules = {pyET, ET}
def pickleRoundTrip(self, obj, name, dumper, loader, proto):
- save_m = sys.modules[name]
try:
- sys.modules[name] = dumper
- temp = pickle.dumps(obj, proto)
- sys.modules[name] = loader
- result = pickle.loads(temp)
+ with swap_item(sys.modules, name, dumper):
+ temp = pickle.dumps(obj, proto)
+ with swap_item(sys.modules, name, loader):
+ result = pickle.loads(temp)
except pickle.PicklingError as pe:
# pyET must be second, because pyET may be (equal to) ET.
human = dict([(ET, "cET"), (pyET, "pyET")])
@@ -180,8 +179,6 @@ class ElementTestCase:
% (obj,
human.get(dumper, dumper),
human.get(loader, loader))) from pe
- finally:
- sys.modules[name] = save_m
return result
def assertEqualElements(self, alice, bob):