diff options
Diffstat (limited to 'Lib/test/libregrtest/runtest_mp.py')
-rw-r--r-- | Lib/test/libregrtest/runtest_mp.py | 140 |
1 files changed, 95 insertions, 45 deletions
diff --git a/Lib/test/libregrtest/runtest_mp.py b/Lib/test/libregrtest/runtest_mp.py index 3d503af23b..c83e44aed0 100644 --- a/Lib/test/libregrtest/runtest_mp.py +++ b/Lib/test/libregrtest/runtest_mp.py @@ -9,13 +9,15 @@ import sys import threading import time import traceback -import types +from typing import NamedTuple, NoReturn, Literal, Any + from test import support from test.support import os_helper +from test.libregrtest.cmdline import Namespace +from test.libregrtest.main import Regrtest from test.libregrtest.runtest import ( - runtest, INTERRUPTED, CHILD_ERROR, PROGRESS_MIN_TIME, - format_test_result, TestResult, is_failed, TIMEOUT) + runtest, is_failed, TestResult, Interrupted, Timeout, ChildError, PROGRESS_MIN_TIME) from test.libregrtest.setup import setup_tests from test.libregrtest.utils import format_duration, print_warning @@ -36,21 +38,21 @@ JOIN_TIMEOUT = 30.0 # seconds USE_PROCESS_GROUP = (hasattr(os, "setsid") and hasattr(os, "killpg")) -def must_stop(result, ns): - if result.result == INTERRUPTED: +def must_stop(result: TestResult, ns: Namespace) -> bool: + if isinstance(result, Interrupted): return True if ns.failfast and is_failed(result, ns): return True return False -def parse_worker_args(worker_args): +def parse_worker_args(worker_args) -> tuple[Namespace, str]: ns_dict, test_name = json.loads(worker_args) - ns = types.SimpleNamespace(**ns_dict) + ns = Namespace(**ns_dict) return (ns, test_name) -def run_test_in_subprocess(testname, ns): +def run_test_in_subprocess(testname: str, ns: Namespace) -> subprocess.Popen: ns_dict = vars(ns) worker_args = (ns_dict, testname) worker_args = json.dumps(worker_args) @@ -75,15 +77,15 @@ def run_test_in_subprocess(testname, ns): **kw) -def run_tests_worker(ns, test_name): +def run_tests_worker(ns: Namespace, test_name: str) -> NoReturn: setup_tests(ns) result = runtest(ns, test_name) print() # Force a newline (just in case) - # Serialize TestResult as list in JSON - print(json.dumps(list(result)), flush=True) + # Serialize TestResult as dict in JSON + print(json.dumps(result, cls=EncodeTestResult), flush=True) sys.exit(0) @@ -110,15 +112,23 @@ class MultiprocessIterator: self.tests_iter = None -MultiprocessResult = collections.namedtuple('MultiprocessResult', - 'result stdout stderr error_msg') +class MultiprocessResult(NamedTuple): + result: TestResult + stdout: str + stderr: str + error_msg: str + + +ExcStr = str +QueueOutput = tuple[Literal[False], MultiprocessResult] | tuple[Literal[True], ExcStr] + class ExitThread(Exception): pass class TestWorkerProcess(threading.Thread): - def __init__(self, worker_id, runner): + def __init__(self, worker_id: int, runner: "MultiprocessTestRunner") -> None: super().__init__() self.worker_id = worker_id self.pending = runner.pending @@ -132,7 +142,7 @@ class TestWorkerProcess(threading.Thread): self._killed = False self._stopped = False - def __repr__(self): + def __repr__(self) -> str: info = [f'TestWorkerProcess #{self.worker_id}'] if self.is_alive(): info.append("running") @@ -148,7 +158,7 @@ class TestWorkerProcess(threading.Thread): f'time={format_duration(dt)}')) return '<%s>' % ' '.join(info) - def _kill(self): + def _kill(self) -> None: popen = self._popen if popen is None: return @@ -176,18 +186,22 @@ class TestWorkerProcess(threading.Thread): except OSError as exc: print_warning(f"Failed to kill {what}: {exc!r}") - def stop(self): + def stop(self) -> None: # Method called from a different thread to stop this thread self._stopped = True self._kill() - def mp_result_error(self, test_name, error_type, stdout='', stderr='', - err_msg=None): - test_time = time.monotonic() - self.start_time - result = TestResult(test_name, error_type, test_time, None) - return MultiprocessResult(result, stdout, stderr, err_msg) - - def _run_process(self, test_name): + def mp_result_error( + self, + test_result: TestResult, + stdout: str = '', + stderr: str = '', + err_msg=None + ) -> MultiprocessResult: + test_result.duration_sec = time.monotonic() - self.start_time + return MultiprocessResult(test_result, stdout, stderr, err_msg) + + def _run_process(self, test_name: str) -> tuple[int, str, str]: self.start_time = time.monotonic() self.current_test_name = test_name @@ -246,11 +260,11 @@ class TestWorkerProcess(threading.Thread): self._popen = None self.current_test_name = None - def _runtest(self, test_name): + def _runtest(self, test_name: str) -> MultiprocessResult: retcode, stdout, stderr = self._run_process(test_name) if retcode is None: - return self.mp_result_error(test_name, TIMEOUT, stdout, stderr) + return self.mp_result_error(Timeout(test_name), stdout, stderr) err_msg = None if retcode != 0: @@ -263,18 +277,17 @@ class TestWorkerProcess(threading.Thread): else: try: # deserialize run_tests_worker() output - result = json.loads(result) - result = TestResult(*result) + result = json.loads(result, object_hook=decode_test_result) except Exception as exc: err_msg = "Failed to parse worker JSON: %s" % exc if err_msg is not None: - return self.mp_result_error(test_name, CHILD_ERROR, + return self.mp_result_error(ChildError(test_name), stdout, stderr, err_msg) return MultiprocessResult(result, stdout, stderr, err_msg) - def run(self): + def run(self) -> None: while not self._stopped: try: try: @@ -293,7 +306,7 @@ class TestWorkerProcess(threading.Thread): self.output.put((True, traceback.format_exc())) break - def _wait_completed(self): + def _wait_completed(self) -> None: popen = self._popen # stdout and stderr must be closed to ensure that communicate() @@ -308,7 +321,7 @@ class TestWorkerProcess(threading.Thread): f"(timeout={format_duration(JOIN_TIMEOUT)}): " f"{exc!r}") - def wait_stopped(self, start_time): + def wait_stopped(self, start_time: float) -> None: # bpo-38207: MultiprocessTestRunner.stop_workers() called self.stop() # which killed the process. Sometimes, killing the process from the # main thread does not interrupt popen.communicate() in @@ -332,7 +345,7 @@ class TestWorkerProcess(threading.Thread): break -def get_running(workers): +def get_running(workers: list[TestWorkerProcess]) -> list[TestWorkerProcess]: running = [] for worker in workers: current_test_name = worker.current_test_name @@ -346,11 +359,11 @@ def get_running(workers): class MultiprocessTestRunner: - def __init__(self, regrtest): + def __init__(self, regrtest: Regrtest) -> None: self.regrtest = regrtest self.log = self.regrtest.log self.ns = regrtest.ns - self.output = queue.Queue() + self.output: queue.Queue[QueueOutput] = queue.Queue() self.pending = MultiprocessIterator(self.regrtest.tests) if self.ns.timeout is not None: # Rely on faulthandler to kill a worker process. This timouet is @@ -362,7 +375,7 @@ class MultiprocessTestRunner: self.worker_timeout = None self.workers = None - def start_workers(self): + def start_workers(self) -> None: self.workers = [TestWorkerProcess(index, self) for index in range(1, self.ns.use_mp + 1)] msg = f"Run tests in parallel using {len(self.workers)} child processes" @@ -374,14 +387,14 @@ class MultiprocessTestRunner: for worker in self.workers: worker.start() - def stop_workers(self): + def stop_workers(self) -> None: start_time = time.monotonic() for worker in self.workers: worker.stop() for worker in self.workers: worker.wait_stopped(start_time) - def _get_result(self): + def _get_result(self) -> QueueOutput | None: if not any(worker.is_alive() for worker in self.workers): # all worker threads are done: consume pending results try: @@ -407,21 +420,22 @@ class MultiprocessTestRunner: if running and not self.ns.pgo: self.log('running: %s' % ', '.join(running)) - def display_result(self, mp_result): + def display_result(self, mp_result: MultiprocessResult) -> None: result = mp_result.result - text = format_test_result(result) + text = str(result) if mp_result.error_msg is not None: # CHILD_ERROR text += ' (%s)' % mp_result.error_msg - elif (result.test_time >= PROGRESS_MIN_TIME and not self.ns.pgo): - text += ' (%s)' % format_duration(result.test_time) + elif (result.duration_sec >= PROGRESS_MIN_TIME and not self.ns.pgo): + text += ' (%s)' % format_duration(result.duration_sec) running = get_running(self.workers) if running and not self.ns.pgo: text += ' -- running: %s' % ', '.join(running) self.regrtest.display_progress(self.test_index, text) - def _process_result(self, item): + def _process_result(self, item: QueueOutput) -> bool: + """Returns True if test runner must stop.""" if item[0]: # Thread got an exception format_exc = item[1] @@ -443,7 +457,7 @@ class MultiprocessTestRunner: return False - def run_tests(self): + def run_tests(self) -> None: self.start_workers() self.test_index = 0 @@ -469,5 +483,41 @@ class MultiprocessTestRunner: self.stop_workers() -def run_tests_multiprocess(regrtest): +def run_tests_multiprocess(regrtest: Regrtest) -> None: MultiprocessTestRunner(regrtest).run_tests() + + +class EncodeTestResult(json.JSONEncoder): + """Encode a TestResult (sub)class object into a JSON dict.""" + + def default(self, o: Any) -> dict[str, Any]: + if isinstance(o, TestResult): + result = vars(o) + result["__test_result__"] = o.__class__.__name__ + return result + + return super().default(o) + + +def decode_test_result(d: dict[str, Any]) -> TestResult | dict[str, Any]: + """Decode a TestResult (sub)class object from a JSON dict.""" + + if "__test_result__" not in d: + return d + + cls_name = d.pop("__test_result__") + for cls in get_all_test_result_classes(): + if cls.__name__ == cls_name: + return cls(**d) + + +def get_all_test_result_classes() -> set[type[TestResult]]: + prev_count = 0 + classes = {TestResult} + while len(classes) > prev_count: + prev_count = len(classes) + to_add = [] + for cls in classes: + to_add.extend(cls.__subclasses__()) + classes.update(to_add) + return classes |