summaryrefslogtreecommitdiff
path: root/Lib/test/test_asyncio/test_ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_asyncio/test_ssl.py')
-rw-r--r--Lib/test/test_asyncio/test_ssl.py1718
1 files changed, 1718 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py
new file mode 100644
index 0000000000..38235c63e0
--- /dev/null
+++ b/Lib/test/test_asyncio/test_ssl.py
@@ -0,0 +1,1718 @@
+import asyncio
+import asyncio.sslproto
+import contextlib
+import gc
+import logging
+import os
+import select
+import socket
+import ssl
+import tempfile
+import threading
+import time
+import weakref
+
+from test import support
+from test.test_asyncio import utils as test_utils
+
+
+def tearDownModule():
+ asyncio.set_event_loop_policy(None)
+
+
+class MyBaseProto(asyncio.Protocol):
+ connected = None
+ done = None
+
+ def __init__(self, loop=None):
+ self.transport = None
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.connected = asyncio.Future(loop=loop)
+ self.done = asyncio.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ if self.connected:
+ self.connected.set_result(None)
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class TestSSL(test_utils.TestCase):
+
+ PAYLOAD_SIZE = 1024 * 100
+ TIMEOUT = 60
+
+ def setUp(self):
+ super().setUp()
+ self.loop = asyncio.new_event_loop()
+ self.set_event_loop(self.loop)
+ self.addCleanup(self.loop.close)
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ if not self.loop.is_closed():
+ test_utils.run_briefly(self.loop)
+
+ self.doCleanups()
+ support.gc_collect()
+ super().tearDown()
+
+ def tcp_server(self, server_prog, *,
+ family=socket.AF_INET,
+ addr=None,
+ timeout=5,
+ backlog=1,
+ max_clients=10):
+
+ if addr is None:
+ if family == getattr(socket, "AF_UNIX", None):
+ with tempfile.NamedTemporaryFile() as tmp:
+ addr = tmp.name
+ else:
+ addr = ('127.0.0.1', 0)
+
+ sock = socket.socket(family, socket.SOCK_STREAM)
+
+ if timeout is None:
+ raise RuntimeError('timeout is required')
+ if timeout <= 0:
+ raise RuntimeError('only blocking sockets are supported')
+ sock.settimeout(timeout)
+
+ try:
+ sock.bind(addr)
+ sock.listen(backlog)
+ except OSError as ex:
+ sock.close()
+ raise ex
+
+ return TestThreadedServer(
+ self, sock, server_prog, timeout, max_clients)
+
+ def tcp_client(self, client_prog,
+ family=socket.AF_INET,
+ timeout=10):
+
+ sock = socket.socket(family, socket.SOCK_STREAM)
+
+ if timeout is None:
+ raise RuntimeError('timeout is required')
+ if timeout <= 0:
+ raise RuntimeError('only blocking sockets are supported')
+ sock.settimeout(timeout)
+
+ return TestThreadedClient(
+ self, sock, client_prog, timeout)
+
+ def unix_server(self, *args, **kwargs):
+ return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
+
+ def unix_client(self, *args, **kwargs):
+ return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
+
+ def _create_server_ssl_context(self, certfile, keyfile=None):
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ sslcontext.options |= ssl.OP_NO_SSLv2
+ sslcontext.load_cert_chain(certfile, keyfile)
+ return sslcontext
+
+ def _create_client_ssl_context(self, *, disable_verify=True):
+ sslcontext = ssl.create_default_context()
+ sslcontext.check_hostname = False
+ if disable_verify:
+ sslcontext.verify_mode = ssl.CERT_NONE
+ return sslcontext
+
+ @contextlib.contextmanager
+ def _silence_eof_received_warning(self):
+ # TODO This warning has to be fixed in asyncio.
+ logger = logging.getLogger('asyncio')
+ filter = logging.Filter('has no effect when using ssl')
+ logger.addFilter(filter)
+ try:
+ yield
+ finally:
+ logger.removeFilter(filter)
+
+ def _abort_socket_test(self, ex):
+ try:
+ self.loop.stop()
+ finally:
+ self.fail(ex)
+
+ def new_loop(self):
+ return asyncio.new_event_loop()
+
+ def new_policy(self):
+ return asyncio.DefaultEventLoopPolicy()
+
+ async def wait_closed(self, obj):
+ if not isinstance(obj, asyncio.StreamWriter):
+ return
+ try:
+ await obj.wait_closed()
+ except (BrokenPipeError, ConnectionError):
+ pass
+
+ def test_create_server_ssl_1(self):
+ CNT = 0 # number of clients that were successful
+ TOTAL_CNT = 25 # total number of clients that test will create
+ TIMEOUT = 10.0 # timeout for this test
+
+ A_DATA = b'A' * 1024 * 1024
+ B_DATA = b'B' * 1024 * 1024
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY
+ )
+ client_sslctx = self._create_client_ssl_context()
+
+ clients = []
+
+ async def handle_client(reader, writer):
+ nonlocal CNT
+
+ data = await reader.readexactly(len(A_DATA))
+ self.assertEqual(data, A_DATA)
+ writer.write(b'OK')
+
+ data = await reader.readexactly(len(B_DATA))
+ self.assertEqual(data, B_DATA)
+ writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
+
+ await writer.drain()
+ writer.close()
+
+ CNT += 1
+
+ async def test_client(addr):
+ fut = asyncio.Future()
+
+ def prog(sock):
+ try:
+ sock.starttls(client_sslctx)
+ sock.connect(addr)
+ sock.send(A_DATA)
+
+ data = sock.recv_all(2)
+ self.assertEqual(data, b'OK')
+
+ sock.send(B_DATA)
+ data = sock.recv_all(4)
+ self.assertEqual(data, b'SPAM')
+
+ sock.close()
+
+ except Exception as ex:
+ self.loop.call_soon_threadsafe(fut.set_exception, ex)
+ else:
+ self.loop.call_soon_threadsafe(fut.set_result, None)
+
+ client = self.tcp_client(prog)
+ client.start()
+ clients.append(client)
+
+ await fut
+
+ async def start_server():
+ extras = {}
+ extras = dict(ssl_handshake_timeout=10.0)
+
+ srv = await asyncio.start_server(
+ handle_client,
+ '127.0.0.1', 0,
+ family=socket.AF_INET,
+ ssl=sslctx,
+ **extras)
+
+ try:
+ srv_socks = srv.sockets
+ self.assertTrue(srv_socks)
+
+ addr = srv_socks[0].getsockname()
+
+ tasks = []
+ for _ in range(TOTAL_CNT):
+ tasks.append(test_client(addr))
+
+ await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
+
+ finally:
+ self.loop.call_soon(srv.close)
+ await srv.wait_closed()
+
+ with self._silence_eof_received_warning():
+ self.loop.run_until_complete(start_server())
+
+ self.assertEqual(CNT, TOTAL_CNT)
+
+ for client in clients:
+ client.stop()
+
+ def test_create_connection_ssl_1(self):
+ self.loop.set_exception_handler(None)
+
+ CNT = 0
+ TOTAL_CNT = 25
+
+ A_DATA = b'A' * 1024 * 1024
+ B_DATA = b'B' * 1024 * 1024
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ client_sslctx = self._create_client_ssl_context()
+
+ def server(sock):
+ sock.starttls(
+ sslctx,
+ server_side=True)
+
+ data = sock.recv_all(len(A_DATA))
+ self.assertEqual(data, A_DATA)
+ sock.send(b'OK')
+
+ data = sock.recv_all(len(B_DATA))
+ self.assertEqual(data, B_DATA)
+ sock.send(b'SPAM')
+
+ sock.close()
+
+ async def client(addr):
+ extras = {}
+ extras = dict(ssl_handshake_timeout=10.0)
+
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ **extras)
+
+ writer.write(A_DATA)
+ self.assertEqual(await reader.readexactly(2), b'OK')
+
+ writer.write(B_DATA)
+ self.assertEqual(await reader.readexactly(4), b'SPAM')
+
+ nonlocal CNT
+ CNT += 1
+
+ writer.close()
+ await self.wait_closed(writer)
+
+ async def client_sock(addr):
+ sock = socket.socket()
+ sock.connect(addr)
+ reader, writer = await asyncio.open_connection(
+ sock=sock,
+ ssl=client_sslctx,
+ server_hostname='')
+
+ writer.write(A_DATA)
+ self.assertEqual(await reader.readexactly(2), b'OK')
+
+ writer.write(B_DATA)
+ self.assertEqual(await reader.readexactly(4), b'SPAM')
+
+ nonlocal CNT
+ CNT += 1
+
+ writer.close()
+ await self.wait_closed(writer)
+ sock.close()
+
+ def run(coro):
+ nonlocal CNT
+ CNT = 0
+
+ async def _gather(*tasks):
+ # trampoline
+ return await asyncio.gather(*tasks)
+
+ with self.tcp_server(server,
+ max_clients=TOTAL_CNT,
+ backlog=TOTAL_CNT) as srv:
+ tasks = []
+ for _ in range(TOTAL_CNT):
+ tasks.append(coro(srv.addr))
+
+ self.loop.run_until_complete(_gather(*tasks))
+
+ self.assertEqual(CNT, TOTAL_CNT)
+
+ with self._silence_eof_received_warning():
+ run(client)
+
+ with self._silence_eof_received_warning():
+ run(client_sock)
+
+ def test_create_connection_ssl_slow_handshake(self):
+ client_sslctx = self._create_client_ssl_context()
+
+ # silence error logger
+ self.loop.set_exception_handler(lambda *args: None)
+
+ def server(sock):
+ try:
+ sock.recv_all(1024 * 1024)
+ except ConnectionAbortedError:
+ pass
+ finally:
+ sock.close()
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ ssl_handshake_timeout=1.0)
+ writer.close()
+ await self.wait_closed(writer)
+
+ with self.tcp_server(server,
+ max_clients=1,
+ backlog=1) as srv:
+
+ with self.assertRaisesRegex(
+ ConnectionAbortedError,
+ r'SSL handshake.*is taking longer'):
+
+ self.loop.run_until_complete(client(srv.addr))
+
+ def test_create_connection_ssl_failed_certificate(self):
+ # silence error logger
+ self.loop.set_exception_handler(lambda *args: None)
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ client_sslctx = self._create_client_ssl_context(disable_verify=False)
+
+ def server(sock):
+ try:
+ sock.starttls(
+ sslctx,
+ server_side=True)
+ sock.connect()
+ except (ssl.SSLError, OSError):
+ pass
+ finally:
+ sock.close()
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ ssl_handshake_timeout=1.0)
+ writer.close()
+ await self.wait_closed(writer)
+
+ with self.tcp_server(server,
+ max_clients=1,
+ backlog=1) as srv:
+
+ with self.assertRaises(ssl.SSLCertVerificationError):
+ self.loop.run_until_complete(client(srv.addr))
+
+ def test_ssl_handshake_timeout(self):
+ # bpo-29970: Check that a connection is aborted if handshake is not
+ # completed in timeout period, instead of remaining open indefinitely
+ client_sslctx = test_utils.simple_client_sslcontext()
+
+ # silence error logger
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+ server_side_aborted = False
+
+ def server(sock):
+ nonlocal server_side_aborted
+ try:
+ sock.recv_all(1024 * 1024)
+ except ConnectionAbortedError:
+ server_side_aborted = True
+ finally:
+ sock.close()
+
+ async def client(addr):
+ await asyncio.wait_for(
+ self.loop.create_connection(
+ asyncio.Protocol,
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ ssl_handshake_timeout=10.0),
+ 0.5)
+
+ with self.tcp_server(server,
+ max_clients=1,
+ backlog=1) as srv:
+
+ with self.assertRaises(asyncio.TimeoutError):
+ self.loop.run_until_complete(client(srv.addr))
+
+ self.assertTrue(server_side_aborted)
+
+ # Python issue #23197: cancelling a handshake must not raise an
+ # exception or log an error, even if the handshake failed
+ self.assertEqual(messages, [])
+
+ def test_ssl_handshake_connection_lost(self):
+ # #246: make sure that no connection_lost() is called before
+ # connection_made() is called first
+
+ client_sslctx = test_utils.simple_client_sslcontext()
+
+ # silence error logger
+ self.loop.set_exception_handler(lambda loop, ctx: None)
+
+ connection_made_called = False
+ connection_lost_called = False
+
+ def server(sock):
+ sock.recv(1024)
+ # break the connection during handshake
+ sock.close()
+
+ class ClientProto(asyncio.Protocol):
+ def connection_made(self, transport):
+ nonlocal connection_made_called
+ connection_made_called = True
+
+ def connection_lost(self, exc):
+ nonlocal connection_lost_called
+ connection_lost_called = True
+
+ async def client(addr):
+ await self.loop.create_connection(
+ ClientProto,
+ *addr,
+ ssl=client_sslctx,
+ server_hostname=''),
+
+ with self.tcp_server(server,
+ max_clients=1,
+ backlog=1) as srv:
+
+ with self.assertRaises(ConnectionResetError):
+ self.loop.run_until_complete(client(srv.addr))
+
+ if connection_lost_called:
+ if connection_made_called:
+ self.fail("unexpected call to connection_lost()")
+ else:
+ self.fail("unexpected call to connection_lost() without"
+ "calling connection_made()")
+ elif connection_made_called:
+ self.fail("unexpected call to connection_made()")
+
+ def test_ssl_connect_accepted_socket(self):
+ proto = ssl.PROTOCOL_TLS_SERVER
+ server_context = ssl.SSLContext(proto)
+ server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY)
+ if hasattr(server_context, 'check_hostname'):
+ server_context.check_hostname = False
+ server_context.verify_mode = ssl.CERT_NONE
+
+ client_context = ssl.SSLContext(proto)
+ if hasattr(server_context, 'check_hostname'):
+ client_context.check_hostname = False
+ client_context.verify_mode = ssl.CERT_NONE
+
+ def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
+ loop = self.loop
+
+ class MyProto(MyBaseProto):
+
+ def connection_lost(self, exc):
+ super().connection_lost(exc)
+ loop.call_soon(loop.stop)
+
+ def data_received(self, data):
+ super().data_received(data)
+ self.transport.write(expected_response)
+
+ lsock = socket.socket(socket.AF_INET)
+ lsock.bind(('127.0.0.1', 0))
+ lsock.listen(1)
+ addr = lsock.getsockname()
+
+ message = b'test data'
+ response = None
+ expected_response = b'roger'
+
+ def client():
+ nonlocal response
+ try:
+ csock = socket.socket(socket.AF_INET)
+ if client_ssl is not None:
+ csock = client_ssl.wrap_socket(csock)
+ csock.connect(addr)
+ csock.sendall(message)
+ response = csock.recv(99)
+ csock.close()
+ except Exception as exc:
+ print(
+ "Failure in client thread in test_connect_accepted_socket",
+ exc)
+
+ thread = threading.Thread(target=client, daemon=True)
+ thread.start()
+
+ conn, _ = lsock.accept()
+ proto = MyProto(loop=loop)
+ proto.loop = loop
+
+ extras = {}
+ if server_ssl:
+ extras = dict(ssl_handshake_timeout=10.0)
+
+ f = loop.create_task(
+ loop.connect_accepted_socket(
+ (lambda: proto), conn, ssl=server_ssl,
+ **extras))
+ loop.run_forever()
+ conn.close()
+ lsock.close()
+
+ thread.join(1)
+ self.assertFalse(thread.is_alive())
+ self.assertEqual(proto.state, 'CLOSED')
+ self.assertEqual(proto.nbytes, len(message))
+ self.assertEqual(response, expected_response)
+ tr, _ = f.result()
+
+ if server_ssl:
+ self.assertIn('SSL', tr.__class__.__name__)
+
+ tr.close()
+ # let it close
+ self.loop.run_until_complete(asyncio.sleep(0.1))
+
+ def test_start_tls_client_corrupted_ssl(self):
+ self.loop.set_exception_handler(lambda loop, ctx: None)
+
+ sslctx = test_utils.simple_server_sslcontext()
+ client_sslctx = test_utils.simple_client_sslcontext()
+
+ def server(sock):
+ orig_sock = sock.dup()
+ try:
+ sock.starttls(
+ sslctx,
+ server_side=True)
+ sock.sendall(b'A\n')
+ sock.recv_all(1)
+ orig_sock.send(b'please corrupt the SSL connection')
+ except ssl.SSLError:
+ pass
+ finally:
+ sock.close()
+ orig_sock.close()
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='')
+
+ self.assertEqual(await reader.readline(), b'A\n')
+ writer.write(b'B')
+ with self.assertRaises(ssl.SSLError):
+ await reader.readline()
+ writer.close()
+ try:
+ await self.wait_closed(writer)
+ except ssl.SSLError:
+ pass
+ return 'OK'
+
+ with self.tcp_server(server,
+ max_clients=1,
+ backlog=1) as srv:
+
+ res = self.loop.run_until_complete(client(srv.addr))
+
+ self.assertEqual(res, 'OK')
+
+ def test_start_tls_client_reg_proto_1(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ server_context = test_utils.simple_server_sslcontext()
+ client_context = test_utils.simple_client_sslcontext()
+
+ def serve(sock):
+ sock.settimeout(self.TIMEOUT)
+
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.starttls(server_context, server_side=True)
+
+ sock.sendall(b'O')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.unwrap()
+ sock.close()
+
+ class ClientProto(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(proto, tr):
+ proto.con_made_cnt += 1
+ # Ensure connection_made gets called only once.
+ self.assertEqual(proto.con_made_cnt, 1)
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
+ def eof_received(self):
+ self.on_eof.set_result(True)
+
+ async def client(addr):
+ await asyncio.sleep(0.5)
+
+ on_data = self.loop.create_future()
+ on_eof = self.loop.create_future()
+
+ tr, proto = await self.loop.create_connection(
+ lambda: ClientProto(on_data, on_eof), *addr)
+
+ tr.write(HELLO_MSG)
+ new_tr = await self.loop.start_tls(tr, proto, client_context)
+
+ self.assertEqual(await on_data, b'O')
+ new_tr.write(HELLO_MSG)
+ await on_eof
+
+ new_tr.close()
+
+ with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv.addr), timeout=10))
+
+ def test_create_connection_memory_leak(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ server_context = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY)
+ client_context = self._create_client_ssl_context()
+
+ def serve(sock):
+ sock.settimeout(self.TIMEOUT)
+
+ sock.starttls(server_context, server_side=True)
+
+ sock.sendall(b'O')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.unwrap()
+ sock.close()
+
+ class ClientProto(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(proto, tr):
+ # XXX: We assume user stores the transport in protocol
+ proto.tr = tr
+ proto.con_made_cnt += 1
+ # Ensure connection_made gets called only once.
+ self.assertEqual(proto.con_made_cnt, 1)
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
+ def eof_received(self):
+ self.on_eof.set_result(True)
+
+ async def client(addr):
+ await asyncio.sleep(0.5)
+
+ on_data = self.loop.create_future()
+ on_eof = self.loop.create_future()
+
+ tr, proto = await self.loop.create_connection(
+ lambda: ClientProto(on_data, on_eof), *addr,
+ ssl=client_context)
+
+ self.assertEqual(await on_data, b'O')
+ tr.write(HELLO_MSG)
+ await on_eof
+
+ tr.close()
+
+ with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv.addr), timeout=10))
+
+ # No garbage is left for SSL client from loop.create_connection, even
+ # if user stores the SSLTransport in corresponding protocol instance
+ client_context = weakref.ref(client_context)
+ self.assertIsNone(client_context())
+
+ def test_start_tls_client_buf_proto_1(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ server_context = test_utils.simple_server_sslcontext()
+ client_context = test_utils.simple_client_sslcontext()
+
+ client_con_made_calls = 0
+
+ def serve(sock):
+ sock.settimeout(self.TIMEOUT)
+
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.starttls(server_context, server_side=True)
+
+ sock.sendall(b'O')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.sendall(b'2')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.unwrap()
+ sock.close()
+
+ class ClientProtoFirst(asyncio.BufferedProtocol):
+ def __init__(self, on_data):
+ self.on_data = on_data
+ self.buf = bytearray(1)
+
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
+
+ def get_buffer(self, sizehint):
+ return self.buf
+
+ def buffer_updated(self, nsize):
+ assert nsize == 1
+ self.on_data.set_result(bytes(self.buf[:nsize]))
+
+ def eof_received(self):
+ pass
+
+ class ClientProtoSecond(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
+ def eof_received(self):
+ self.on_eof.set_result(True)
+
+ async def client(addr):
+ await asyncio.sleep(0.5)
+
+ on_data1 = self.loop.create_future()
+ on_data2 = self.loop.create_future()
+ on_eof = self.loop.create_future()
+
+ tr, proto = await self.loop.create_connection(
+ lambda: ClientProtoFirst(on_data1), *addr)
+
+ tr.write(HELLO_MSG)
+ new_tr = await self.loop.start_tls(tr, proto, client_context)
+
+ self.assertEqual(await on_data1, b'O')
+ new_tr.write(HELLO_MSG)
+
+ new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
+ self.assertEqual(await on_data2, b'2')
+ new_tr.write(HELLO_MSG)
+ await on_eof
+
+ new_tr.close()
+
+ # connection_made() should be called only once -- when
+ # we establish connection for the first time. Start TLS
+ # doesn't call connection_made() on application protocols.
+ self.assertEqual(client_con_made_calls, 1)
+
+ with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv.addr),
+ timeout=self.TIMEOUT))
+
+ def test_start_tls_slow_client_cancel(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ client_context = test_utils.simple_client_sslcontext()
+ server_waits_on_handshake = self.loop.create_future()
+
+ def serve(sock):
+ sock.settimeout(self.TIMEOUT)
+
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ try:
+ self.loop.call_soon_threadsafe(
+ server_waits_on_handshake.set_result, None)
+ data = sock.recv_all(1024 * 1024)
+ except ConnectionAbortedError:
+ pass
+ finally:
+ sock.close()
+
+ class ClientProto(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(proto, tr):
+ proto.con_made_cnt += 1
+ # Ensure connection_made gets called only once.
+ self.assertEqual(proto.con_made_cnt, 1)
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
+ def eof_received(self):
+ self.on_eof.set_result(True)
+
+ async def client(addr):
+ await asyncio.sleep(0.5)
+
+ on_data = self.loop.create_future()
+ on_eof = self.loop.create_future()
+
+ tr, proto = await self.loop.create_connection(
+ lambda: ClientProto(on_data, on_eof), *addr)
+
+ tr.write(HELLO_MSG)
+
+ await server_waits_on_handshake
+
+ with self.assertRaises(asyncio.TimeoutError):
+ await asyncio.wait_for(
+ self.loop.start_tls(tr, proto, client_context),
+ 0.5)
+
+ with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv.addr), timeout=10))
+
+ def test_start_tls_server_1(self):
+ HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+ server_context = test_utils.simple_server_sslcontext()
+ client_context = test_utils.simple_client_sslcontext()
+
+ def client(sock, addr):
+ sock.settimeout(self.TIMEOUT)
+
+ sock.connect(addr)
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.starttls(client_context)
+ sock.sendall(HELLO_MSG)
+
+ sock.unwrap()
+ sock.close()
+
+ class ServerProto(asyncio.Protocol):
+ def __init__(self, on_con, on_eof, on_con_lost):
+ self.on_con = on_con
+ self.on_eof = on_eof
+ self.on_con_lost = on_con_lost
+ self.data = b''
+
+ def connection_made(self, tr):
+ self.on_con.set_result(tr)
+
+ def data_received(self, data):
+ self.data += data
+
+ def eof_received(self):
+ self.on_eof.set_result(1)
+
+ def connection_lost(self, exc):
+ if exc is None:
+ self.on_con_lost.set_result(None)
+ else:
+ self.on_con_lost.set_exception(exc)
+
+ async def main(proto, on_con, on_eof, on_con_lost):
+ tr = await on_con
+ tr.write(HELLO_MSG)
+
+ self.assertEqual(proto.data, b'')
+
+ new_tr = await self.loop.start_tls(
+ tr, proto, server_context,
+ server_side=True,
+ ssl_handshake_timeout=self.TIMEOUT)
+
+ await on_eof
+ await on_con_lost
+ self.assertEqual(proto.data, HELLO_MSG)
+ new_tr.close()
+
+ async def run_main():
+ on_con = self.loop.create_future()
+ on_eof = self.loop.create_future()
+ on_con_lost = self.loop.create_future()
+ proto = ServerProto(on_con, on_eof, on_con_lost)
+
+ server = await self.loop.create_server(
+ lambda: proto, '127.0.0.1', 0)
+ addr = server.sockets[0].getsockname()
+
+ with self.tcp_client(lambda sock: client(sock, addr),
+ timeout=self.TIMEOUT):
+ await asyncio.wait_for(
+ main(proto, on_con, on_eof, on_con_lost),
+ timeout=self.TIMEOUT)
+
+ server.close()
+ await server.wait_closed()
+
+ self.loop.run_until_complete(run_main())
+
+ def test_create_server_ssl_over_ssl(self):
+ CNT = 0 # number of clients that were successful
+ TOTAL_CNT = 25 # total number of clients that test will create
+ TIMEOUT = 10.0 # timeout for this test
+
+ A_DATA = b'A' * 1024 * 1024
+ B_DATA = b'B' * 1024 * 1024
+
+ sslctx_1 = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY)
+ client_sslctx_1 = self._create_client_ssl_context()
+ sslctx_2 = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY)
+ client_sslctx_2 = self._create_client_ssl_context()
+
+ clients = []
+
+ async def handle_client(reader, writer):
+ nonlocal CNT
+
+ data = await reader.readexactly(len(A_DATA))
+ self.assertEqual(data, A_DATA)
+ writer.write(b'OK')
+
+ data = await reader.readexactly(len(B_DATA))
+ self.assertEqual(data, B_DATA)
+ writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
+
+ await writer.drain()
+ writer.close()
+
+ CNT += 1
+
+ class ServerProtocol(asyncio.StreamReaderProtocol):
+ def connection_made(self, transport):
+ super_ = super()
+ transport.pause_reading()
+ fut = self._loop.create_task(self._loop.start_tls(
+ transport, self, sslctx_2, server_side=True))
+
+ def cb(_):
+ try:
+ tr = fut.result()
+ except Exception as ex:
+ super_.connection_lost(ex)
+ else:
+ super_.connection_made(tr)
+ fut.add_done_callback(cb)
+
+ def server_protocol_factory():
+ reader = asyncio.StreamReader()
+ protocol = ServerProtocol(reader, handle_client)
+ return protocol
+
+ async def test_client(addr):
+ fut = asyncio.Future()
+
+ def prog(sock):
+ try:
+ sock.connect(addr)
+ sock.starttls(client_sslctx_1)
+
+ # because wrap_socket() doesn't work correctly on
+ # SSLSocket, we have to do the 2nd level SSL manually
+ incoming = ssl.MemoryBIO()
+ outgoing = ssl.MemoryBIO()
+ sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
+
+ def do(func, *args):
+ while True:
+ try:
+ rv = func(*args)
+ break
+ except ssl.SSLWantReadError:
+ if outgoing.pending:
+ sock.send(outgoing.read())
+ incoming.write(sock.recv(65536))
+ if outgoing.pending:
+ sock.send(outgoing.read())
+ return rv
+
+ do(sslobj.do_handshake)
+
+ do(sslobj.write, A_DATA)
+ data = do(sslobj.read, 2)
+ self.assertEqual(data, b'OK')
+
+ do(sslobj.write, B_DATA)
+ data = b''
+ while True:
+ chunk = do(sslobj.read, 4)
+ if not chunk:
+ break
+ data += chunk
+ self.assertEqual(data, b'SPAM')
+
+ do(sslobj.unwrap)
+ sock.close()
+
+ except Exception as ex:
+ self.loop.call_soon_threadsafe(fut.set_exception, ex)
+ sock.close()
+ else:
+ self.loop.call_soon_threadsafe(fut.set_result, None)
+
+ client = self.tcp_client(prog)
+ client.start()
+ clients.append(client)
+
+ await fut
+
+ async def start_server():
+ extras = {}
+
+ srv = await self.loop.create_server(
+ server_protocol_factory,
+ '127.0.0.1', 0,
+ family=socket.AF_INET,
+ ssl=sslctx_1,
+ **extras)
+
+ try:
+ srv_socks = srv.sockets
+ self.assertTrue(srv_socks)
+
+ addr = srv_socks[0].getsockname()
+
+ tasks = []
+ for _ in range(TOTAL_CNT):
+ tasks.append(test_client(addr))
+
+ await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
+
+ finally:
+ self.loop.call_soon(srv.close)
+ await srv.wait_closed()
+
+ with self._silence_eof_received_warning():
+ self.loop.run_until_complete(start_server())
+
+ self.assertEqual(CNT, TOTAL_CNT)
+
+ for client in clients:
+ client.stop()
+
+ def test_shutdown_cleanly(self):
+ CNT = 0
+ TOTAL_CNT = 25
+
+ A_DATA = b'A' * 1024 * 1024
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY)
+ client_sslctx = self._create_client_ssl_context()
+
+ def server(sock):
+ sock.starttls(
+ sslctx,
+ server_side=True)
+
+ data = sock.recv_all(len(A_DATA))
+ self.assertEqual(data, A_DATA)
+ sock.send(b'OK')
+
+ sock.unwrap()
+
+ sock.close()
+
+ async def client(addr):
+ extras = {}
+ extras = dict(ssl_handshake_timeout=10.0)
+
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ **extras)
+
+ writer.write(A_DATA)
+ self.assertEqual(await reader.readexactly(2), b'OK')
+
+ self.assertEqual(await reader.read(), b'')
+
+ nonlocal CNT
+ CNT += 1
+
+ writer.close()
+ await self.wait_closed(writer)
+
+ def run(coro):
+ nonlocal CNT
+ CNT = 0
+
+ async def _gather(*tasks):
+ return await asyncio.gather(*tasks)
+
+ with self.tcp_server(server,
+ max_clients=TOTAL_CNT,
+ backlog=TOTAL_CNT) as srv:
+ tasks = []
+ for _ in range(TOTAL_CNT):
+ tasks.append(coro(srv.addr))
+
+ self.loop.run_until_complete(
+ _gather(*tasks))
+
+ self.assertEqual(CNT, TOTAL_CNT)
+
+ with self._silence_eof_received_warning():
+ run(client)
+
+ def test_flush_before_shutdown(self):
+ CHUNK = 1024 * 128
+ SIZE = 32
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT, test_utils.ONLYKEY)
+ client_sslctx = self._create_client_ssl_context()
+ if hasattr(ssl, 'OP_NO_TLSv1_3'):
+ client_sslctx.options |= ssl.OP_NO_TLSv1_3
+
+ future = None
+
+ def server(sock):
+ sock.starttls(sslctx, server_side=True)
+ self.assertEqual(sock.recv_all(4), b'ping')
+ sock.send(b'pong')
+ time.sleep(0.5) # hopefully stuck the TCP buffer
+ data = sock.recv_all(CHUNK * SIZE)
+ self.assertEqual(len(data), CHUNK * SIZE)
+ sock.close()
+
+ def run(meth):
+ def wrapper(sock):
+ try:
+ meth(sock)
+ except Exception as ex:
+ self.loop.call_soon_threadsafe(future.set_exception, ex)
+ else:
+ self.loop.call_soon_threadsafe(future.set_result, None)
+ return wrapper
+
+ async def client(addr):
+ nonlocal future
+ future = self.loop.create_future()
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='')
+ sslprotocol = writer.transport._ssl_protocol
+ writer.write(b'ping')
+ data = await reader.readexactly(4)
+ self.assertEqual(data, b'pong')
+
+ sslprotocol.pause_writing()
+ for _ in range(SIZE):
+ writer.write(b'x' * CHUNK)
+
+ writer.close()
+ sslprotocol.resume_writing()
+
+ await self.wait_closed(writer)
+ try:
+ data = await reader.read()
+ self.assertEqual(data, b'')
+ except ConnectionResetError:
+ pass
+ await future
+
+ with self.tcp_server(run(server)) as srv:
+ self.loop.run_until_complete(client(srv.addr))
+
+ def test_remote_shutdown_receives_trailing_data(self):
+ CHUNK = 1024 * 128
+ SIZE = 32
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ client_sslctx = self._create_client_ssl_context()
+ future = None
+
+ def server(sock):
+ incoming = ssl.MemoryBIO()
+ outgoing = ssl.MemoryBIO()
+ sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
+
+ while True:
+ try:
+ sslobj.do_handshake()
+ except ssl.SSLWantReadError:
+ if outgoing.pending:
+ sock.send(outgoing.read())
+ incoming.write(sock.recv(16384))
+ else:
+ if outgoing.pending:
+ sock.send(outgoing.read())
+ break
+
+ while True:
+ try:
+ data = sslobj.read(4)
+ except ssl.SSLWantReadError:
+ incoming.write(sock.recv(16384))
+ else:
+ break
+
+ self.assertEqual(data, b'ping')
+ sslobj.write(b'pong')
+ sock.send(outgoing.read())
+
+ time.sleep(0.2) # wait for the peer to fill its backlog
+
+ # send close_notify but don't wait for response
+ with self.assertRaises(ssl.SSLWantReadError):
+ sslobj.unwrap()
+ sock.send(outgoing.read())
+
+ # should receive all data
+ data_len = 0
+ while True:
+ try:
+ chunk = len(sslobj.read(16384))
+ data_len += chunk
+ except ssl.SSLWantReadError:
+ incoming.write(sock.recv(16384))
+ except ssl.SSLZeroReturnError:
+ break
+
+ self.assertEqual(data_len, CHUNK * SIZE)
+
+ # verify that close_notify is received
+ sslobj.unwrap()
+
+ sock.close()
+
+ def eof_server(sock):
+ sock.starttls(sslctx, server_side=True)
+ self.assertEqual(sock.recv_all(4), b'ping')
+ sock.send(b'pong')
+
+ time.sleep(0.2) # wait for the peer to fill its backlog
+
+ # send EOF
+ sock.shutdown(socket.SHUT_WR)
+
+ # should receive all data
+ data = sock.recv_all(CHUNK * SIZE)
+ self.assertEqual(len(data), CHUNK * SIZE)
+
+ sock.close()
+
+ async def client(addr):
+ nonlocal future
+ future = self.loop.create_future()
+
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='')
+ writer.write(b'ping')
+ data = await reader.readexactly(4)
+ self.assertEqual(data, b'pong')
+
+ # fill write backlog in a hacky way - renegotiation won't help
+ for _ in range(SIZE):
+ writer.transport._test__append_write_backlog(b'x' * CHUNK)
+
+ try:
+ data = await reader.read()
+ self.assertEqual(data, b'')
+ except (BrokenPipeError, ConnectionResetError):
+ pass
+
+ await future
+
+ writer.close()
+ await self.wait_closed(writer)
+
+ def run(meth):
+ def wrapper(sock):
+ try:
+ meth(sock)
+ except Exception as ex:
+ self.loop.call_soon_threadsafe(future.set_exception, ex)
+ else:
+ self.loop.call_soon_threadsafe(future.set_result, None)
+ return wrapper
+
+ with self.tcp_server(run(server)) as srv:
+ self.loop.run_until_complete(client(srv.addr))
+
+ with self.tcp_server(run(eof_server)) as srv:
+ self.loop.run_until_complete(client(srv.addr))
+
+ def test_connect_timeout_warning(self):
+ s = socket.socket(socket.AF_INET)
+ s.bind(('127.0.0.1', 0))
+ addr = s.getsockname()
+
+ async def test():
+ try:
+ await asyncio.wait_for(
+ self.loop.create_connection(asyncio.Protocol,
+ *addr, ssl=True),
+ 0.1)
+ except (ConnectionRefusedError, asyncio.TimeoutError):
+ pass
+ else:
+ self.fail('TimeoutError is not raised')
+
+ with s:
+ try:
+ with self.assertWarns(ResourceWarning) as cm:
+ self.loop.run_until_complete(test())
+ gc.collect()
+ gc.collect()
+ gc.collect()
+ except AssertionError as e:
+ self.assertEqual(str(e), 'ResourceWarning not triggered')
+ else:
+ self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
+
+ def test_handshake_timeout_handler_leak(self):
+ s = socket.socket(socket.AF_INET)
+ s.bind(('127.0.0.1', 0))
+ s.listen(1)
+ addr = s.getsockname()
+
+ async def test(ctx):
+ try:
+ await asyncio.wait_for(
+ self.loop.create_connection(asyncio.Protocol, *addr,
+ ssl=ctx),
+ 0.1)
+ except (ConnectionRefusedError, asyncio.TimeoutError):
+ pass
+ else:
+ self.fail('TimeoutError is not raised')
+
+ with s:
+ ctx = ssl.create_default_context()
+ self.loop.run_until_complete(test(ctx))
+ ctx = weakref.ref(ctx)
+
+ # SSLProtocol should be DECREF to 0
+ self.assertIsNone(ctx())
+
+ def test_shutdown_timeout_handler_leak(self):
+ loop = self.loop
+
+ def server(sock):
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ sock = sslctx.wrap_socket(sock, server_side=True)
+ sock.recv(32)
+ sock.close()
+
+ class Protocol(asyncio.Protocol):
+ def __init__(self):
+ self.fut = asyncio.Future(loop=loop)
+
+ def connection_lost(self, exc):
+ self.fut.set_result(None)
+
+ async def client(addr, ctx):
+ tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
+ tr.close()
+ await pr.fut
+
+ with self.tcp_server(server) as srv:
+ ctx = self._create_client_ssl_context()
+ loop.run_until_complete(client(srv.addr, ctx))
+ ctx = weakref.ref(ctx)
+
+ # asyncio has no shutdown timeout, but it ends up with a circular
+ # reference loop - not ideal (introduces gc glitches), but at least
+ # not leaking
+ gc.collect()
+ gc.collect()
+ gc.collect()
+
+ # SSLProtocol should be DECREF to 0
+ self.assertIsNone(ctx())
+
+ def test_shutdown_timeout_handler_not_set(self):
+ loop = self.loop
+ eof = asyncio.Event()
+ extra = None
+
+ def server(sock):
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ sock = sslctx.wrap_socket(sock, server_side=True)
+ sock.send(b'hello')
+ assert sock.recv(1024) == b'world'
+ sock.send(b'extra bytes')
+ # sending EOF here
+ sock.shutdown(socket.SHUT_WR)
+ loop.call_soon_threadsafe(eof.set)
+ # make sure we have enough time to reproduce the issue
+ assert sock.recv(1024) == b''
+ sock.close()
+
+ class Protocol(asyncio.Protocol):
+ def __init__(self):
+ self.fut = asyncio.Future(loop=loop)
+ self.transport = None
+
+ def connection_made(self, transport):
+ self.transport = transport
+
+ def data_received(self, data):
+ if data == b'hello':
+ self.transport.write(b'world')
+ # pause reading would make incoming data stay in the sslobj
+ self.transport.pause_reading()
+ else:
+ nonlocal extra
+ extra = data
+
+ def connection_lost(self, exc):
+ if exc is None:
+ self.fut.set_result(None)
+ else:
+ self.fut.set_exception(exc)
+
+ async def client(addr):
+ ctx = self._create_client_ssl_context()
+ tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
+ await eof.wait()
+ tr.resume_reading()
+ await pr.fut
+ tr.close()
+ assert extra == b'extra bytes'
+
+ with self.tcp_server(server) as srv:
+ loop.run_until_complete(client(srv.addr))
+
+
+###############################################################################
+# Socket Testing Utilities
+###############################################################################
+
+
+class TestSocketWrapper:
+
+ def __init__(self, sock):
+ self.__sock = sock
+
+ def recv_all(self, n):
+ buf = b''
+ while len(buf) < n:
+ data = self.recv(n - len(buf))
+ if data == b'':
+ raise ConnectionAbortedError
+ buf += data
+ return buf
+
+ def starttls(self, ssl_context, *,
+ server_side=False,
+ server_hostname=None,
+ do_handshake_on_connect=True):
+
+ assert isinstance(ssl_context, ssl.SSLContext)
+
+ ssl_sock = ssl_context.wrap_socket(
+ self.__sock, server_side=server_side,
+ server_hostname=server_hostname,
+ do_handshake_on_connect=do_handshake_on_connect)
+
+ if server_side:
+ ssl_sock.do_handshake()
+
+ self.__sock.close()
+ self.__sock = ssl_sock
+
+ def __getattr__(self, name):
+ return getattr(self.__sock, name)
+
+ def __repr__(self):
+ return '<{} {!r}>'.format(type(self).__name__, self.__sock)
+
+
+class SocketThread(threading.Thread):
+
+ def stop(self):
+ self._active = False
+ self.join()
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, *exc):
+ self.stop()
+
+
+class TestThreadedClient(SocketThread):
+
+ def __init__(self, test, sock, prog, timeout):
+ threading.Thread.__init__(self, None, None, 'test-client')
+ self.daemon = True
+
+ self._timeout = timeout
+ self._sock = sock
+ self._active = True
+ self._prog = prog
+ self._test = test
+
+ def run(self):
+ try:
+ self._prog(TestSocketWrapper(self._sock))
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except BaseException as ex:
+ self._test._abort_socket_test(ex)
+
+
+class TestThreadedServer(SocketThread):
+
+ def __init__(self, test, sock, prog, timeout, max_clients):
+ threading.Thread.__init__(self, None, None, 'test-server')
+ self.daemon = True
+
+ self._clients = 0
+ self._finished_clients = 0
+ self._max_clients = max_clients
+ self._timeout = timeout
+ self._sock = sock
+ self._active = True
+
+ self._prog = prog
+
+ self._s1, self._s2 = socket.socketpair()
+ self._s1.setblocking(False)
+
+ self._test = test
+
+ def stop(self):
+ try:
+ if self._s2 and self._s2.fileno() != -1:
+ try:
+ self._s2.send(b'stop')
+ except OSError:
+ pass
+ finally:
+ super().stop()
+
+ def run(self):
+ try:
+ with self._sock:
+ self._sock.setblocking(0)
+ self._run()
+ finally:
+ self._s1.close()
+ self._s2.close()
+
+ def _run(self):
+ while self._active:
+ if self._clients >= self._max_clients:
+ return
+
+ r, w, x = select.select(
+ [self._sock, self._s1], [], [], self._timeout)
+
+ if self._s1 in r:
+ return
+
+ if self._sock in r:
+ try:
+ conn, addr = self._sock.accept()
+ except BlockingIOError:
+ continue
+ except socket.timeout:
+ if not self._active:
+ return
+ else:
+ raise
+ else:
+ self._clients += 1
+ conn.settimeout(self._timeout)
+ try:
+ with conn:
+ self._handle_client(conn)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except BaseException as ex:
+ self._active = False
+ try:
+ raise
+ finally:
+ self._test._abort_socket_test(ex)
+
+ def _handle_client(self, sock):
+ self._prog(TestSocketWrapper(sock))
+
+ @property
+ def addr(self):
+ return self._sock.getsockname()