summaryrefslogtreecommitdiff
path: root/tests/lib/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/lib/server.py')
-rw-r--r--tests/lib/server.py135
1 files changed, 37 insertions, 98 deletions
diff --git a/tests/lib/server.py b/tests/lib/server.py
index 6db46d166..4cc18452c 100644
--- a/tests/lib/server.py
+++ b/tests/lib/server.py
@@ -1,59 +1,29 @@
-import os
-import signal
+import pathlib
import ssl
import threading
from base64 import b64encode
from contextlib import contextmanager
from textwrap import dedent
-from types import TracebackType
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator
from unittest.mock import Mock
from werkzeug.serving import BaseWSGIServer, WSGIRequestHandler
from werkzeug.serving import make_server as _make_server
-from .compat import nullcontext
+from .compat import blocked_signals
-Environ = Dict[str, str]
-Status = str
-Headers = Iterable[Tuple[str, str]]
-ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
-Write = Callable[[bytes], None]
-StartResponse = Callable[[Status, Headers, Optional[ExcInfo]], Write]
-Body = List[bytes]
-Responder = Callable[[Environ, StartResponse], Body]
+if TYPE_CHECKING:
+ from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment
+Body = Iterable[bytes]
-class MockServer(BaseWSGIServer):
- mock = Mock() # type: Mock
-
-
-# Applies on Python 2 and Windows.
-if not hasattr(signal, "pthread_sigmask"):
- # We're not relying on this behavior anywhere currently, it's just best
- # practice.
- blocked_signals = nullcontext
-else:
-
- @contextmanager
- def blocked_signals():
- """Block all signals for e.g. starting a worker thread."""
- # valid_signals() was added in Python 3.8 (and not using it results
- # in a warning on pthread_sigmask() call)
- try:
- mask = signal.valid_signals()
- except AttributeError:
- mask = set(range(1, signal.NSIG))
- old_mask = signal.pthread_sigmask(signal.SIG_SETMASK, mask)
- try:
- yield
- finally:
- signal.pthread_sigmask(signal.SIG_SETMASK, old_mask)
+class MockServer(BaseWSGIServer):
+ mock: Mock = Mock()
class _RequestHandler(WSGIRequestHandler):
- def make_environ(self):
+ def make_environ(self) -> Dict[str, Any]:
environ = super().make_environ()
# From pallets/werkzeug#1469, will probably be in release after
@@ -77,14 +47,14 @@ class _RequestHandler(WSGIRequestHandler):
return environ
-def _mock_wsgi_adapter(mock):
- # type: (Callable[[Environ, StartResponse], Responder]) -> Responder
+def _mock_wsgi_adapter(
+ mock: Callable[["WSGIEnvironment", "StartResponse"], "WSGIApplication"]
+) -> "WSGIApplication":
"""Uses a mock to record function arguments and provide
the actual function that should respond.
"""
- def adapter(environ, start_response):
- # type: (Environ, StartResponse) -> Body
+ def adapter(environ: "WSGIEnvironment", start_response: "StartResponse") -> Body:
try:
responder = mock(environ, start_response)
except StopIteration:
@@ -94,8 +64,7 @@ def _mock_wsgi_adapter(mock):
return adapter
-def make_mock_server(**kwargs):
- # type: (Any) -> MockServer
+def make_mock_server(**kwargs: Any) -> MockServer:
"""Creates a mock HTTP(S) server listening on a random port on localhost.
The `mock` property of the returned server provides and records all WSGI
@@ -135,8 +104,7 @@ def make_mock_server(**kwargs):
@contextmanager
-def server_running(server):
- # type: (BaseWSGIServer) -> None
+def server_running(server: BaseWSGIServer) -> Iterator[None]:
"""Context manager for running the provided server in a separate thread."""
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
@@ -152,10 +120,8 @@ def server_running(server):
# Helper functions for making responses in a declarative way.
-def text_html_response(text):
- # type: (str) -> Responder
- def responder(environ, start_response):
- # type: (Environ, StartResponse) -> Body
+def text_html_response(text: str) -> "WSGIApplication":
+ def responder(environ: "WSGIEnvironment", start_response: "StartResponse") -> Body:
start_response(
"200 OK",
[
@@ -167,8 +133,7 @@ def text_html_response(text):
return responder
-def html5_page(text):
- # type: (str) -> str
+def html5_page(text: str) -> str:
return (
dedent(
"""
@@ -185,68 +150,42 @@ def html5_page(text):
)
-def index_page(spec):
- # type: (Dict[str, str]) -> Responder
- def link(name, value):
+def package_page(spec: Dict[str, str]) -> "WSGIApplication":
+ def link(name: str, value: str) -> str:
return '<a href="{}">{}</a>'.format(value, name)
links = "".join(link(*kv) for kv in spec.items())
return text_html_response(html5_page(links))
-def package_page(spec):
- # type: (Dict[str, str]) -> Responder
- def link(name, value):
- return '<a href="{}">{}</a>'.format(value, name)
-
- links = "".join(link(*kv) for kv in spec.items())
- return text_html_response(html5_page(links))
-
-
-def file_response(path):
- # type: (str) -> Responder
- def responder(environ, start_response):
- # type: (Environ, StartResponse) -> Body
- size = os.stat(path).st_size
+def file_response(path: pathlib.Path) -> "WSGIApplication":
+ def responder(environ: "WSGIEnvironment", start_response: "StartResponse") -> Body:
start_response(
"200 OK",
[
("Content-Type", "application/octet-stream"),
- ("Content-Length", str(size)),
+ ("Content-Length", str(path.stat().st_size)),
],
)
-
- with open(path, "rb") as f:
- return [f.read()]
+ return [path.read_bytes()]
return responder
-def authorization_response(path):
- # type: (str) -> Responder
+def authorization_response(path: pathlib.Path) -> "WSGIApplication":
correct_auth = "Basic " + b64encode(b"USERNAME:PASSWORD").decode("ascii")
- def responder(environ, start_response):
- # type: (Environ, StartResponse) -> Body
-
- if environ.get("HTTP_AUTHORIZATION") == correct_auth:
- size = os.stat(path).st_size
- start_response(
- "200 OK",
- [
- ("Content-Type", "application/octet-stream"),
- ("Content-Length", str(size)),
- ],
- )
- else:
- start_response(
- "401 Unauthorized",
- [
- ("WWW-Authenticate", "Basic"),
- ],
- )
-
- with open(path, "rb") as f:
- return [f.read()]
+ def responder(environ: "WSGIEnvironment", start_response: "StartResponse") -> Body:
+ if environ.get("HTTP_AUTHORIZATION") != correct_auth:
+ start_response("401 Unauthorized", [("WWW-Authenticate", "Basic")])
+ return ()
+ start_response(
+ "200 OK",
+ [
+ ("Content-Type", "application/octet-stream"),
+ ("Content-Length", str(path.stat().st_size)),
+ ],
+ )
+ return [path.read_bytes()]
return responder