diff options
Diffstat (limited to 'tests/lib/server.py')
-rw-r--r-- | tests/lib/server.py | 135 |
1 files changed, 37 insertions, 98 deletions
diff --git a/tests/lib/server.py b/tests/lib/server.py index 6db46d166..4cc18452c 100644 --- a/tests/lib/server.py +++ b/tests/lib/server.py @@ -1,59 +1,29 @@ -import os -import signal +import pathlib import ssl import threading from base64 import b64encode from contextlib import contextmanager from textwrap import dedent -from types import TracebackType -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator from unittest.mock import Mock from werkzeug.serving import BaseWSGIServer, WSGIRequestHandler from werkzeug.serving import make_server as _make_server -from .compat import nullcontext +from .compat import blocked_signals -Environ = Dict[str, str] -Status = str -Headers = Iterable[Tuple[str, str]] -ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] -Write = Callable[[bytes], None] -StartResponse = Callable[[Status, Headers, Optional[ExcInfo]], Write] -Body = List[bytes] -Responder = Callable[[Environ, StartResponse], Body] +if TYPE_CHECKING: + from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment +Body = Iterable[bytes] -class MockServer(BaseWSGIServer): - mock = Mock() # type: Mock - - -# Applies on Python 2 and Windows. -if not hasattr(signal, "pthread_sigmask"): - # We're not relying on this behavior anywhere currently, it's just best - # practice. - blocked_signals = nullcontext -else: - - @contextmanager - def blocked_signals(): - """Block all signals for e.g. starting a worker thread.""" - # valid_signals() was added in Python 3.8 (and not using it results - # in a warning on pthread_sigmask() call) - try: - mask = signal.valid_signals() - except AttributeError: - mask = set(range(1, signal.NSIG)) - old_mask = signal.pthread_sigmask(signal.SIG_SETMASK, mask) - try: - yield - finally: - signal.pthread_sigmask(signal.SIG_SETMASK, old_mask) +class MockServer(BaseWSGIServer): + mock: Mock = Mock() class _RequestHandler(WSGIRequestHandler): - def make_environ(self): + def make_environ(self) -> Dict[str, Any]: environ = super().make_environ() # From pallets/werkzeug#1469, will probably be in release after @@ -77,14 +47,14 @@ class _RequestHandler(WSGIRequestHandler): return environ -def _mock_wsgi_adapter(mock): - # type: (Callable[[Environ, StartResponse], Responder]) -> Responder +def _mock_wsgi_adapter( + mock: Callable[["WSGIEnvironment", "StartResponse"], "WSGIApplication"] +) -> "WSGIApplication": """Uses a mock to record function arguments and provide the actual function that should respond. """ - def adapter(environ, start_response): - # type: (Environ, StartResponse) -> Body + def adapter(environ: "WSGIEnvironment", start_response: "StartResponse") -> Body: try: responder = mock(environ, start_response) except StopIteration: @@ -94,8 +64,7 @@ def _mock_wsgi_adapter(mock): return adapter -def make_mock_server(**kwargs): - # type: (Any) -> MockServer +def make_mock_server(**kwargs: Any) -> MockServer: """Creates a mock HTTP(S) server listening on a random port on localhost. The `mock` property of the returned server provides and records all WSGI @@ -135,8 +104,7 @@ def make_mock_server(**kwargs): @contextmanager -def server_running(server): - # type: (BaseWSGIServer) -> None +def server_running(server: BaseWSGIServer) -> Iterator[None]: """Context manager for running the provided server in a separate thread.""" thread = threading.Thread(target=server.serve_forever) thread.daemon = True @@ -152,10 +120,8 @@ def server_running(server): # Helper functions for making responses in a declarative way. -def text_html_response(text): - # type: (str) -> Responder - def responder(environ, start_response): - # type: (Environ, StartResponse) -> Body +def text_html_response(text: str) -> "WSGIApplication": + def responder(environ: "WSGIEnvironment", start_response: "StartResponse") -> Body: start_response( "200 OK", [ @@ -167,8 +133,7 @@ def text_html_response(text): return responder -def html5_page(text): - # type: (str) -> str +def html5_page(text: str) -> str: return ( dedent( """ @@ -185,68 +150,42 @@ def html5_page(text): ) -def index_page(spec): - # type: (Dict[str, str]) -> Responder - def link(name, value): +def package_page(spec: Dict[str, str]) -> "WSGIApplication": + def link(name: str, value: str) -> str: return '<a href="{}">{}</a>'.format(value, name) links = "".join(link(*kv) for kv in spec.items()) return text_html_response(html5_page(links)) -def package_page(spec): - # type: (Dict[str, str]) -> Responder - def link(name, value): - return '<a href="{}">{}</a>'.format(value, name) - - links = "".join(link(*kv) for kv in spec.items()) - return text_html_response(html5_page(links)) - - -def file_response(path): - # type: (str) -> Responder - def responder(environ, start_response): - # type: (Environ, StartResponse) -> Body - size = os.stat(path).st_size +def file_response(path: pathlib.Path) -> "WSGIApplication": + def responder(environ: "WSGIEnvironment", start_response: "StartResponse") -> Body: start_response( "200 OK", [ ("Content-Type", "application/octet-stream"), - ("Content-Length", str(size)), + ("Content-Length", str(path.stat().st_size)), ], ) - - with open(path, "rb") as f: - return [f.read()] + return [path.read_bytes()] return responder -def authorization_response(path): - # type: (str) -> Responder +def authorization_response(path: pathlib.Path) -> "WSGIApplication": correct_auth = "Basic " + b64encode(b"USERNAME:PASSWORD").decode("ascii") - def responder(environ, start_response): - # type: (Environ, StartResponse) -> Body - - if environ.get("HTTP_AUTHORIZATION") == correct_auth: - size = os.stat(path).st_size - start_response( - "200 OK", - [ - ("Content-Type", "application/octet-stream"), - ("Content-Length", str(size)), - ], - ) - else: - start_response( - "401 Unauthorized", - [ - ("WWW-Authenticate", "Basic"), - ], - ) - - with open(path, "rb") as f: - return [f.read()] + def responder(environ: "WSGIEnvironment", start_response: "StartResponse") -> Body: + if environ.get("HTTP_AUTHORIZATION") != correct_auth: + start_response("401 Unauthorized", [("WWW-Authenticate", "Basic")]) + return () + start_response( + "200 OK", + [ + ("Content-Type", "application/octet-stream"), + ("Content-Length", str(path.stat().st_size)), + ], + ) + return [path.read_bytes()] return responder |