import asyncio import asyncio.sslproto import contextlib import gc import logging import os import select import socket import ssl import tempfile import threading import time import weakref from test import support from test.test_asyncio import utils as test_utils def tearDownModule(): asyncio.set_event_loop_policy(None) class MyBaseProto(asyncio.Protocol): connected = None done = None def __init__(self, loop=None): self.transport = None self.state = 'INITIAL' self.nbytes = 0 if loop is not None: self.connected = asyncio.Future(loop=loop) self.done = asyncio.Future(loop=loop) def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' if self.connected: self.connected.set_result(None) def data_received(self, data): assert self.state == 'CONNECTED', self.state self.nbytes += len(data) def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' if self.done: self.done.set_result(None) class TestSSL(test_utils.TestCase): PAYLOAD_SIZE = 1024 * 100 TIMEOUT = 60 def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() self.set_event_loop(self.loop) self.addCleanup(self.loop.close) def tearDown(self): # just in case if we have transport close callbacks if not self.loop.is_closed(): test_utils.run_briefly(self.loop) self.doCleanups() support.gc_collect() super().tearDown() def tcp_server(self, server_prog, *, family=socket.AF_INET, addr=None, timeout=5, backlog=1, max_clients=10): if addr is None: if family == getattr(socket, "AF_UNIX", None): with tempfile.NamedTemporaryFile() as tmp: addr = tmp.name else: addr = ('127.0.0.1', 0) sock = socket.socket(family, socket.SOCK_STREAM) if timeout is None: raise RuntimeError('timeout is required') if timeout <= 0: raise RuntimeError('only blocking sockets are supported') sock.settimeout(timeout) try: sock.bind(addr) sock.listen(backlog) except OSError as ex: sock.close() raise ex return TestThreadedServer( self, sock, server_prog, timeout, max_clients) def tcp_client(self, client_prog, family=socket.AF_INET, timeout=10): sock = socket.socket(family, socket.SOCK_STREAM) if timeout is None: raise RuntimeError('timeout is required') if timeout <= 0: raise RuntimeError('only blocking sockets are supported') sock.settimeout(timeout) return TestThreadedClient( self, sock, client_prog, timeout) def unix_server(self, *args, **kwargs): return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) def unix_client(self, *args, **kwargs): return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) def _create_server_ssl_context(self, certfile, keyfile=None): sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.load_cert_chain(certfile, keyfile) return sslcontext def _create_client_ssl_context(self, *, disable_verify=True): sslcontext = ssl.create_default_context() sslcontext.check_hostname = False if disable_verify: sslcontext.verify_mode = ssl.CERT_NONE return sslcontext @contextlib.contextmanager def _silence_eof_received_warning(self): # TODO This warning has to be fixed in asyncio. logger = logging.getLogger('asyncio') filter = logging.Filter('has no effect when using ssl') logger.addFilter(filter) try: yield finally: logger.removeFilter(filter) def _abort_socket_test(self, ex): try: self.loop.stop() finally: self.fail(ex) def new_loop(self): return asyncio.new_event_loop() def new_policy(self): return asyncio.DefaultEventLoopPolicy() async def wait_closed(self, obj): if not isinstance(obj, asyncio.StreamWriter): return try: await obj.wait_closed() except (BrokenPipeError, ConnectionError): pass def test_create_server_ssl_1(self): CNT = 0 # number of clients that were successful TOTAL_CNT = 25 # total number of clients that test will create TIMEOUT = 10.0 # timeout for this test A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) client_sslctx = self._create_client_ssl_context() clients = [] async def handle_client(reader, writer): nonlocal CNT data = await reader.readexactly(len(A_DATA)) self.assertEqual(data, A_DATA) writer.write(b'OK') data = await reader.readexactly(len(B_DATA)) self.assertEqual(data, B_DATA) writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) await writer.drain() writer.close() CNT += 1 async def test_client(addr): fut = asyncio.Future() def prog(sock): try: sock.starttls(client_sslctx) sock.connect(addr) sock.send(A_DATA) data = sock.recv_all(2) self.assertEqual(data, b'OK') sock.send(B_DATA) data = sock.recv_all(4) self.assertEqual(data, b'SPAM') sock.close() except Exception as ex: self.loop.call_soon_threadsafe(fut.set_exception, ex) else: self.loop.call_soon_threadsafe(fut.set_result, None) client = self.tcp_client(prog) client.start() clients.append(client) await fut async def start_server(): extras = {} extras = dict(ssl_handshake_timeout=10.0) srv = await asyncio.start_server( handle_client, '127.0.0.1', 0, family=socket.AF_INET, ssl=sslctx, **extras) try: srv_socks = srv.sockets self.assertTrue(srv_socks) addr = srv_socks[0].getsockname() tasks = [] for _ in range(TOTAL_CNT): tasks.append(test_client(addr)) await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) finally: self.loop.call_soon(srv.close) await srv.wait_closed() with self._silence_eof_received_warning(): self.loop.run_until_complete(start_server()) self.assertEqual(CNT, TOTAL_CNT) for client in clients: client.stop() def test_create_connection_ssl_1(self): self.loop.set_exception_handler(None) CNT = 0 TOTAL_CNT = 25 A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) client_sslctx = self._create_client_ssl_context() def server(sock): sock.starttls( sslctx, server_side=True) data = sock.recv_all(len(A_DATA)) self.assertEqual(data, A_DATA) sock.send(b'OK') data = sock.recv_all(len(B_DATA)) self.assertEqual(data, B_DATA) sock.send(b'SPAM') sock.close() async def client(addr): extras = {} extras = dict(ssl_handshake_timeout=10.0) reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='', **extras) writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') writer.write(B_DATA) self.assertEqual(await reader.readexactly(4), b'SPAM') nonlocal CNT CNT += 1 writer.close() await self.wait_closed(writer) async def client_sock(addr): sock = socket.socket() sock.connect(addr) reader, writer = await asyncio.open_connection( sock=sock, ssl=client_sslctx, server_hostname='') writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') writer.write(B_DATA) self.assertEqual(await reader.readexactly(4), b'SPAM') nonlocal CNT CNT += 1 writer.close() await self.wait_closed(writer) sock.close() def run(coro): nonlocal CNT CNT = 0 async def _gather(*tasks): # trampoline return await asyncio.gather(*tasks) with self.tcp_server(server, max_clients=TOTAL_CNT, backlog=TOTAL_CNT) as srv: tasks = [] for _ in range(TOTAL_CNT): tasks.append(coro(srv.addr)) self.loop.run_until_complete(_gather(*tasks)) self.assertEqual(CNT, TOTAL_CNT) with self._silence_eof_received_warning(): run(client) with self._silence_eof_received_warning(): run(client_sock) def test_create_connection_ssl_slow_handshake(self): client_sslctx = self._create_client_ssl_context() # silence error logger self.loop.set_exception_handler(lambda *args: None) def server(sock): try: sock.recv_all(1024 * 1024) except ConnectionAbortedError: pass finally: sock.close() async def client(addr): reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='', ssl_handshake_timeout=1.0) writer.close() await self.wait_closed(writer) with self.tcp_server(server, max_clients=1, backlog=1) as srv: with self.assertRaisesRegex( ConnectionAbortedError, r'SSL handshake.*is taking longer'): self.loop.run_until_complete(client(srv.addr)) def test_create_connection_ssl_failed_certificate(self): # silence error logger self.loop.set_exception_handler(lambda *args: None) sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) client_sslctx = self._create_client_ssl_context(disable_verify=False) def server(sock): try: sock.starttls( sslctx, server_side=True) sock.connect() except (ssl.SSLError, OSError): pass finally: sock.close() async def client(addr): reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='', ssl_handshake_timeout=1.0) writer.close() await self.wait_closed(writer) with self.tcp_server(server, max_clients=1, backlog=1) as srv: with self.assertRaises(ssl.SSLCertVerificationError): self.loop.run_until_complete(client(srv.addr)) def test_ssl_handshake_timeout(self): # bpo-29970: Check that a connection is aborted if handshake is not # completed in timeout period, instead of remaining open indefinitely client_sslctx = test_utils.simple_client_sslcontext() # silence error logger messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) server_side_aborted = False def server(sock): nonlocal server_side_aborted try: sock.recv_all(1024 * 1024) except ConnectionAbortedError: server_side_aborted = True finally: sock.close() async def client(addr): await asyncio.wait_for( self.loop.create_connection( asyncio.Protocol, *addr, ssl=client_sslctx, server_hostname='', ssl_handshake_timeout=10.0), 0.5) with self.tcp_server(server, max_clients=1, backlog=1) as srv: with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(client(srv.addr)) self.assertTrue(server_side_aborted) # Python issue #23197: cancelling a handshake must not raise an # exception or log an error, even if the handshake failed self.assertEqual(messages, []) def test_ssl_handshake_connection_lost(self): # #246: make sure that no connection_lost() is called before # connection_made() is called first client_sslctx = test_utils.simple_client_sslcontext() # silence error logger self.loop.set_exception_handler(lambda loop, ctx: None) connection_made_called = False connection_lost_called = False def server(sock): sock.recv(1024) # break the connection during handshake sock.close() class ClientProto(asyncio.Protocol): def connection_made(self, transport): nonlocal connection_made_called connection_made_called = True def connection_lost(self, exc): nonlocal connection_lost_called connection_lost_called = True async def client(addr): await self.loop.create_connection( ClientProto, *addr, ssl=client_sslctx, server_hostname=''), with self.tcp_server(server, max_clients=1, backlog=1) as srv: with self.assertRaises(ConnectionResetError): self.loop.run_until_complete(client(srv.addr)) if connection_lost_called: if connection_made_called: self.fail("unexpected call to connection_lost()") else: self.fail("unexpected call to connection_lost() without" "calling connection_made()") elif connection_made_called: self.fail("unexpected call to connection_made()") def test_ssl_connect_accepted_socket(self): proto = ssl.PROTOCOL_TLS_SERVER server_context = ssl.SSLContext(proto) server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY) if hasattr(server_context, 'check_hostname'): server_context.check_hostname = False server_context.verify_mode = ssl.CERT_NONE client_context = ssl.SSLContext(proto) if hasattr(server_context, 'check_hostname'): client_context.check_hostname = False client_context.verify_mode = ssl.CERT_NONE def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): loop = self.loop class MyProto(MyBaseProto): def connection_lost(self, exc): super().connection_lost(exc) loop.call_soon(loop.stop) def data_received(self, data): super().data_received(data) self.transport.write(expected_response) lsock = socket.socket(socket.AF_INET) lsock.bind(('127.0.0.1', 0)) lsock.listen(1) addr = lsock.getsockname() message = b'test data' response = None expected_response = b'roger' def client(): nonlocal response try: csock = socket.socket(socket.AF_INET) if client_ssl is not None: csock = client_ssl.wrap_socket(csock) csock.connect(addr) csock.sendall(message) response = csock.recv(99) csock.close() except Exception as exc: print( "Failure in client thread in test_connect_accepted_socket", exc) thread = threading.Thread(target=client, daemon=True) thread.start() conn, _ = lsock.accept() proto = MyProto(loop=loop) proto.loop = loop extras = {} if server_ssl: extras = dict(ssl_handshake_timeout=10.0) f = loop.create_task( loop.connect_accepted_socket( (lambda: proto), conn, ssl=server_ssl, **extras)) loop.run_forever() conn.close() lsock.close() thread.join(1) self.assertFalse(thread.is_alive()) self.assertEqual(proto.state, 'CLOSED') self.assertEqual(proto.nbytes, len(message)) self.assertEqual(response, expected_response) tr, _ = f.result() if server_ssl: self.assertIn('SSL', tr.__class__.__name__) tr.close() # let it close self.loop.run_until_complete(asyncio.sleep(0.1)) def test_start_tls_client_corrupted_ssl(self): self.loop.set_exception_handler(lambda loop, ctx: None) sslctx = test_utils.simple_server_sslcontext() client_sslctx = test_utils.simple_client_sslcontext() def server(sock): orig_sock = sock.dup() try: sock.starttls( sslctx, server_side=True) sock.sendall(b'A\n') sock.recv_all(1) orig_sock.send(b'please corrupt the SSL connection') except ssl.SSLError: pass finally: sock.close() orig_sock.close() async def client(addr): reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='') self.assertEqual(await reader.readline(), b'A\n') writer.write(b'B') with self.assertRaises(ssl.SSLError): await reader.readline() writer.close() try: await self.wait_closed(writer) except ssl.SSLError: pass return 'OK' with self.tcp_server(server, max_clients=1, backlog=1) as srv: res = self.loop.run_until_complete(client(srv.addr)) self.assertEqual(res, 'OK') def test_start_tls_client_reg_proto_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() def serve(sock): sock.settimeout(self.TIMEOUT) data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.starttls(server_context, server_side=True) sock.sendall(b'O') data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.unwrap() sock.close() class ClientProto(asyncio.Protocol): def __init__(self, on_data, on_eof): self.on_data = on_data self.on_eof = on_eof self.con_made_cnt = 0 def connection_made(proto, tr): proto.con_made_cnt += 1 # Ensure connection_made gets called only once. self.assertEqual(proto.con_made_cnt, 1) def data_received(self, data): self.on_data.set_result(data) def eof_received(self): self.on_eof.set_result(True) async def client(addr): await asyncio.sleep(0.5) on_data = self.loop.create_future() on_eof = self.loop.create_future() tr, proto = await self.loop.create_connection( lambda: ClientProto(on_data, on_eof), *addr) tr.write(HELLO_MSG) new_tr = await self.loop.start_tls(tr, proto, client_context) self.assertEqual(await on_data, b'O') new_tr.write(HELLO_MSG) await on_eof new_tr.close() with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), timeout=10)) def test_create_connection_memory_leak(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY) client_context = self._create_client_ssl_context() def serve(sock): sock.settimeout(self.TIMEOUT) sock.starttls(server_context, server_side=True) sock.sendall(b'O') data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.unwrap() sock.close() class ClientProto(asyncio.Protocol): def __init__(self, on_data, on_eof): self.on_data = on_data self.on_eof = on_eof self.con_made_cnt = 0 def connection_made(proto, tr): # XXX: We assume user stores the transport in protocol proto.tr = tr proto.con_made_cnt += 1 # Ensure connection_made gets called only once. self.assertEqual(proto.con_made_cnt, 1) def data_received(self, data): self.on_data.set_result(data) def eof_received(self): self.on_eof.set_result(True) async def client(addr): await asyncio.sleep(0.5) on_data = self.loop.create_future() on_eof = self.loop.create_future() tr, proto = await self.loop.create_connection( lambda: ClientProto(on_data, on_eof), *addr, ssl=client_context) self.assertEqual(await on_data, b'O') tr.write(HELLO_MSG) await on_eof tr.close() with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), timeout=10)) # No garbage is left for SSL client from loop.create_connection, even # if user stores the SSLTransport in corresponding protocol instance client_context = weakref.ref(client_context) self.assertIsNone(client_context()) def test_start_tls_client_buf_proto_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() client_con_made_calls = 0 def serve(sock): sock.settimeout(self.TIMEOUT) data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.starttls(server_context, server_side=True) sock.sendall(b'O') data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.sendall(b'2') data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.unwrap() sock.close() class ClientProtoFirst(asyncio.BufferedProtocol): def __init__(self, on_data): self.on_data = on_data self.buf = bytearray(1) def connection_made(self, tr): nonlocal client_con_made_calls client_con_made_calls += 1 def get_buffer(self, sizehint): return self.buf def buffer_updated(self, nsize): assert nsize == 1 self.on_data.set_result(bytes(self.buf[:nsize])) def eof_received(self): pass class ClientProtoSecond(asyncio.Protocol): def __init__(self, on_data, on_eof): self.on_data = on_data self.on_eof = on_eof self.con_made_cnt = 0 def connection_made(self, tr): nonlocal client_con_made_calls client_con_made_calls += 1 def data_received(self, data): self.on_data.set_result(data) def eof_received(self): self.on_eof.set_result(True) async def client(addr): await asyncio.sleep(0.5) on_data1 = self.loop.create_future() on_data2 = self.loop.create_future() on_eof = self.loop.create_future() tr, proto = await self.loop.create_connection( lambda: ClientProtoFirst(on_data1), *addr) tr.write(HELLO_MSG) new_tr = await self.loop.start_tls(tr, proto, client_context) self.assertEqual(await on_data1, b'O') new_tr.write(HELLO_MSG) new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) self.assertEqual(await on_data2, b'2') new_tr.write(HELLO_MSG) await on_eof new_tr.close() # connection_made() should be called only once -- when # we establish connection for the first time. Start TLS # doesn't call connection_made() on application protocols. self.assertEqual(client_con_made_calls, 1) with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), timeout=self.TIMEOUT)) def test_start_tls_slow_client_cancel(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE client_context = test_utils.simple_client_sslcontext() server_waits_on_handshake = self.loop.create_future() def serve(sock): sock.settimeout(self.TIMEOUT) data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) try: self.loop.call_soon_threadsafe( server_waits_on_handshake.set_result, None) data = sock.recv_all(1024 * 1024) except ConnectionAbortedError: pass finally: sock.close() class ClientProto(asyncio.Protocol): def __init__(self, on_data, on_eof): self.on_data = on_data self.on_eof = on_eof self.con_made_cnt = 0 def connection_made(proto, tr): proto.con_made_cnt += 1 # Ensure connection_made gets called only once. self.assertEqual(proto.con_made_cnt, 1) def data_received(self, data): self.on_data.set_result(data) def eof_received(self): self.on_eof.set_result(True) async def client(addr): await asyncio.sleep(0.5) on_data = self.loop.create_future() on_eof = self.loop.create_future() tr, proto = await self.loop.create_connection( lambda: ClientProto(on_data, on_eof), *addr) tr.write(HELLO_MSG) await server_waits_on_handshake with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for( self.loop.start_tls(tr, proto, client_context), 0.5) with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), timeout=10)) def test_start_tls_server_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() def client(sock, addr): sock.settimeout(self.TIMEOUT) sock.connect(addr) data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.starttls(client_context) sock.sendall(HELLO_MSG) sock.unwrap() sock.close() class ServerProto(asyncio.Protocol): def __init__(self, on_con, on_eof, on_con_lost): self.on_con = on_con self.on_eof = on_eof self.on_con_lost = on_con_lost self.data = b'' def connection_made(self, tr): self.on_con.set_result(tr) def data_received(self, data): self.data += data def eof_received(self): self.on_eof.set_result(1) def connection_lost(self, exc): if exc is None: self.on_con_lost.set_result(None) else: self.on_con_lost.set_exception(exc) async def main(proto, on_con, on_eof, on_con_lost): tr = await on_con tr.write(HELLO_MSG) self.assertEqual(proto.data, b'') new_tr = await self.loop.start_tls( tr, proto, server_context, server_side=True, ssl_handshake_timeout=self.TIMEOUT) await on_eof await on_con_lost self.assertEqual(proto.data, HELLO_MSG) new_tr.close() async def run_main(): on_con = self.loop.create_future() on_eof = self.loop.create_future() on_con_lost = self.loop.create_future() proto = ServerProto(on_con, on_eof, on_con_lost) server = await self.loop.create_server( lambda: proto, '127.0.0.1', 0) addr = server.sockets[0].getsockname() with self.tcp_client(lambda sock: client(sock, addr), timeout=self.TIMEOUT): await asyncio.wait_for( main(proto, on_con, on_eof, on_con_lost), timeout=self.TIMEOUT) server.close() await server.wait_closed() self.loop.run_until_complete(run_main()) def test_create_server_ssl_over_ssl(self): CNT = 0 # number of clients that were successful TOTAL_CNT = 25 # total number of clients that test will create TIMEOUT = 10.0 # timeout for this test A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx_1 = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY) client_sslctx_1 = self._create_client_ssl_context() sslctx_2 = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY) client_sslctx_2 = self._create_client_ssl_context() clients = [] async def handle_client(reader, writer): nonlocal CNT data = await reader.readexactly(len(A_DATA)) self.assertEqual(data, A_DATA) writer.write(b'OK') data = await reader.readexactly(len(B_DATA)) self.assertEqual(data, B_DATA) writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) await writer.drain() writer.close() CNT += 1 class ServerProtocol(asyncio.StreamReaderProtocol): def connection_made(self, transport): super_ = super() transport.pause_reading() fut = self._loop.create_task(self._loop.start_tls( transport, self, sslctx_2, server_side=True)) def cb(_): try: tr = fut.result() except Exception as ex: super_.connection_lost(ex) else: super_.connection_made(tr) fut.add_done_callback(cb) def server_protocol_factory(): reader = asyncio.StreamReader() protocol = ServerProtocol(reader, handle_client) return protocol async def test_client(addr): fut = asyncio.Future() def prog(sock): try: sock.connect(addr) sock.starttls(client_sslctx_1) # because wrap_socket() doesn't work correctly on # SSLSocket, we have to do the 2nd level SSL manually incoming = ssl.MemoryBIO() outgoing = ssl.MemoryBIO() sslobj = client_sslctx_2.wrap_bio(incoming, outgoing) def do(func, *args): while True: try: rv = func(*args) break except ssl.SSLWantReadError: if outgoing.pending: sock.send(outgoing.read()) incoming.write(sock.recv(65536)) if outgoing.pending: sock.send(outgoing.read()) return rv do(sslobj.do_handshake) do(sslobj.write, A_DATA) data = do(sslobj.read, 2) self.assertEqual(data, b'OK') do(sslobj.write, B_DATA) data = b'' while True: chunk = do(sslobj.read, 4) if not chunk: break data += chunk self.assertEqual(data, b'SPAM') do(sslobj.unwrap) sock.close() except Exception as ex: self.loop.call_soon_threadsafe(fut.set_exception, ex) sock.close() else: self.loop.call_soon_threadsafe(fut.set_result, None) client = self.tcp_client(prog) client.start() clients.append(client) await fut async def start_server(): extras = {} srv = await self.loop.create_server( server_protocol_factory, '127.0.0.1', 0, family=socket.AF_INET, ssl=sslctx_1, **extras) try: srv_socks = srv.sockets self.assertTrue(srv_socks) addr = srv_socks[0].getsockname() tasks = [] for _ in range(TOTAL_CNT): tasks.append(test_client(addr)) await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) finally: self.loop.call_soon(srv.close) await srv.wait_closed() with self._silence_eof_received_warning(): self.loop.run_until_complete(start_server()) self.assertEqual(CNT, TOTAL_CNT) for client in clients: client.stop() def test_shutdown_cleanly(self): CNT = 0 TOTAL_CNT = 25 A_DATA = b'A' * 1024 * 1024 sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY) client_sslctx = self._create_client_ssl_context() def server(sock): sock.starttls( sslctx, server_side=True) data = sock.recv_all(len(A_DATA)) self.assertEqual(data, A_DATA) sock.send(b'OK') sock.unwrap() sock.close() async def client(addr): extras = {} extras = dict(ssl_handshake_timeout=10.0) reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='', **extras) writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') self.assertEqual(await reader.read(), b'') nonlocal CNT CNT += 1 writer.close() await self.wait_closed(writer) def run(coro): nonlocal CNT CNT = 0 async def _gather(*tasks): return await asyncio.gather(*tasks) with self.tcp_server(server, max_clients=TOTAL_CNT, backlog=TOTAL_CNT) as srv: tasks = [] for _ in range(TOTAL_CNT): tasks.append(coro(srv.addr)) self.loop.run_until_complete( _gather(*tasks)) self.assertEqual(CNT, TOTAL_CNT) with self._silence_eof_received_warning(): run(client) def test_flush_before_shutdown(self): CHUNK = 1024 * 128 SIZE = 32 sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY) client_sslctx = self._create_client_ssl_context() if hasattr(ssl, 'OP_NO_TLSv1_3'): client_sslctx.options |= ssl.OP_NO_TLSv1_3 future = None def server(sock): sock.starttls(sslctx, server_side=True) self.assertEqual(sock.recv_all(4), b'ping') sock.send(b'pong') time.sleep(0.5) # hopefully stuck the TCP buffer data = sock.recv_all(CHUNK * SIZE) self.assertEqual(len(data), CHUNK * SIZE) sock.close() def run(meth): def wrapper(sock): try: meth(sock) except Exception as ex: self.loop.call_soon_threadsafe(future.set_exception, ex) else: self.loop.call_soon_threadsafe(future.set_result, None) return wrapper async def client(addr): nonlocal future future = self.loop.create_future() reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='') sslprotocol = writer.transport._ssl_protocol writer.write(b'ping') data = await reader.readexactly(4) self.assertEqual(data, b'pong') sslprotocol.pause_writing() for _ in range(SIZE): writer.write(b'x' * CHUNK) writer.close() sslprotocol.resume_writing() await self.wait_closed(writer) try: data = await reader.read() self.assertEqual(data, b'') except ConnectionResetError: pass await future with self.tcp_server(run(server)) as srv: self.loop.run_until_complete(client(srv.addr)) def test_remote_shutdown_receives_trailing_data(self): CHUNK = 1024 * 128 SIZE = 32 sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) client_sslctx = self._create_client_ssl_context() future = None def server(sock): incoming = ssl.MemoryBIO() outgoing = ssl.MemoryBIO() sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) while True: try: sslobj.do_handshake() except ssl.SSLWantReadError: if outgoing.pending: sock.send(outgoing.read()) incoming.write(sock.recv(16384)) else: if outgoing.pending: sock.send(outgoing.read()) break while True: try: data = sslobj.read(4) except ssl.SSLWantReadError: incoming.write(sock.recv(16384)) else: break self.assertEqual(data, b'ping') sslobj.write(b'pong') sock.send(outgoing.read()) time.sleep(0.2) # wait for the peer to fill its backlog # send close_notify but don't wait for response with self.assertRaises(ssl.SSLWantReadError): sslobj.unwrap() sock.send(outgoing.read()) # should receive all data data_len = 0 while True: try: chunk = len(sslobj.read(16384)) data_len += chunk except ssl.SSLWantReadError: incoming.write(sock.recv(16384)) except ssl.SSLZeroReturnError: break self.assertEqual(data_len, CHUNK * SIZE) # verify that close_notify is received sslobj.unwrap() sock.close() def eof_server(sock): sock.starttls(sslctx, server_side=True) self.assertEqual(sock.recv_all(4), b'ping') sock.send(b'pong') time.sleep(0.2) # wait for the peer to fill its backlog # send EOF sock.shutdown(socket.SHUT_WR) # should receive all data data = sock.recv_all(CHUNK * SIZE) self.assertEqual(len(data), CHUNK * SIZE) sock.close() async def client(addr): nonlocal future future = self.loop.create_future() reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='') writer.write(b'ping') data = await reader.readexactly(4) self.assertEqual(data, b'pong') # fill write backlog in a hacky way - renegotiation won't help for _ in range(SIZE): writer.transport._test__append_write_backlog(b'x' * CHUNK) try: data = await reader.read() self.assertEqual(data, b'') except (BrokenPipeError, ConnectionResetError): pass await future writer.close() await self.wait_closed(writer) def run(meth): def wrapper(sock): try: meth(sock) except Exception as ex: self.loop.call_soon_threadsafe(future.set_exception, ex) else: self.loop.call_soon_threadsafe(future.set_result, None) return wrapper with self.tcp_server(run(server)) as srv: self.loop.run_until_complete(client(srv.addr)) with self.tcp_server(run(eof_server)) as srv: self.loop.run_until_complete(client(srv.addr)) def test_connect_timeout_warning(self): s = socket.socket(socket.AF_INET) s.bind(('127.0.0.1', 0)) addr = s.getsockname() async def test(): try: await asyncio.wait_for( self.loop.create_connection(asyncio.Protocol, *addr, ssl=True), 0.1) except (ConnectionRefusedError, asyncio.TimeoutError): pass else: self.fail('TimeoutError is not raised') with s: try: with self.assertWarns(ResourceWarning) as cm: self.loop.run_until_complete(test()) gc.collect() gc.collect() gc.collect() except AssertionError as e: self.assertEqual(str(e), 'ResourceWarning not triggered') else: self.fail('Unexpected ResourceWarning: {}'.format(cm.warning)) def test_handshake_timeout_handler_leak(self): s = socket.socket(socket.AF_INET) s.bind(('127.0.0.1', 0)) s.listen(1) addr = s.getsockname() async def test(ctx): try: await asyncio.wait_for( self.loop.create_connection(asyncio.Protocol, *addr, ssl=ctx), 0.1) except (ConnectionRefusedError, asyncio.TimeoutError): pass else: self.fail('TimeoutError is not raised') with s: ctx = ssl.create_default_context() self.loop.run_until_complete(test(ctx)) ctx = weakref.ref(ctx) # SSLProtocol should be DECREF to 0 self.assertIsNone(ctx()) def test_shutdown_timeout_handler_leak(self): loop = self.loop def server(sock): sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) sock = sslctx.wrap_socket(sock, server_side=True) sock.recv(32) sock.close() class Protocol(asyncio.Protocol): def __init__(self): self.fut = asyncio.Future(loop=loop) def connection_lost(self, exc): self.fut.set_result(None) async def client(addr, ctx): tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) tr.close() await pr.fut with self.tcp_server(server) as srv: ctx = self._create_client_ssl_context() loop.run_until_complete(client(srv.addr, ctx)) ctx = weakref.ref(ctx) # asyncio has no shutdown timeout, but it ends up with a circular # reference loop - not ideal (introduces gc glitches), but at least # not leaking gc.collect() gc.collect() gc.collect() # SSLProtocol should be DECREF to 0 self.assertIsNone(ctx()) def test_shutdown_timeout_handler_not_set(self): loop = self.loop eof = asyncio.Event() extra = None def server(sock): sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) sock = sslctx.wrap_socket(sock, server_side=True) sock.send(b'hello') assert sock.recv(1024) == b'world' sock.send(b'extra bytes') # sending EOF here sock.shutdown(socket.SHUT_WR) loop.call_soon_threadsafe(eof.set) # make sure we have enough time to reproduce the issue assert sock.recv(1024) == b'' sock.close() class Protocol(asyncio.Protocol): def __init__(self): self.fut = asyncio.Future(loop=loop) self.transport = None def connection_made(self, transport): self.transport = transport def data_received(self, data): if data == b'hello': self.transport.write(b'world') # pause reading would make incoming data stay in the sslobj self.transport.pause_reading() else: nonlocal extra extra = data def connection_lost(self, exc): if exc is None: self.fut.set_result(None) else: self.fut.set_exception(exc) async def client(addr): ctx = self._create_client_ssl_context() tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) await eof.wait() tr.resume_reading() await pr.fut tr.close() assert extra == b'extra bytes' with self.tcp_server(server) as srv: loop.run_until_complete(client(srv.addr)) ############################################################################### # Socket Testing Utilities ############################################################################### class TestSocketWrapper: def __init__(self, sock): self.__sock = sock def recv_all(self, n): buf = b'' while len(buf) < n: data = self.recv(n - len(buf)) if data == b'': raise ConnectionAbortedError buf += data return buf def starttls(self, ssl_context, *, server_side=False, server_hostname=None, do_handshake_on_connect=True): assert isinstance(ssl_context, ssl.SSLContext) ssl_sock = ssl_context.wrap_socket( self.__sock, server_side=server_side, server_hostname=server_hostname, do_handshake_on_connect=do_handshake_on_connect) if server_side: ssl_sock.do_handshake() self.__sock.close() self.__sock = ssl_sock def __getattr__(self, name): return getattr(self.__sock, name) def __repr__(self): return '<{} {!r}>'.format(type(self).__name__, self.__sock) class SocketThread(threading.Thread): def stop(self): self._active = False self.join() def __enter__(self): self.start() return self def __exit__(self, *exc): self.stop() class TestThreadedClient(SocketThread): def __init__(self, test, sock, prog, timeout): threading.Thread.__init__(self, None, None, 'test-client') self.daemon = True self._timeout = timeout self._sock = sock self._active = True self._prog = prog self._test = test def run(self): try: self._prog(TestSocketWrapper(self._sock)) except (KeyboardInterrupt, SystemExit): raise except BaseException as ex: self._test._abort_socket_test(ex) class TestThreadedServer(SocketThread): def __init__(self, test, sock, prog, timeout, max_clients): threading.Thread.__init__(self, None, None, 'test-server') self.daemon = True self._clients = 0 self._finished_clients = 0 self._max_clients = max_clients self._timeout = timeout self._sock = sock self._active = True self._prog = prog self._s1, self._s2 = socket.socketpair() self._s1.setblocking(False) self._test = test def stop(self): try: if self._s2 and self._s2.fileno() != -1: try: self._s2.send(b'stop') except OSError: pass finally: super().stop() def run(self): try: with self._sock: self._sock.setblocking(0) self._run() finally: self._s1.close() self._s2.close() def _run(self): while self._active: if self._clients >= self._max_clients: return r, w, x = select.select( [self._sock, self._s1], [], [], self._timeout) if self._s1 in r: return if self._sock in r: try: conn, addr = self._sock.accept() except BlockingIOError: continue except socket.timeout: if not self._active: return else: raise else: self._clients += 1 conn.settimeout(self._timeout) try: with conn: self._handle_client(conn) except (KeyboardInterrupt, SystemExit): raise except BaseException as ex: self._active = False try: raise finally: self._test._abort_socket_test(ex) def _handle_client(self, sock): self._prog(TestSocketWrapper(sock)) @property def addr(self): return self._sock.getsockname()