summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--django/core/servers/basehttp.py26
-rw-r--r--tests/servers/tests.py66
-rw-r--r--tests/servers/urls.py1
-rw-r--r--tests/servers/views.py6
4 files changed, 78 insertions, 21 deletions
diff --git a/django/core/servers/basehttp.py b/django/core/servers/basehttp.py
index 89306683b8..8a36b67eef 100644
--- a/django/core/servers/basehttp.py
+++ b/django/core/servers/basehttp.py
@@ -74,12 +74,24 @@ class WSGIServer(simple_server.WSGIServer):
class ThreadedWSGIServer(socketserver.ThreadingMixIn, WSGIServer):
"""A threaded version of the WSGIServer"""
- pass
+ daemon_threads = True
class ServerHandler(simple_server.ServerHandler):
http_version = '1.1'
+ def cleanup_headers(self):
+ super().cleanup_headers()
+ # HTTP/1.1 requires us to support persistent connections, so
+ # explicitly send close if we do not know the content length to
+ # prevent clients from reusing the connection.
+ if 'Content-Length' not in self.headers:
+ self.headers['Connection'] = 'close'
+ # Mark the connection for closing if we set it as such above or
+ # if the application sent the header.
+ if self.headers.get('Connection') == 'close':
+ self.request_handler.close_connection = True
+
def handle_error(self):
# Ignore broken pipe errors, otherwise pass on
if not is_broken_pipe_error():
@@ -135,6 +147,16 @@ class WSGIRequestHandler(simple_server.WSGIRequestHandler):
return super().get_environ()
def handle(self):
+ self.close_connection = True
+ self.handle_one_request()
+ while not self.close_connection:
+ self.handle_one_request()
+ try:
+ self.connection.shutdown(socket.SHUT_WR)
+ except (socket.error, AttributeError):
+ pass
+
+ def handle_one_request(self):
"""Copy of WSGIRequestHandler.handle() but with different ServerHandler"""
self.raw_requestline = self.rfile.readline(65537)
if len(self.raw_requestline) > 65536:
@@ -150,7 +172,7 @@ class WSGIRequestHandler(simple_server.WSGIRequestHandler):
handler = ServerHandler(
self.rfile, self.wfile, self.get_stderr(), self.get_environ()
)
- handler.request_handler = self # backpointer for logging
+ handler.request_handler = self # backpointer for logging & connection closing
handler.run(self.server.get_app())
diff --git a/tests/servers/tests.py b/tests/servers/tests.py
index ce08eb4a3f..e38cb5eb07 100644
--- a/tests/servers/tests.py
+++ b/tests/servers/tests.py
@@ -4,8 +4,7 @@ Tests for django.core.servers.
import errno
import os
import socket
-import sys
-from http.client import HTTPConnection, RemoteDisconnected
+from http.client import HTTPConnection
from urllib.error import HTTPError
from urllib.parse import urlencode
from urllib.request import urlopen
@@ -57,29 +56,60 @@ class LiveServerViews(LiveServerBase):
with self.urlopen('/example_view/') as f:
self.assertEqual(f.version, 11)
- @override_settings(MIDDLEWARE=[])
def test_closes_connection_without_content_length(self):
"""
- The server doesn't support keep-alive because Python's http.server
- module that it uses hangs if a Content-Length header isn't set (for
- example, if CommonMiddleware isn't enabled or if the response is a
- StreamingHttpResponse) (#28440 / https://bugs.python.org/issue31076).
+ A HTTP 1.1 server is supposed to support keep-alive. Since our
+ development server is rather simple we support it only in cases where
+ we can detect a content length from the response. This should be doable
+ for all simple views and streaming responses where an iterable with
+ length of one is passed. The latter follows as result of `set_content_length`
+ from https://github.com/python/cpython/blob/master/Lib/wsgiref/handlers.py.
+
+ If we cannot detect a content length we explicitly set the `Connection`
+ header to `close` to notify the client that we do not actually support
+ it.
"""
conn = HTTPConnection(LiveServerViews.server_thread.host, LiveServerViews.server_thread.port, timeout=1)
try:
- conn.request('GET', '/example_view/', headers={'Connection': 'keep-alive'})
- response = conn.getresponse().read()
- conn.request('GET', '/example_view/', headers={'Connection': 'close'})
- # macOS may give ConnectionResetError.
- with self.assertRaises((RemoteDisconnected, ConnectionResetError)):
- try:
- conn.getresponse()
- except ConnectionAbortedError:
- if sys.platform == 'win32':
- self.skipTest('Ignore nondeterministic failure on Windows.')
+ conn.request('GET', '/streaming_example_view/', headers={'Connection': 'keep-alive'})
+ response = conn.getresponse()
+ self.assertTrue(response.will_close)
+ self.assertEqual(response.read(), b'Iamastream')
+ self.assertEqual(response.status, 200)
+ self.assertEqual(response.getheader('Connection'), 'close')
+
+ conn.request('GET', '/streaming_example_view/', headers={'Connection': 'close'})
+ response = conn.getresponse()
+ self.assertTrue(response.will_close)
+ self.assertEqual(response.read(), b'Iamastream')
+ self.assertEqual(response.status, 200)
+ self.assertEqual(response.getheader('Connection'), 'close')
+ finally:
+ conn.close()
+
+ def test_keep_alive_on_connection_with_content_length(self):
+ """
+ See `test_closes_connection_without_content_length` for details. This
+ is a follow up test, which ensure that we do not close the connection
+ if not needed, hence allowing us to take advantage of keep-alive.
+ """
+ conn = HTTPConnection(LiveServerViews.server_thread.host, LiveServerViews.server_thread.port)
+ try:
+ conn.request('GET', '/example_view/', headers={"Connection": "keep-alive"})
+ response = conn.getresponse()
+ self.assertFalse(response.will_close)
+ self.assertEqual(response.read(), b'example view')
+ self.assertEqual(response.status, 200)
+ self.assertIsNone(response.getheader('Connection'))
+
+ conn.request('GET', '/example_view/', headers={"Connection": "close"})
+ response = conn.getresponse()
+ self.assertFalse(response.will_close)
+ self.assertEqual(response.read(), b'example view')
+ self.assertEqual(response.status, 200)
+ self.assertIsNone(response.getheader('Connection'))
finally:
conn.close()
- self.assertEqual(response, b'example view')
def test_404(self):
with self.assertRaises(HTTPError) as err:
diff --git a/tests/servers/urls.py b/tests/servers/urls.py
index 4963bde357..9017161808 100644
--- a/tests/servers/urls.py
+++ b/tests/servers/urls.py
@@ -4,6 +4,7 @@ from . import views
urlpatterns = [
url(r'^example_view/$', views.example_view),
+ url(r'^streaming_example_view/$', views.streaming_example_view),
url(r'^model_view/$', views.model_view),
url(r'^create_model_instance/$', views.create_model_instance),
url(r'^environ_view/$', views.environ_view),
diff --git a/tests/servers/views.py b/tests/servers/views.py
index 3bae0834ab..078be67f46 100644
--- a/tests/servers/views.py
+++ b/tests/servers/views.py
@@ -1,6 +1,6 @@
from urllib.request import urlopen
-from django.http import HttpResponse
+from django.http import HttpResponse, StreamingHttpResponse
from .models import Person
@@ -9,6 +9,10 @@ def example_view(request):
return HttpResponse('example view')
+def streaming_example_view(request):
+ return StreamingHttpResponse((b'I', b'am', b'a', b'stream'))
+
+
def model_view(request):
people = Person.objects.all()
return HttpResponse('\n'.join(person.name for person in people))