diff options
author | Andrew Svetlov <andrew.svetlov@gmail.com> | 2018-09-12 11:43:04 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-12 11:43:04 -0700 |
commit | a5d1eb8d8b7add31b5f5d9bbb31cee1a491b2c08 (patch) | |
tree | 8ffce2f8bcedaea78a0f0eb9c7e1c25f0a32707a /Lib/test/test_asyncio/test_streams.py | |
parent | aca819fb494d4801b3e5b5b507b17cab772c1b40 (diff) | |
download | cpython-git-a5d1eb8d8b7add31b5f5d9bbb31cee1a491b2c08.tar.gz |
bpo-34638: Store a weak reference to stream reader to break strong references loop (GH-9201)
Store a weak reference to stream readerfor breaking strong references
It breaks the strong reference loop between reader and protocol and allows to detect and close the socket if the stream is deleted (garbage collected)
Diffstat (limited to 'Lib/test/test_asyncio/test_streams.py')
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 66d18738b3..67ac9d91a0 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -46,6 +46,8 @@ class StreamTests(test_utils.TestCase): self.assertIs(stream._loop, m_events.get_event_loop.return_value) def _basetest_open_connection(self, open_connection_fut): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) reader, writer = self.loop.run_until_complete(open_connection_fut) writer.write(b'GET / HTTP/1.0\r\n\r\n') f = reader.readline() @@ -55,6 +57,7 @@ class StreamTests(test_utils.TestCase): data = self.loop.run_until_complete(f) self.assertTrue(data.endswith(b'\r\n\r\nTest message')) writer.close() + self.assertEqual(messages, []) def test_open_connection(self): with test_utils.run_test_server() as httpd: @@ -70,6 +73,8 @@ class StreamTests(test_utils.TestCase): self._basetest_open_connection(conn_fut) def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) try: reader, writer = self.loop.run_until_complete(open_connection_fut) finally: @@ -80,6 +85,7 @@ class StreamTests(test_utils.TestCase): self.assertTrue(data.endswith(b'\r\n\r\nTest message')) writer.close() + self.assertEqual(messages, []) @unittest.skipIf(ssl is None, 'No ssl module') def test_open_connection_no_loop_ssl(self): @@ -104,6 +110,8 @@ class StreamTests(test_utils.TestCase): self._basetest_open_connection_no_loop_ssl(conn_fut) def _basetest_open_connection_error(self, open_connection_fut): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) reader, writer = self.loop.run_until_complete(open_connection_fut) writer._protocol.connection_lost(ZeroDivisionError()) f = reader.read() @@ -111,6 +119,7 @@ class StreamTests(test_utils.TestCase): self.loop.run_until_complete(f) writer.close() test_utils.run_briefly(self.loop) + self.assertEqual(messages, []) def test_open_connection_error(self): with test_utils.run_test_server() as httpd: @@ -621,6 +630,9 @@ class StreamTests(test_utils.TestCase): writer.close() return msgback + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + # test the server variant with a coroutine as client handler server = MyServer(self.loop) addr = server.start() @@ -637,6 +649,8 @@ class StreamTests(test_utils.TestCase): server.stop() self.assertEqual(msg, b"hello world!\n") + self.assertEqual(messages, []) + @support.skip_unless_bind_unix_socket def test_start_unix_server(self): @@ -685,6 +699,9 @@ class StreamTests(test_utils.TestCase): writer.close() return msgback + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + # test the server variant with a coroutine as client handler with test_utils.unix_socket_path() as path: server = MyServer(self.loop, path) @@ -703,6 +720,8 @@ class StreamTests(test_utils.TestCase): server.stop() self.assertEqual(msg, b"hello world!\n") + self.assertEqual(messages, []) + @unittest.skipIf(sys.platform == 'win32', "Don't have pipes") def test_read_all_from_pipe_reader(self): # See asyncio issue 168. This test is derived from the example @@ -893,6 +912,58 @@ os.close(fd) wr.close() self.loop.run_until_complete(wr.wait_closed()) + def test_del_stream_before_sock_closing(self): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with test_utils.run_test_server() as httpd: + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, loop=self.loop)) + sock = wr.get_extra_info('socket') + self.assertNotEqual(sock.fileno(), -1) + + wr.write(b'GET / HTTP/1.0\r\n\r\n') + f = rd.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + + # drop refs to reader/writer + del rd + del wr + gc.collect() + # make a chance to close the socket + test_utils.run_briefly(self.loop) + + self.assertEqual(1, len(messages)) + self.assertEqual(sock.fileno(), -1) + + self.assertEqual(1, len(messages)) + self.assertEqual('An open stream object is being garbage ' + 'collected; call "stream.close()" explicitly.', + messages[0]['message']) + + def test_del_stream_before_connection_made(self): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with test_utils.run_test_server() as httpd: + rd = asyncio.StreamReader(loop=self.loop) + pr = asyncio.StreamReaderProtocol(rd, loop=self.loop) + del rd + gc.collect() + tr, _ = self.loop.run_until_complete( + self.loop.create_connection( + lambda: pr, *httpd.address)) + + sock = tr.get_extra_info('socket') + self.assertEqual(sock.fileno(), -1) + + self.assertEqual(1, len(messages)) + self.assertEqual('An open stream was garbage collected prior to ' + 'establishing network connection; ' + 'call "stream.close()" explicitly.', + messages[0]['message']) + if __name__ == '__main__': unittest.main() |