diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/http_protocol_test.py | 814 |
1 files changed, 638 insertions, 176 deletions
diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index c095228..408bfc7 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -203,151 +203,6 @@ class HttpStreamReaderTests(LogTrackingTestCase): self.assertIn("limit request headers fields size", str(cm.exception)) - def test_read_payload_unknown_encoding(self): - self.assertRaises( - ValueError, self.stream.read_length_payload, encoding='unknown') - - def test_read_payload(self): - self.stream.feed_data(b'da') - self.stream.feed_data(b't') - self.stream.feed_data(b'ali') - self.stream.feed_data(b'ne') - - stream = self.stream.read_length_payload(4) - self.assertIsInstance(stream, tulip.StreamReader) - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'data', data) - self.assertEqual(b'line', b''.join(self.stream.buffer)) - - def test_read_payload_eof(self): - self.stream.feed_data(b'da') - self.stream.feed_eof() - stream = self.stream.read_length_payload(4) - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(stream.read())) - - def test_read_payload_eof_exc(self): - self.stream.feed_data(b'da') - stream = self.stream.read_length_payload(4) - - def eof(): - yield from [] - self.stream.feed_eof() - - t1 = tulip.Task(stream.read()) - t2 = tulip.Task(eof()) - - self.loop.run_until_complete(tulip.Task(tulip.wait([t1, t2]))) - self.assertRaises(http.client.IncompleteRead, t1.result) - self.assertIsNone(self.stream._reader) - - def test_read_payload_deflate(self): - comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) - - data = b''.join([comp.compress(b'data'), comp.flush()]) - stream = self.stream.read_length_payload(len(data), encoding='deflate') - - self.stream.feed_data(data) - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'data', data) - - def _test_read_payload_compress_error(self): - data = b'123123123datadatadata' - reader = protocol.length_reader(4) - self.stream.feed_data(data) - stream = self.stream.read_payload(reader, 'deflate') - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(stream.read())) - - def test_read_chunked_payload(self): - stream = self.stream.read_chunked_payload() - self.stream.feed_data(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'dataline', data) - - def test_read_chunked_payload_chunks(self): - stream = self.stream.read_chunked_payload() - - self.stream.feed_data(b'4\r\ndata\r') - self.stream.feed_data(b'\n4') - self.stream.feed_data(b'\r') - self.stream.feed_data(b'\n') - self.stream.feed_data(b'line\r\n0\r\n') - self.stream.feed_data(b'test\r\n\r\n') - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'dataline', data) - - def test_read_chunked_payload_incomplete(self): - stream = self.stream.read_chunked_payload() - - self.stream.feed_data(b'4\r\ndata\r\n') - self.stream.feed_eof() - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(stream.read())) - - def test_read_chunked_payload_extension(self): - stream = self.stream.read_chunked_payload() - - self.stream.feed_data( - b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'dataline', data) - - def test_read_chunked_payload_size_error(self): - stream = self.stream.read_chunked_payload() - - self.stream.feed_data(b'blah\r\n') - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(stream.read())) - - def test_read_length_payload(self): - stream = self.stream.read_length_payload(8) - - self.stream.feed_data(b'data') - self.stream.feed_data(b'data') - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'datadata', data) - - def test_read_length_payload_zero(self): - stream = self.stream.read_length_payload(0) - - self.stream.feed_data(b'data') - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'', data) - - def test_read_length_payload_incomplete(self): - stream = self.stream.read_length_payload(8) - - self.stream.feed_data(b'data') - self.stream.feed_eof() - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(stream.read())) - - def test_read_eof_payload(self): - stream = self.stream.read_eof_payload() - - self.stream.feed_data(b'data') - self.stream.feed_eof() - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'data', data) - def test_read_message_should_close(self): self.stream.feed_data( b'Host: example.com\r\nConnection: close\r\n\r\n') @@ -477,7 +332,7 @@ class HttpStreamReaderTests(LogTrackingTestCase): payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) self.assertEqual(b'dataline', payload) - def test_read_message_readall(self): + def test_read_message_readall_eof(self): self.stream.feed_data( b'Host: example.com\r\n\r\n') self.stream.feed_data(b'data') @@ -490,46 +345,653 @@ class HttpStreamReaderTests(LogTrackingTestCase): payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) self.assertEqual(b'dataline', payload) + def test_read_message_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + data = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'datadata', data) + + def test_read_message_payload_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(msg.payload.read())) + + def test_read_message_length_payload_zero(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 0\r\n\r\n') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + data = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'', data) + + def test_read_message_length_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(coro())) -class HttpStreamWriterTests(unittest.TestCase): + def test_read_message_eof_payload(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'data', data) + + def test_read_message_length_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + self.assertIsInstance(msg.payload, tulip.StreamReader) + + data = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_read_message_length_payload_extra(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_parse_length_payload_eof_exc(self): + parser = self.stream._parse_length_payload(4) + next(parser) + + stream = tulip.StreamReader() + parser.send(stream) + self.stream._parser = parser + self.stream.feed_data(b'da') + + def eof(): + yield from [] + self.stream.feed_eof() + + t1 = tulip.Task(stream.read()) + t2 = tulip.Task(eof()) + + self.loop.run_until_complete(tulip.Task(tulip.wait([t1, t2]))) + self.assertRaises(http.client.IncompleteRead, t1.result) + self.assertIsNone(self.stream._parser) + + def test_read_message_deflate_payload(self): + comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + + data = b''.join([comp.compress(b'data'), comp.flush()]) + + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Encoding: deflate\r\n' + + ('Content-Length: %s\r\n\r\n' % len(data)).encode()) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + def coro(): + self.stream.feed_data(data) + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'data', data) + + def test_read_message_chunked_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data( + b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_chunks(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r') + self.stream.feed_data(b'\n4') + self.stream.feed_data(b'\r') + self.stream.feed_data(b'\n') + self.stream.feed_data(b'line\r\n0\r\n') + self.stream.feed_data(b'test\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r\n') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(coro())) + + def test_read_message_chunked_payload_extension(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_size_error(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data(b'blah\r\n') + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(coro())) + + def test_deflate_stream_set_exception(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.set_exception(exc) + self.assertIs(exc, stream.exception()) + + def test_deflate_stream_feed_data(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.return_value = b'line' + + dstream.feed_data(b'data') + self.assertEqual([b'line'], list(stream.buffer)) + + def test_deflate_stream_feed_data_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.side_effect = exc + + dstream.feed_data(b'data') + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + def test_deflate_stream_feed_eof(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + + dstream.feed_eof() + self.assertEqual([b'line'], list(stream.buffer)) + self.assertTrue(stream.eof) + + def test_deflate_stream_feed_eof_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + dstream.zlib.eof = False + + dstream.feed_eof() + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + +class HttpMessageTests(unittest.TestCase): def setUp(self): self.transport = unittest.mock.Mock() - self.writer = protocol.HttpStreamWriter(self.transport) - def test_ctor(self): - transport = unittest.mock.Mock() - writer = protocol.HttpStreamWriter(transport, 'latin-1') - self.assertIs(writer.transport, transport) - self.assertEqual(writer.encoding, 'latin-1') + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) - def test_encode(self): - self.assertEqual(b'test', self.writer.encode('test')) - self.assertEqual(b'test', self.writer.encode(b'test')) + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) - def test_decode(self): - self.assertEqual('test', self.writer.decode('test')) - self.assertEqual('test', self.writer.decode(b'test')) + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) - def test_write(self): - self.writer.write(b'test') - self.assertTrue(self.transport.write.called) - self.assertEqual((b'test',), self.transport.write.call_args[0]) + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) - def test_write_str(self): - self.writer.write_str('test') - self.assertTrue(self.transport.write.called) - self.assertEqual((b'test',), self.transport.write.call_args[0]) + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) - def test_write_cunked(self): - self.writer.write_chunked('') - self.assertFalse(self.transport.write.called) + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) - self.writer.write_chunked('data') + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], msg.headers) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], msg.headers) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], msg.headers) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], msg.headers) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg.force_chunked() + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) self.assertEqual( - [(b'4\r\n',), (b'data',), (b'\r\n',)], - [c[0] for c in self.transport.write.call_args_list]) + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() - def test_write_eof(self): - self.writer.write_chunked_eof() - self.assertEqual((b'0\r\n\r\n',), self.transport.write.call_args[0]) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) |