summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/http_protocol_test.py814
1 files changed, 638 insertions, 176 deletions
diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py
index c095228..408bfc7 100644
--- a/tests/http_protocol_test.py
+++ b/tests/http_protocol_test.py
@@ -203,151 +203,6 @@ class HttpStreamReaderTests(LogTrackingTestCase):
self.assertIn("limit request headers fields size", str(cm.exception))
- def test_read_payload_unknown_encoding(self):
- self.assertRaises(
- ValueError, self.stream.read_length_payload, encoding='unknown')
-
- def test_read_payload(self):
- self.stream.feed_data(b'da')
- self.stream.feed_data(b't')
- self.stream.feed_data(b'ali')
- self.stream.feed_data(b'ne')
-
- stream = self.stream.read_length_payload(4)
- self.assertIsInstance(stream, tulip.StreamReader)
-
- data = self.loop.run_until_complete(tulip.Task(stream.read()))
- self.assertEqual(b'data', data)
- self.assertEqual(b'line', b''.join(self.stream.buffer))
-
- def test_read_payload_eof(self):
- self.stream.feed_data(b'da')
- self.stream.feed_eof()
- stream = self.stream.read_length_payload(4)
-
- self.assertRaises(
- http.client.IncompleteRead,
- self.loop.run_until_complete, tulip.Task(stream.read()))
-
- def test_read_payload_eof_exc(self):
- self.stream.feed_data(b'da')
- stream = self.stream.read_length_payload(4)
-
- def eof():
- yield from []
- self.stream.feed_eof()
-
- t1 = tulip.Task(stream.read())
- t2 = tulip.Task(eof())
-
- self.loop.run_until_complete(tulip.Task(tulip.wait([t1, t2])))
- self.assertRaises(http.client.IncompleteRead, t1.result)
- self.assertIsNone(self.stream._reader)
-
- def test_read_payload_deflate(self):
- comp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
-
- data = b''.join([comp.compress(b'data'), comp.flush()])
- stream = self.stream.read_length_payload(len(data), encoding='deflate')
-
- self.stream.feed_data(data)
-
- data = self.loop.run_until_complete(tulip.Task(stream.read()))
- self.assertEqual(b'data', data)
-
- def _test_read_payload_compress_error(self):
- data = b'123123123datadatadata'
- reader = protocol.length_reader(4)
- self.stream.feed_data(data)
- stream = self.stream.read_payload(reader, 'deflate')
-
- self.assertRaises(
- http.client.IncompleteRead,
- self.loop.run_until_complete, tulip.Task(stream.read()))
-
- def test_read_chunked_payload(self):
- stream = self.stream.read_chunked_payload()
- self.stream.feed_data(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n')
-
- data = self.loop.run_until_complete(tulip.Task(stream.read()))
- self.assertEqual(b'dataline', data)
-
- def test_read_chunked_payload_chunks(self):
- stream = self.stream.read_chunked_payload()
-
- self.stream.feed_data(b'4\r\ndata\r')
- self.stream.feed_data(b'\n4')
- self.stream.feed_data(b'\r')
- self.stream.feed_data(b'\n')
- self.stream.feed_data(b'line\r\n0\r\n')
- self.stream.feed_data(b'test\r\n\r\n')
-
- data = self.loop.run_until_complete(tulip.Task(stream.read()))
- self.assertEqual(b'dataline', data)
-
- def test_read_chunked_payload_incomplete(self):
- stream = self.stream.read_chunked_payload()
-
- self.stream.feed_data(b'4\r\ndata\r\n')
- self.stream.feed_eof()
-
- self.assertRaises(
- http.client.IncompleteRead,
- self.loop.run_until_complete, tulip.Task(stream.read()))
-
- def test_read_chunked_payload_extension(self):
- stream = self.stream.read_chunked_payload()
-
- self.stream.feed_data(
- b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n')
-
- data = self.loop.run_until_complete(tulip.Task(stream.read()))
- self.assertEqual(b'dataline', data)
-
- def test_read_chunked_payload_size_error(self):
- stream = self.stream.read_chunked_payload()
-
- self.stream.feed_data(b'blah\r\n')
- self.assertRaises(
- http.client.IncompleteRead,
- self.loop.run_until_complete, tulip.Task(stream.read()))
-
- def test_read_length_payload(self):
- stream = self.stream.read_length_payload(8)
-
- self.stream.feed_data(b'data')
- self.stream.feed_data(b'data')
-
- data = self.loop.run_until_complete(tulip.Task(stream.read()))
- self.assertEqual(b'datadata', data)
-
- def test_read_length_payload_zero(self):
- stream = self.stream.read_length_payload(0)
-
- self.stream.feed_data(b'data')
-
- data = self.loop.run_until_complete(tulip.Task(stream.read()))
- self.assertEqual(b'', data)
-
- def test_read_length_payload_incomplete(self):
- stream = self.stream.read_length_payload(8)
-
- self.stream.feed_data(b'data')
- self.stream.feed_eof()
-
- self.assertRaises(
- http.client.IncompleteRead,
- self.loop.run_until_complete, tulip.Task(stream.read()))
-
- def test_read_eof_payload(self):
- stream = self.stream.read_eof_payload()
-
- self.stream.feed_data(b'data')
- self.stream.feed_eof()
-
- data = self.loop.run_until_complete(tulip.Task(stream.read()))
- self.assertEqual(b'data', data)
-
def test_read_message_should_close(self):
self.stream.feed_data(
b'Host: example.com\r\nConnection: close\r\n\r\n')
@@ -477,7 +332,7 @@ class HttpStreamReaderTests(LogTrackingTestCase):
payload = self.loop.run_until_complete(tulip.Task(msg.payload.read()))
self.assertEqual(b'dataline', payload)
- def test_read_message_readall(self):
+ def test_read_message_readall_eof(self):
self.stream.feed_data(
b'Host: example.com\r\n\r\n')
self.stream.feed_data(b'data')
@@ -490,46 +345,653 @@ class HttpStreamReaderTests(LogTrackingTestCase):
payload = self.loop.run_until_complete(tulip.Task(msg.payload.read()))
self.assertEqual(b'dataline', payload)
+ def test_read_message_payload(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 8\r\n\r\n')
+ self.stream.feed_data(b'data')
+ self.stream.feed_data(b'data')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message(readall=True)))
+
+ data = self.loop.run_until_complete(tulip.Task(msg.payload.read()))
+ self.assertEqual(b'datadata', data)
+
+ def test_read_message_payload_eof(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 4\r\n\r\n')
+ self.stream.feed_data(b'da')
+ self.stream.feed_eof()
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message(readall=True)))
+
+ self.assertRaises(
+ http.client.IncompleteRead,
+ self.loop.run_until_complete, tulip.Task(msg.payload.read()))
+
+ def test_read_message_length_payload_zero(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 0\r\n\r\n')
+ self.stream.feed_data(b'data')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message()))
+
+ data = self.loop.run_until_complete(tulip.Task(msg.payload.read()))
+ self.assertEqual(b'', data)
+
+ def test_read_message_length_payload_incomplete(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 8\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message()))
+
+ def coro():
+ self.stream.feed_data(b'data')
+ self.stream.feed_eof()
+ return (yield from msg.payload.read())
+
+ self.assertRaises(
+ http.client.IncompleteRead,
+ self.loop.run_until_complete, tulip.Task(coro()))
-class HttpStreamWriterTests(unittest.TestCase):
+ def test_read_message_eof_payload(self):
+ self.stream.feed_data(b'Host: example.com\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message(readall=True)))
+
+ def coro():
+ self.stream.feed_data(b'data')
+ self.stream.feed_eof()
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(tulip.Task(coro()))
+ self.assertEqual(b'data', data)
+
+ def test_read_message_length_payload(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 4\r\n\r\n')
+ self.stream.feed_data(b'da')
+ self.stream.feed_data(b't')
+ self.stream.feed_data(b'ali')
+ self.stream.feed_data(b'ne')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message(readall=True)))
+
+ self.assertIsInstance(msg.payload, tulip.StreamReader)
+
+ data = self.loop.run_until_complete(tulip.Task(msg.payload.read()))
+ self.assertEqual(b'data', data)
+ self.assertEqual(b'line', b''.join(self.stream.buffer))
+
+ def test_read_message_length_payload_extra(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Length: 4\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message()))
+
+ def coro():
+ self.stream.feed_data(b'da')
+ self.stream.feed_data(b't')
+ self.stream.feed_data(b'ali')
+ self.stream.feed_data(b'ne')
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(tulip.Task(coro()))
+ self.assertEqual(b'data', data)
+ self.assertEqual(b'line', b''.join(self.stream.buffer))
+
+ def test_parse_length_payload_eof_exc(self):
+ parser = self.stream._parse_length_payload(4)
+ next(parser)
+
+ stream = tulip.StreamReader()
+ parser.send(stream)
+ self.stream._parser = parser
+ self.stream.feed_data(b'da')
+
+ def eof():
+ yield from []
+ self.stream.feed_eof()
+
+ t1 = tulip.Task(stream.read())
+ t2 = tulip.Task(eof())
+
+ self.loop.run_until_complete(tulip.Task(tulip.wait([t1, t2])))
+ self.assertRaises(http.client.IncompleteRead, t1.result)
+ self.assertIsNone(self.stream._parser)
+
+ def test_read_message_deflate_payload(self):
+ comp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
+
+ data = b''.join([comp.compress(b'data'), comp.flush()])
+
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Content-Encoding: deflate\r\n' +
+ ('Content-Length: %s\r\n\r\n' % len(data)).encode())
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message(readall=True)))
+
+ def coro():
+ self.stream.feed_data(data)
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(tulip.Task(coro()))
+ self.assertEqual(b'data', data)
+
+ def test_read_message_chunked_payload(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Transfer-Encoding: chunked\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message()))
+
+ def coro():
+ self.stream.feed_data(
+ b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n')
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(tulip.Task(coro()))
+ self.assertEqual(b'dataline', data)
+
+ def test_read_message_chunked_payload_chunks(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Transfer-Encoding: chunked\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message()))
+
+ def coro():
+ self.stream.feed_data(b'4\r\ndata\r')
+ self.stream.feed_data(b'\n4')
+ self.stream.feed_data(b'\r')
+ self.stream.feed_data(b'\n')
+ self.stream.feed_data(b'line\r\n0\r\n')
+ self.stream.feed_data(b'test\r\n\r\n')
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(tulip.Task(coro()))
+ self.assertEqual(b'dataline', data)
+
+ def test_read_message_chunked_payload_incomplete(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Transfer-Encoding: chunked\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message()))
+
+ def coro():
+ self.stream.feed_data(b'4\r\ndata\r\n')
+ self.stream.feed_eof()
+ return (yield from msg.payload.read())
+
+ self.assertRaises(
+ http.client.IncompleteRead,
+ self.loop.run_until_complete, tulip.Task(coro()))
+
+ def test_read_message_chunked_payload_extension(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Transfer-Encoding: chunked\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message()))
+
+ def coro():
+ self.stream.feed_data(
+ b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n')
+ return (yield from msg.payload.read())
+
+ data = self.loop.run_until_complete(tulip.Task(coro()))
+ self.assertEqual(b'dataline', data)
+
+ def test_read_message_chunked_payload_size_error(self):
+ self.stream.feed_data(
+ b'Host: example.com\r\n'
+ b'Transfer-Encoding: chunked\r\n\r\n')
+
+ msg = self.loop.run_until_complete(
+ tulip.Task(self.stream.read_message()))
+
+ def coro():
+ self.stream.feed_data(b'blah\r\n')
+ return (yield from msg.payload.read())
+
+ self.assertRaises(
+ http.client.IncompleteRead,
+ self.loop.run_until_complete, tulip.Task(coro()))
+
+ def test_deflate_stream_set_exception(self):
+ stream = tulip.StreamReader()
+ dstream = protocol.DeflateStream(stream, 'deflate')
+
+ exc = ValueError()
+ dstream.set_exception(exc)
+ self.assertIs(exc, stream.exception())
+
+ def test_deflate_stream_feed_data(self):
+ stream = tulip.StreamReader()
+ dstream = protocol.DeflateStream(stream, 'deflate')
+
+ dstream.zlib = unittest.mock.Mock()
+ dstream.zlib.decompress.return_value = b'line'
+
+ dstream.feed_data(b'data')
+ self.assertEqual([b'line'], list(stream.buffer))
+
+ def test_deflate_stream_feed_data_err(self):
+ stream = tulip.StreamReader()
+ dstream = protocol.DeflateStream(stream, 'deflate')
+
+ exc = ValueError()
+ dstream.zlib = unittest.mock.Mock()
+ dstream.zlib.decompress.side_effect = exc
+
+ dstream.feed_data(b'data')
+ self.assertIsInstance(stream.exception(), http.client.IncompleteRead)
+
+ def test_deflate_stream_feed_eof(self):
+ stream = tulip.StreamReader()
+ dstream = protocol.DeflateStream(stream, 'deflate')
+
+ dstream.zlib = unittest.mock.Mock()
+ dstream.zlib.flush.return_value = b'line'
+
+ dstream.feed_eof()
+ self.assertEqual([b'line'], list(stream.buffer))
+ self.assertTrue(stream.eof)
+
+ def test_deflate_stream_feed_eof_err(self):
+ stream = tulip.StreamReader()
+ dstream = protocol.DeflateStream(stream, 'deflate')
+
+ dstream.zlib = unittest.mock.Mock()
+ dstream.zlib.flush.return_value = b'line'
+ dstream.zlib.eof = False
+
+ dstream.feed_eof()
+ self.assertIsInstance(stream.exception(), http.client.IncompleteRead)
+
+
+class HttpMessageTests(unittest.TestCase):
def setUp(self):
self.transport = unittest.mock.Mock()
- self.writer = protocol.HttpStreamWriter(self.transport)
- def test_ctor(self):
- transport = unittest.mock.Mock()
- writer = protocol.HttpStreamWriter(transport, 'latin-1')
- self.assertIs(writer.transport, transport)
- self.assertEqual(writer.encoding, 'latin-1')
+ def test_start_request(self):
+ msg = protocol.Request(
+ self.transport, 'GET', '/index.html', close=True)
+
+ self.assertIs(msg.transport, self.transport)
+ self.assertIsNone(msg.status)
+ self.assertTrue(msg.closing)
+ self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n')
+
+ def test_start_response(self):
+ msg = protocol.Response(self.transport, 200, close=True)
+
+ self.assertIs(msg.transport, self.transport)
+ self.assertEqual(msg.status, 200)
+ self.assertTrue(msg.closing)
+ self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n')
+
+ def test_force_close(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertFalse(msg.closing)
+ msg.force_close()
+ self.assertTrue(msg.closing)
+
+ def test_force_chunked(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertFalse(msg.chunked)
+ msg.force_chunked()
+ self.assertTrue(msg.chunked)
+
+ def test_keep_alive(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertFalse(msg.keep_alive())
+ msg.keepalive = True
+ self.assertTrue(msg.keep_alive())
+
+ msg.force_close()
+ self.assertFalse(msg.keep_alive())
+
+ def test_add_header(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertEqual([], msg.headers)
- def test_encode(self):
- self.assertEqual(b'test', self.writer.encode('test'))
- self.assertEqual(b'test', self.writer.encode(b'test'))
+ msg.add_header('content-type', 'plain/html')
+ self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers)
- def test_decode(self):
- self.assertEqual('test', self.writer.decode('test'))
- self.assertEqual('test', self.writer.decode(b'test'))
+ def test_add_headers(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertEqual([], msg.headers)
- def test_write(self):
- self.writer.write(b'test')
- self.assertTrue(self.transport.write.called)
- self.assertEqual((b'test',), self.transport.write.call_args[0])
+ msg.add_headers(('content-type', 'plain/html'))
+ self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers)
- def test_write_str(self):
- self.writer.write_str('test')
- self.assertTrue(self.transport.write.called)
- self.assertEqual((b'test',), self.transport.write.call_args[0])
+ def test_add_headers_length(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertIsNone(msg.length)
- def test_write_cunked(self):
- self.writer.write_chunked('')
- self.assertFalse(self.transport.write.called)
+ msg.add_headers(('content-length', '200'))
+ self.assertEqual(200, msg.length)
- self.writer.write_chunked('data')
+ def test_add_headers_upgrade(self):
+ msg = protocol.Response(self.transport, 200)
+ self.assertFalse(msg.upgrade)
+
+ msg.add_headers(('connection', 'upgrade'))
+ self.assertTrue(msg.upgrade)
+
+ def test_add_headers_upgrade_websocket(self):
+ msg = protocol.Response(self.transport, 200)
+
+ msg.add_headers(('upgrade', 'test'))
+ self.assertEqual([], msg.headers)
+
+ msg.add_headers(('upgrade', 'websocket'))
+ self.assertEqual([('UPGRADE', 'websocket')], msg.headers)
+
+ def test_add_headers_connection_keepalive(self):
+ msg = protocol.Response(self.transport, 200)
+
+ msg.add_headers(('connection', 'keep-alive'))
+ self.assertEqual([], msg.headers)
+ self.assertTrue(msg.keepalive)
+
+ msg.add_headers(('connection', 'close'))
+ self.assertFalse(msg.keepalive)
+
+ def test_add_headers_hop_headers(self):
+ msg = protocol.Response(self.transport, 200)
+
+ msg.add_headers(('connection', 'test'), ('transfer-encoding', 't'))
+ self.assertEqual([], msg.headers)
+
+ def test_default_headers(self):
+ msg = protocol.Response(self.transport, 200)
+
+ headers = [r for r, _ in msg._default_headers()]
+ self.assertIn('DATE', headers)
+ self.assertIn('CONNECTION', headers)
+
+ def test_default_headers_server(self):
+ msg = protocol.Response(self.transport, 200)
+
+ headers = [r for r, _ in msg._default_headers()]
+ self.assertIn('SERVER', headers)
+
+ def test_default_headers_useragent(self):
+ msg = protocol.Request(self.transport, 'GET', '/')
+
+ headers = [r for r, _ in msg._default_headers()]
+ self.assertNotIn('SERVER', headers)
+ self.assertIn('USER-AGENT', headers)
+
+ def test_default_headers_chunked(self):
+ msg = protocol.Response(self.transport, 200)
+
+ headers = [r for r, _ in msg._default_headers()]
+ self.assertNotIn('TRANSFER-ENCODING', headers)
+
+ msg.force_chunked()
+
+ headers = [r for r, _ in msg._default_headers()]
+ self.assertIn('TRANSFER-ENCODING', headers)
+
+ def test_default_headers_connection_upgrade(self):
+ msg = protocol.Response(self.transport, 200)
+ msg.upgrade = True
+
+ headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION']
+ self.assertEqual([('CONNECTION', 'upgrade')], headers)
+
+ def test_default_headers_connection_close(self):
+ msg = protocol.Response(self.transport, 200)
+ msg.force_close()
+
+ headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION']
+ self.assertEqual([('CONNECTION', 'close')], headers)
+
+ def test_default_headers_connection_keep_alive(self):
+ msg = protocol.Response(self.transport, 200)
+ msg.keepalive = True
+
+ headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION']
+ self.assertEqual([('CONNECTION', 'keep-alive')], headers)
+
+ def test_send_headers(self):
+ write = self.transport.write = unittest.mock.Mock()
+
+ msg = protocol.Response(self.transport, 200)
+ msg.add_headers(('content-type', 'plain/html'))
+ self.assertFalse(msg.is_headers_sent())
+
+ msg.send_headers()
+
+ content = b''.join([arg[1][0] for arg in list(write.mock_calls)])
+
+ self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n'))
+ self.assertIn(b'CONTENT-TYPE: plain/html', content)
+ self.assertTrue(msg.headers_sent)
+ self.assertTrue(msg.is_headers_sent())
+
+ def test_send_headers_nomore_add(self):
+ msg = protocol.Response(self.transport, 200)
+ msg.add_headers(('content-type', 'plain/html'))
+ msg.send_headers()
+
+ self.assertRaises(AssertionError,
+ msg.add_header, 'content-type', 'plain/html')
+
+ def test_prepare_length(self):
+ msg = protocol.Response(self.transport, 200)
+ length = msg._write_length_payload = unittest.mock.Mock()
+ length.return_value = iter([1, 2, 3])
+
+ msg.add_headers(('content-length', '200'))
+ msg.send_headers()
+
+ self.assertTrue(length.called)
+ self.assertTrue((200,), length.call_args[0])
+
+ def test_prepare_chunked_force(self):
+ msg = protocol.Response(self.transport, 200)
+ msg.force_chunked()
+
+ chunked = msg._write_chunked_payload = unittest.mock.Mock()
+ chunked.return_value = iter([1, 2, 3])
+
+ msg.add_headers(('content-length', '200'))
+ msg.send_headers()
+ self.assertTrue(chunked.called)
+
+ def test_prepare_chunked_no_length(self):
+ msg = protocol.Response(self.transport, 200)
+
+ chunked = msg._write_chunked_payload = unittest.mock.Mock()
+ chunked.return_value = iter([1, 2, 3])
+
+ msg.send_headers()
+ self.assertTrue(chunked.called)
+
+ def test_prepare_eof(self):
+ msg = protocol.Response(self.transport, 200, http_version=(1, 0))
+
+ eof = msg._write_eof_payload = unittest.mock.Mock()
+ eof.return_value = iter([1, 2, 3])
+
+ msg.send_headers()
+ self.assertTrue(eof.called)
+
+ def test_write_auto_send_headers(self):
+ msg = protocol.Response(self.transport, 200, http_version=(1, 0))
+ msg._send_headers = True
+
+ msg.write(b'data1')
+ self.assertTrue(msg.headers_sent)
+
+ def test_write_payload_eof(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200, http_version=(1, 0))
+ msg.send_headers()
+
+ msg.write(b'data1')
+ self.assertTrue(msg.headers_sent)
+
+ msg.write(b'data2')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ b'data1data2', content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_chunked(self):
+ write = self.transport.write = unittest.mock.Mock()
+
+ msg = protocol.Response(self.transport, 200)
+ msg.force_chunked()
+ msg.send_headers()
+
+ msg.write(b'data')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ b'4\r\ndata\r\n0\r\n\r\n',
+ content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_chunked_multiple(self):
+ write = self.transport.write = unittest.mock.Mock()
+
+ msg = protocol.Response(self.transport, 200)
+ msg.force_chunked()
+ msg.send_headers()
+
+ msg.write(b'data1')
+ msg.write(b'data2')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n',
+ content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_length(self):
+ write = self.transport.write = unittest.mock.Mock()
+
+ msg = protocol.Response(self.transport, 200)
+ msg.add_headers(('content-length', '2'))
+ msg.send_headers()
+
+ msg.write(b'd')
+ msg.write(b'ata')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ b'da', content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_chunked_filter(self):
+ write = self.transport.write = unittest.mock.Mock()
+
+ msg = protocol.Response(self.transport, 200)
+ msg.send_headers()
+
+ msg.add_chunking_filter(2)
+ msg.write(b'data')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n'))
+
+ def test_write_payload_chunked_filter_mutiple_chunks(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200)
+ msg.send_headers()
+
+ msg.add_chunking_filter(2)
+ msg.write(b'data1')
+ msg.write(b'data2')
+ msg.write_eof()
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertTrue(content.endswith(
+ b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n'
+ b'2\r\na2\r\n0\r\n\r\n'))
+
+ def test_write_payload_chunked_large_chunk(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200)
+ msg.send_headers()
+
+ msg.add_chunking_filter(1024)
+ msg.write(b'data')
+ msg.write_eof()
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n'))
+
+ _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS)
+ _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()])
+
+ def test_write_payload_deflate_filter(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200)
+ msg.add_headers(('content-length', '%s' % len(self._COMPRESSED)))
+ msg.send_headers()
+
+ msg.add_compression_filter('deflate')
+ msg.write(b'data')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_deflate_and_chunked(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200)
+ msg.send_headers()
+
+ msg.add_compression_filter('deflate')
+ msg.add_chunking_filter(2)
+
+ msg.write(b'data')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
self.assertEqual(
- [(b'4\r\n',), (b'data',), (b'\r\n',)],
- [c[0] for c in self.transport.write.call_args_list])
+ b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n',
+ content.split(b'\r\n\r\n', 1)[-1])
+
+ def test_write_payload_chunked_and_deflate(self):
+ write = self.transport.write = unittest.mock.Mock()
+ msg = protocol.Response(self.transport, 200)
+ msg.add_headers(('content-length', '%s' % len(self._COMPRESSED)))
+
+ msg.add_chunking_filter(2)
+ msg.add_compression_filter('deflate')
+ msg.send_headers()
- def test_write_eof(self):
- self.writer.write_chunked_eof()
- self.assertEqual((b'0\r\n\r\n',), self.transport.write.call_args[0])
+ msg.write(b'data')
+ msg.write_eof()
+
+ content = b''.join([c[1][0] for c in list(write.mock_calls)])
+ self.assertEqual(
+ self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1])