diff options
| author | Pablo Galindo <Pablogsal@gmail.com> | 2021-05-03 16:21:59 +0100 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-05-03 16:21:59 +0100 | 
| commit | 7719953b30430b351ba0f153c2b51b16cc68ee36 (patch) | |
| tree | 8014086b85a13ed79d45e29ab74a9a9f5c9c68eb /Lib/asyncio | |
| parent | 39494285e15dc2d291ec13de5045b930eaf0a3db (diff) | |
| download | cpython-git-7719953b30430b351ba0f153c2b51b16cc68ee36.tar.gz | |
bpo-44011: Revert "New asyncio ssl implementation (GH-17975)" (GH-25848)
This reverts commit 5fb06edbbb769561e245d0fe13002bab50e2ae60 and all
subsequent dependent commits.
Diffstat (limited to 'Lib/asyncio')
| -rw-r--r-- | Lib/asyncio/base_events.py | 43 | ||||
| -rw-r--r-- | Lib/asyncio/constants.py | 7 | ||||
| -rw-r--r-- | Lib/asyncio/events.py | 19 | ||||
| -rw-r--r-- | Lib/asyncio/proactor_events.py | 12 | ||||
| -rw-r--r-- | Lib/asyncio/selector_events.py | 31 | ||||
| -rw-r--r-- | Lib/asyncio/sslproto.py | 1059 | ||||
| -rw-r--r-- | Lib/asyncio/unix_events.py | 17 | 
7 files changed, 467 insertions, 721 deletions
| diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index e54ee309e4..f789635e0f 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -273,7 +273,7 @@ class _SendfileFallbackProtocol(protocols.Protocol):  class Server(events.AbstractServer):      def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, -                 ssl_handshake_timeout, ssl_shutdown_timeout=None): +                 ssl_handshake_timeout):          self._loop = loop          self._sockets = sockets          self._active_count = 0 @@ -282,7 +282,6 @@ class Server(events.AbstractServer):          self._backlog = backlog          self._ssl_context = ssl_context          self._ssl_handshake_timeout = ssl_handshake_timeout -        self._ssl_shutdown_timeout = ssl_shutdown_timeout          self._serving = False          self._serving_forever_fut = None @@ -314,8 +313,7 @@ class Server(events.AbstractServer):              sock.listen(self._backlog)              self._loop._start_serving(                  self._protocol_factory, sock, self._ssl_context, -                self, self._backlog, self._ssl_handshake_timeout, -                self._ssl_shutdown_timeout) +                self, self._backlog, self._ssl_handshake_timeout)      def get_loop(self):          return self._loop @@ -469,7 +467,6 @@ class BaseEventLoop(events.AbstractEventLoop):              *, server_side=False, server_hostname=None,              extra=None, server=None,              ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None,              call_connection_made=True):          """Create SSL transport."""          raise NotImplementedError @@ -972,7 +969,6 @@ class BaseEventLoop(events.AbstractEventLoop):              proto=0, flags=0, sock=None,              local_addr=None, server_hostname=None,              ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None,              happy_eyeballs_delay=None, interleave=None):          """Connect to a TCP server. @@ -1008,10 +1004,6 @@ class BaseEventLoop(events.AbstractEventLoop):              raise ValueError(                  'ssl_handshake_timeout is only meaningful with ssl') -        if ssl_shutdown_timeout is not None and not ssl: -            raise ValueError( -                'ssl_shutdown_timeout is only meaningful with ssl') -          if happy_eyeballs_delay is not None and interleave is None:              # If using happy eyeballs, default to interleave addresses by family              interleave = 1 @@ -1087,8 +1079,7 @@ class BaseEventLoop(events.AbstractEventLoop):          transport, protocol = await self._create_connection_transport(              sock, protocol_factory, ssl, server_hostname, -            ssl_handshake_timeout=ssl_handshake_timeout, -            ssl_shutdown_timeout=ssl_shutdown_timeout) +            ssl_handshake_timeout=ssl_handshake_timeout)          if self._debug:              # Get the socket from the transport because SSL transport closes              # the old socket and creates a new SSL socket @@ -1100,8 +1091,7 @@ class BaseEventLoop(events.AbstractEventLoop):      async def _create_connection_transport(              self, sock, protocol_factory, ssl,              server_hostname, server_side=False, -            ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None): +            ssl_handshake_timeout=None):          sock.setblocking(False) @@ -1112,8 +1102,7 @@ class BaseEventLoop(events.AbstractEventLoop):              transport = self._make_ssl_transport(                  sock, protocol, sslcontext, waiter,                  server_side=server_side, server_hostname=server_hostname, -                ssl_handshake_timeout=ssl_handshake_timeout, -                ssl_shutdown_timeout=ssl_shutdown_timeout) +                ssl_handshake_timeout=ssl_handshake_timeout)          else:              transport = self._make_socket_transport(sock, protocol, waiter) @@ -1204,8 +1193,7 @@ class BaseEventLoop(events.AbstractEventLoop):      async def start_tls(self, transport, protocol, sslcontext, *,                          server_side=False,                          server_hostname=None, -                        ssl_handshake_timeout=None, -                        ssl_shutdown_timeout=None): +                        ssl_handshake_timeout=None):          """Upgrade transport to TLS.          Return a new transport that *protocol* should start using @@ -1228,7 +1216,6 @@ class BaseEventLoop(events.AbstractEventLoop):              self, protocol, sslcontext, waiter,              server_side, server_hostname,              ssl_handshake_timeout=ssl_handshake_timeout, -            ssl_shutdown_timeout=ssl_shutdown_timeout,              call_connection_made=False)          # Pause early so that "ssl_protocol.data_received()" doesn't @@ -1427,7 +1414,6 @@ class BaseEventLoop(events.AbstractEventLoop):              reuse_address=None,              reuse_port=None,              ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None,              start_serving=True):          """Create a TCP server. @@ -1451,10 +1437,6 @@ class BaseEventLoop(events.AbstractEventLoop):              raise ValueError(                  'ssl_handshake_timeout is only meaningful with ssl') -        if ssl_shutdown_timeout is not None and ssl is None: -            raise ValueError( -                'ssl_shutdown_timeout is only meaningful with ssl') -          if host is not None or port is not None:              if sock is not None:                  raise ValueError( @@ -1527,8 +1509,7 @@ class BaseEventLoop(events.AbstractEventLoop):              sock.setblocking(False)          server = Server(self, sockets, protocol_factory, -                        ssl, backlog, ssl_handshake_timeout, -                        ssl_shutdown_timeout) +                        ssl, backlog, ssl_handshake_timeout)          if start_serving:              server._start_serving()              # Skip one loop iteration so that all 'loop.add_reader' @@ -1542,8 +1523,7 @@ class BaseEventLoop(events.AbstractEventLoop):      async def connect_accepted_socket(              self, protocol_factory, sock,              *, ssl=None, -            ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None): +            ssl_handshake_timeout=None):          if sock.type != socket.SOCK_STREAM:              raise ValueError(f'A Stream Socket was expected, got {sock!r}') @@ -1551,14 +1531,9 @@ class BaseEventLoop(events.AbstractEventLoop):              raise ValueError(                  'ssl_handshake_timeout is only meaningful with ssl') -        if ssl_shutdown_timeout is not None and not ssl: -            raise ValueError( -                'ssl_shutdown_timeout is only meaningful with ssl') -          transport, protocol = await self._create_connection_transport(              sock, protocol_factory, ssl, '', server_side=True, -            ssl_handshake_timeout=ssl_handshake_timeout, -            ssl_shutdown_timeout=ssl_shutdown_timeout) +            ssl_handshake_timeout=ssl_handshake_timeout)          if self._debug:              # Get the socket from the transport because SSL transport closes              # the old socket and creates a new SSL socket diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py index f171ead28f..33feed60e5 100644 --- a/Lib/asyncio/constants.py +++ b/Lib/asyncio/constants.py @@ -15,17 +15,10 @@ DEBUG_STACK_DEPTH = 10  # The default timeout matches that of Nginx.  SSL_HANDSHAKE_TIMEOUT = 60.0 -# Number of seconds to wait for SSL shutdown to complete -# The default timeout mimics lingering_time -SSL_SHUTDOWN_TIMEOUT = 30.0 -  # Used in sendfile fallback code.  We use fallback for platforms  # that don't support sendfile, or for TLS connections.  SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256 -FLOW_CONTROL_HIGH_WATER_SSL_READ = 256  # KiB -FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512  # KiB -  # The enum should be here to break circular dependencies between  # base_events and sslproto  class _SendfileMode(enum.Enum): diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index d5254fa5e7..b966ad26bf 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -304,7 +304,6 @@ class AbstractEventLoop:              flags=0, sock=None, local_addr=None,              server_hostname=None,              ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None,              happy_eyeballs_delay=None, interleave=None):          raise NotImplementedError @@ -314,7 +313,6 @@ class AbstractEventLoop:              flags=socket.AI_PASSIVE, sock=None, backlog=100,              ssl=None, reuse_address=None, reuse_port=None,              ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None,              start_serving=True):          """A coroutine which creates a TCP server bound to host and port. @@ -355,10 +353,6 @@ class AbstractEventLoop:          will wait for completion of the SSL handshake before aborting the          connection. Default is 60s. -        ssl_shutdown_timeout is the time in seconds that an SSL server -        will wait for completion of the SSL shutdown procedure -        before aborting the connection. Default is 30s. -          start_serving set to True (default) causes the created server          to start accepting connections immediately.  When set to False,          the user should await Server.start_serving() or Server.serve_forever() @@ -377,8 +371,7 @@ class AbstractEventLoop:      async def start_tls(self, transport, protocol, sslcontext, *,                          server_side=False,                          server_hostname=None, -                        ssl_handshake_timeout=None, -                        ssl_shutdown_timeout=None): +                        ssl_handshake_timeout=None):          """Upgrade a transport to TLS.          Return a new transport that *protocol* should start using @@ -390,15 +383,13 @@ class AbstractEventLoop:              self, protocol_factory, path=None, *,              ssl=None, sock=None,              server_hostname=None, -            ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None): +            ssl_handshake_timeout=None):          raise NotImplementedError      async def create_unix_server(              self, protocol_factory, path=None, *,              sock=None, backlog=100, ssl=None,              ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None,              start_serving=True):          """A coroutine which creates a UNIX Domain Socket server. @@ -420,9 +411,6 @@ class AbstractEventLoop:          ssl_handshake_timeout is the time in seconds that an SSL server          will wait for the SSL handshake to complete (defaults to 60s). -        ssl_shutdown_timeout is the time in seconds that an SSL server -        will wait for the SSL shutdown to finish (defaults to 30s). -          start_serving set to True (default) causes the created server          to start accepting connections immediately.  When set to False,          the user should await Server.start_serving() or Server.serve_forever() @@ -433,8 +421,7 @@ class AbstractEventLoop:      async def connect_accepted_socket(              self, protocol_factory, sock,              *, ssl=None, -            ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None): +            ssl_handshake_timeout=None):          """Handle an accepted connection.          This is used by servers that accept connections outside of diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 10852afe2b..45c11ee4b4 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -642,13 +642,11 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):              self, rawsock, protocol, sslcontext, waiter=None,              *, server_side=False, server_hostname=None,              extra=None, server=None, -            ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None): +            ssl_handshake_timeout=None):          ssl_protocol = sslproto.SSLProtocol(                  self, protocol, sslcontext, waiter,                  server_side, server_hostname, -                ssl_handshake_timeout=ssl_handshake_timeout, -                ssl_shutdown_timeout=ssl_shutdown_timeout) +                ssl_handshake_timeout=ssl_handshake_timeout)          _ProactorSocketTransport(self, rawsock, ssl_protocol,                                   extra=extra, server=server)          return ssl_protocol._app_transport @@ -814,8 +812,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):      def _start_serving(self, protocol_factory, sock,                         sslcontext=None, server=None, backlog=100, -                       ssl_handshake_timeout=None, -                       ssl_shutdown_timeout=None): +                       ssl_handshake_timeout=None):          def loop(f=None):              try: @@ -829,8 +826,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):                          self._make_ssl_transport(                              conn, protocol, sslcontext, server_side=True,                              extra={'peername': addr}, server=server, -                            ssl_handshake_timeout=ssl_handshake_timeout, -                            ssl_shutdown_timeout=ssl_shutdown_timeout) +                            ssl_handshake_timeout=ssl_handshake_timeout)                      else:                          self._make_socket_transport(                              conn, protocol, diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 63ab15f30f..59cb6b1bab 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -70,15 +70,11 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):              self, rawsock, protocol, sslcontext, waiter=None,              *, server_side=False, server_hostname=None,              extra=None, server=None, -            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, -            ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, -    ): +            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):          ssl_protocol = sslproto.SSLProtocol( -            self, protocol, sslcontext, waiter, -            server_side, server_hostname, -            ssl_handshake_timeout=ssl_handshake_timeout, -            ssl_shutdown_timeout=ssl_shutdown_timeout -        ) +                self, protocol, sslcontext, waiter, +                server_side, server_hostname, +                ssl_handshake_timeout=ssl_handshake_timeout)          _SelectorSocketTransport(self, rawsock, ssl_protocol,                                   extra=extra, server=server)          return ssl_protocol._app_transport @@ -150,17 +146,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):      def _start_serving(self, protocol_factory, sock,                         sslcontext=None, server=None, backlog=100, -                       ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, -                       ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): +                       ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):          self._add_reader(sock.fileno(), self._accept_connection,                           protocol_factory, sock, sslcontext, server, backlog, -                         ssl_handshake_timeout, ssl_shutdown_timeout) +                         ssl_handshake_timeout)      def _accept_connection(              self, protocol_factory, sock,              sslcontext=None, server=None, backlog=100, -            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, -            ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): +            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):          # This method is only called once for each event loop tick where the          # listening socket has triggered an EVENT_READ. There may be multiple          # connections waiting for an .accept() so it is called in a loop. @@ -191,22 +185,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):                      self.call_later(constants.ACCEPT_RETRY_DELAY,                                      self._start_serving,                                      protocol_factory, sock, sslcontext, server, -                                    backlog, ssl_handshake_timeout, -                                    ssl_shutdown_timeout) +                                    backlog, ssl_handshake_timeout)                  else:                      raise  # The event loop will catch, log and ignore it.              else:                  extra = {'peername': addr}                  accept = self._accept_connection2(                      protocol_factory, conn, extra, sslcontext, server, -                    ssl_handshake_timeout, ssl_shutdown_timeout) +                    ssl_handshake_timeout)                  self.create_task(accept)      async def _accept_connection2(              self, protocol_factory, conn, extra,              sslcontext=None, server=None, -            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, -            ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): +            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):          protocol = None          transport = None          try: @@ -216,8 +208,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):                  transport = self._make_ssl_transport(                      conn, protocol, sslcontext, waiter=waiter,                      server_side=True, extra=extra, server=server, -                    ssl_handshake_timeout=ssl_handshake_timeout, -                    ssl_shutdown_timeout=ssl_shutdown_timeout) +                    ssl_handshake_timeout=ssl_handshake_timeout)              else:                  transport = self._make_socket_transport(                      conn, protocol, waiter=waiter, extra=extra, diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 79734ab63d..cad25b2653 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -1,5 +1,4 @@  import collections -import enum  import warnings  try:      import ssl @@ -7,38 +6,10 @@ except ImportError:  # pragma: no cover      ssl = None  from . import constants -from . import exceptions  from . import protocols  from . import transports  from .log import logger -if ssl is not None: -    SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError) - - -class SSLProtocolState(enum.Enum): -    UNWRAPPED = "UNWRAPPED" -    DO_HANDSHAKE = "DO_HANDSHAKE" -    WRAPPED = "WRAPPED" -    FLUSHING = "FLUSHING" -    SHUTDOWN = "SHUTDOWN" - - -class AppProtocolState(enum.Enum): -    # This tracks the state of app protocol (https://git.io/fj59P): -    # -    #     INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST -    # -    # * cm: connection_made() -    # * dr: data_received() -    # * er: eof_received() -    # * cl: connection_lost() - -    STATE_INIT = "STATE_INIT" -    STATE_CON_MADE = "STATE_CON_MADE" -    STATE_EOF = "STATE_EOF" -    STATE_CON_LOST = "STATE_CON_LOST" -  def _create_transport_context(server_side, server_hostname):      if server_side: @@ -54,35 +25,269 @@ def _create_transport_context(server_side, server_hostname):      return sslcontext -def add_flowcontrol_defaults(high, low, kb): -    if high is None: -        if low is None: -            hi = kb * 1024 -        else: -            lo = low -            hi = 4 * lo -    else: -        hi = high -    if low is None: -        lo = hi // 4 -    else: -        lo = low +# States of an _SSLPipe. +_UNWRAPPED = "UNWRAPPED" +_DO_HANDSHAKE = "DO_HANDSHAKE" +_WRAPPED = "WRAPPED" +_SHUTDOWN = "SHUTDOWN" + + +class _SSLPipe(object): +    """An SSL "Pipe". + +    An SSL pipe allows you to communicate with an SSL/TLS protocol instance +    through memory buffers. It can be used to implement a security layer for an +    existing connection where you don't have access to the connection's file +    descriptor, or for some reason you don't want to use it. + +    An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, +    data is passed through untransformed. In wrapped mode, application level +    data is encrypted to SSL record level data and vice versa. The SSL record +    level is the lowest level in the SSL protocol suite and is what travels +    as-is over the wire. + +    An SslPipe initially is in "unwrapped" mode. To start SSL, call +    do_handshake(). To shutdown SSL again, call unwrap(). +    """ + +    max_size = 256 * 1024   # Buffer size passed to read() + +    def __init__(self, context, server_side, server_hostname=None): +        """ +        The *context* argument specifies the ssl.SSLContext to use. + +        The *server_side* argument indicates whether this is a server side or +        client side transport. + +        The optional *server_hostname* argument can be used to specify the +        hostname you are connecting to. You may only specify this parameter if +        the _ssl module supports Server Name Indication (SNI). +        """ +        self._context = context +        self._server_side = server_side +        self._server_hostname = server_hostname +        self._state = _UNWRAPPED +        self._incoming = ssl.MemoryBIO() +        self._outgoing = ssl.MemoryBIO() +        self._sslobj = None +        self._need_ssldata = False +        self._handshake_cb = None +        self._shutdown_cb = None + +    @property +    def context(self): +        """The SSL context passed to the constructor.""" +        return self._context + +    @property +    def ssl_object(self): +        """The internal ssl.SSLObject instance. + +        Return None if the pipe is not wrapped. +        """ +        return self._sslobj + +    @property +    def need_ssldata(self): +        """Whether more record level data is needed to complete a handshake +        that is currently in progress.""" +        return self._need_ssldata + +    @property +    def wrapped(self): +        """ +        Whether a security layer is currently in effect. + +        Return False during handshake. +        """ +        return self._state == _WRAPPED + +    def do_handshake(self, callback=None): +        """Start the SSL handshake. + +        Return a list of ssldata. A ssldata element is a list of buffers + +        The optional *callback* argument can be used to install a callback that +        will be called when the handshake is complete. The callback will be +        called with None if successful, else an exception instance. +        """ +        if self._state != _UNWRAPPED: +            raise RuntimeError('handshake in progress or completed') +        self._sslobj = self._context.wrap_bio( +            self._incoming, self._outgoing, +            server_side=self._server_side, +            server_hostname=self._server_hostname) +        self._state = _DO_HANDSHAKE +        self._handshake_cb = callback +        ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) +        assert len(appdata) == 0 +        return ssldata + +    def shutdown(self, callback=None): +        """Start the SSL shutdown sequence. + +        Return a list of ssldata. A ssldata element is a list of buffers -    if not hi >= lo >= 0: -        raise ValueError('high (%r) must be >= low (%r) must be >= 0' % -                         (hi, lo)) +        The optional *callback* argument can be used to install a callback that +        will be called when the shutdown is complete. The callback will be +        called without arguments. +        """ +        if self._state == _UNWRAPPED: +            raise RuntimeError('no security layer present') +        if self._state == _SHUTDOWN: +            raise RuntimeError('shutdown in progress') +        assert self._state in (_WRAPPED, _DO_HANDSHAKE) +        self._state = _SHUTDOWN +        self._shutdown_cb = callback +        ssldata, appdata = self.feed_ssldata(b'') +        assert appdata == [] or appdata == [b''] +        return ssldata + +    def feed_eof(self): +        """Send a potentially "ragged" EOF. + +        This method will raise an SSL_ERROR_EOF exception if the EOF is +        unexpected. +        """ +        self._incoming.write_eof() +        ssldata, appdata = self.feed_ssldata(b'') +        assert appdata == [] or appdata == [b''] + +    def feed_ssldata(self, data, only_handshake=False): +        """Feed SSL record level data into the pipe. + +        The data must be a bytes instance. It is OK to send an empty bytes +        instance. This can be used to get ssldata for a handshake initiated by +        this endpoint. + +        Return a (ssldata, appdata) tuple. The ssldata element is a list of +        buffers containing SSL data that needs to be sent to the remote SSL. + +        The appdata element is a list of buffers containing plaintext data that +        needs to be forwarded to the application. The appdata list may contain +        an empty buffer indicating an SSL "close_notify" alert. This alert must +        be acknowledged by calling shutdown(). +        """ +        if self._state == _UNWRAPPED: +            # If unwrapped, pass plaintext data straight through. +            if data: +                appdata = [data] +            else: +                appdata = [] +            return ([], appdata) + +        self._need_ssldata = False +        if data: +            self._incoming.write(data) + +        ssldata = [] +        appdata = [] +        try: +            if self._state == _DO_HANDSHAKE: +                # Call do_handshake() until it doesn't raise anymore. +                self._sslobj.do_handshake() +                self._state = _WRAPPED +                if self._handshake_cb: +                    self._handshake_cb(None) +                if only_handshake: +                    return (ssldata, appdata) +                # Handshake done: execute the wrapped block + +            if self._state == _WRAPPED: +                # Main state: read data from SSL until close_notify +                while True: +                    chunk = self._sslobj.read(self.max_size) +                    appdata.append(chunk) +                    if not chunk:  # close_notify +                        break + +            elif self._state == _SHUTDOWN: +                # Call shutdown() until it doesn't raise anymore. +                self._sslobj.unwrap() +                self._sslobj = None +                self._state = _UNWRAPPED +                if self._shutdown_cb: +                    self._shutdown_cb() + +            elif self._state == _UNWRAPPED: +                # Drain possible plaintext data after close_notify. +                appdata.append(self._incoming.read()) +        except (ssl.SSLError, ssl.CertificateError) as exc: +            exc_errno = getattr(exc, 'errno', None) +            if exc_errno not in ( +                    ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, +                    ssl.SSL_ERROR_SYSCALL): +                if self._state == _DO_HANDSHAKE and self._handshake_cb: +                    self._handshake_cb(exc) +                raise +            self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ) + +        # Check for record level data that needs to be sent back. +        # Happens for the initial handshake and renegotiations. +        if self._outgoing.pending: +            ssldata.append(self._outgoing.read()) +        return (ssldata, appdata) + +    def feed_appdata(self, data, offset=0): +        """Feed plaintext data into the pipe. + +        Return an (ssldata, offset) tuple. The ssldata element is a list of +        buffers containing record level data that needs to be sent to the +        remote SSL instance. The offset is the number of plaintext bytes that +        were processed, which may be less than the length of data. + +        NOTE: In case of short writes, this call MUST be retried with the SAME +        buffer passed into the *data* argument (i.e. the id() must be the +        same). This is an OpenSSL requirement. A further particularity is that +        a short write will always have offset == 0, because the _ssl module +        does not enable partial writes. And even though the offset is zero, +        there will still be encrypted data in ssldata. +        """ +        assert 0 <= offset <= len(data) +        if self._state == _UNWRAPPED: +            # pass through data in unwrapped mode +            if offset < len(data): +                ssldata = [data[offset:]] +            else: +                ssldata = [] +            return (ssldata, len(data)) -    return hi, lo +        ssldata = [] +        view = memoryview(data) +        while True: +            self._need_ssldata = False +            try: +                if offset < len(view): +                    offset += self._sslobj.write(view[offset:]) +            except ssl.SSLError as exc: +                # It is not allowed to call write() after unwrap() until the +                # close_notify is acknowledged. We return the condition to the +                # caller as a short write. +                exc_errno = getattr(exc, 'errno', None) +                if exc.reason == 'PROTOCOL_IS_SHUTDOWN': +                    exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ +                if exc_errno not in (ssl.SSL_ERROR_WANT_READ, +                                     ssl.SSL_ERROR_WANT_WRITE, +                                     ssl.SSL_ERROR_SYSCALL): +                    raise +                self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ) + +            # See if there's any record level data back for us. +            if self._outgoing.pending: +                ssldata.append(self._outgoing.read()) +            if offset == len(view) or self._need_ssldata: +                break +        return (ssldata, offset)  class _SSLProtocolTransport(transports._FlowControlMixin,                              transports.Transport): -    _start_tls_compatible = True      _sendfile_compatible = constants._SendfileMode.FALLBACK      def __init__(self, loop, ssl_protocol):          self._loop = loop +        # SSLProtocol instance          self._ssl_protocol = ssl_protocol          self._closed = False @@ -110,15 +315,16 @@ class _SSLProtocolTransport(transports._FlowControlMixin,          self._closed = True          self._ssl_protocol._start_shutdown() -    def __del__(self, _warnings=warnings): +    def __del__(self, _warn=warnings.warn):          if not self._closed: -            self._closed = True -            _warnings.warn( -                "unclosed transport <asyncio._SSLProtocolTransport " -                "object>", ResourceWarning) +            _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) +            self.close()      def is_reading(self): -        return not self._ssl_protocol._app_reading_paused +        tr = self._ssl_protocol._transport +        if tr is None: +            raise RuntimeError('SSL transport has not been initialized yet') +        return tr.is_reading()      def pause_reading(self):          """Pause the receiving end. @@ -126,7 +332,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,          No data will be passed to the protocol's data_received()          method until resume_reading() is called.          """ -        self._ssl_protocol._pause_reading() +        self._ssl_protocol._transport.pause_reading()      def resume_reading(self):          """Resume the receiving end. @@ -134,7 +340,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,          Data received will once again be passed to the protocol's          data_received() method.          """ -        self._ssl_protocol._resume_reading() +        self._ssl_protocol._transport.resume_reading()      def set_write_buffer_limits(self, high=None, low=None):          """Set the high- and low-water limits for write flow control. @@ -155,51 +361,16 @@ class _SSLProtocolTransport(transports._FlowControlMixin,          reduces opportunities for doing I/O and computation          concurrently.          """ -        self._ssl_protocol._set_write_buffer_limits(high, low) -        self._ssl_protocol._control_app_writing() - -    def get_write_buffer_limits(self): -        return (self._ssl_protocol._outgoing_low_water, -                self._ssl_protocol._outgoing_high_water) +        self._ssl_protocol._transport.set_write_buffer_limits(high, low)      def get_write_buffer_size(self): -        """Return the current size of the write buffers.""" -        return self._ssl_protocol._get_write_buffer_size() - -    def set_read_buffer_limits(self, high=None, low=None): -        """Set the high- and low-water limits for read flow control. - -        These two values control when to call the upstream transport's -        pause_reading() and resume_reading() methods.  If specified, -        the low-water limit must be less than or equal to the -        high-water limit.  Neither value can be negative. - -        The defaults are implementation-specific.  If only the -        high-water limit is given, the low-water limit defaults to an -        implementation-specific value less than or equal to the -        high-water limit.  Setting high to zero forces low to zero as -        well, and causes pause_reading() to be called whenever the -        buffer becomes non-empty.  Setting low to zero causes -        resume_reading() to be called only once the buffer is empty. -        Use of zero for either limit is generally sub-optimal as it -        reduces opportunities for doing I/O and computation -        concurrently. -        """ -        self._ssl_protocol._set_read_buffer_limits(high, low) -        self._ssl_protocol._control_ssl_reading() - -    def get_read_buffer_limits(self): -        return (self._ssl_protocol._incoming_low_water, -                self._ssl_protocol._incoming_high_water) - -    def get_read_buffer_size(self): -        """Return the current size of the read buffer.""" -        return self._ssl_protocol._get_read_buffer_size() +        """Return the current size of the write buffer.""" +        return self._ssl_protocol._transport.get_write_buffer_size()      @property      def _protocol_paused(self):          # Required for sendfile fallback pause_writing/resume_writing logic -        return self._ssl_protocol._app_writing_paused +        return self._ssl_protocol._transport._protocol_paused      def write(self, data):          """Write some data bytes to the transport. @@ -212,22 +383,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,                              f"got {type(data).__name__}")          if not data:              return -        self._ssl_protocol._write_appdata((data,)) - -    def writelines(self, list_of_data): -        """Write a list (or any iterable) of data bytes to the transport. - -        The default implementation concatenates the arguments and -        calls write() on the result. -        """ -        self._ssl_protocol._write_appdata(list_of_data) - -    def write_eof(self): -        """Close the write end after flushing buffered data. - -        This raises :exc:`NotImplementedError` right now. -        """ -        raise NotImplementedError +        self._ssl_protocol._write_appdata(data)      def can_write_eof(self):          """Return True if this transport supports write_eof(), False if not.""" @@ -240,36 +396,23 @@ class _SSLProtocolTransport(transports._FlowControlMixin,          The protocol's connection_lost() method will (eventually) be          called with None as its argument.          """ -        self._closed = True          self._ssl_protocol._abort() - -    def _force_close(self, exc):          self._closed = True -        self._ssl_protocol._abort(exc) -    def _test__append_write_backlog(self, data): -        # for test only -        self._ssl_protocol._write_backlog.append(data) -        self._ssl_protocol._write_buffer_size += len(data) +class SSLProtocol(protocols.Protocol): +    """SSL protocol. -class SSLProtocol(protocols.BufferedProtocol): -    max_size = 256 * 1024   # Buffer size passed to read() - -    _handshake_start_time = None -    _handshake_timeout_handle = None -    _shutdown_timeout_handle = None +    Implementation of SSL on top of a socket using incoming and outgoing +    buffers which are ssl.MemoryBIO objects. +    """      def __init__(self, loop, app_protocol, sslcontext, waiter,                   server_side=False, server_hostname=None,                   call_connection_made=True, -                 ssl_handshake_timeout=None, -                 ssl_shutdown_timeout=None): +                 ssl_handshake_timeout=None):          if ssl is None: -            raise RuntimeError("stdlib ssl module not available") - -        self._ssl_buffer = bytearray(self.max_size) -        self._ssl_buffer_view = memoryview(self._ssl_buffer) +            raise RuntimeError('stdlib ssl module not available')          if ssl_handshake_timeout is None:              ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT @@ -277,12 +420,6 @@ class SSLProtocol(protocols.BufferedProtocol):              raise ValueError(                  f"ssl_handshake_timeout should be a positive number, "                  f"got {ssl_handshake_timeout}") -        if ssl_shutdown_timeout is None: -            ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT -        elif ssl_shutdown_timeout <= 0: -            raise ValueError( -                f"ssl_shutdown_timeout should be a positive number, " -                f"got {ssl_shutdown_timeout}")          if not sslcontext:              sslcontext = _create_transport_context( @@ -305,54 +442,21 @@ class SSLProtocol(protocols.BufferedProtocol):          self._waiter = waiter          self._loop = loop          self._set_app_protocol(app_protocol) -        self._app_transport = None -        self._app_transport_created = False +        self._app_transport = _SSLProtocolTransport(self._loop, self) +        # _SSLPipe instance (None until the connection is made) +        self._sslpipe = None +        self._session_established = False +        self._in_handshake = False +        self._in_shutdown = False          # transport, ex: SelectorSocketTransport          self._transport = None +        self._call_connection_made = call_connection_made          self._ssl_handshake_timeout = ssl_handshake_timeout -        self._ssl_shutdown_timeout = ssl_shutdown_timeout -        # SSL and state machine -        self._incoming = ssl.MemoryBIO() -        self._outgoing = ssl.MemoryBIO() -        self._state = SSLProtocolState.UNWRAPPED -        self._conn_lost = 0  # Set when connection_lost called -        if call_connection_made: -            self._app_state = AppProtocolState.STATE_INIT -        else: -            self._app_state = AppProtocolState.STATE_CON_MADE -        self._sslobj = self._sslcontext.wrap_bio( -            self._incoming, self._outgoing, -            server_side=self._server_side, -            server_hostname=self._server_hostname) - -        # Flow Control - -        self._ssl_writing_paused = False - -        self._app_reading_paused = False - -        self._ssl_reading_paused = False -        self._incoming_high_water = 0 -        self._incoming_low_water = 0 -        self._set_read_buffer_limits() -        self._eof_received = False - -        self._app_writing_paused = False -        self._outgoing_high_water = 0 -        self._outgoing_low_water = 0 -        self._set_write_buffer_limits() -        self._get_app_transport()      def _set_app_protocol(self, app_protocol):          self._app_protocol = app_protocol -        # Make fast hasattr check first -        if (hasattr(app_protocol, 'get_buffer') and -                isinstance(app_protocol, protocols.BufferedProtocol)): -            self._app_protocol_get_buffer = app_protocol.get_buffer -            self._app_protocol_buffer_updated = app_protocol.buffer_updated -            self._app_protocol_is_buffer = True -        else: -            self._app_protocol_is_buffer = False +        self._app_protocol_is_buffer = \ +            isinstance(app_protocol, protocols.BufferedProtocol)      def _wakeup_waiter(self, exc=None):          if self._waiter is None: @@ -364,20 +468,15 @@ class SSLProtocol(protocols.BufferedProtocol):                  self._waiter.set_result(None)          self._waiter = None -    def _get_app_transport(self): -        if self._app_transport is None: -            if self._app_transport_created: -                raise RuntimeError('Creating _SSLProtocolTransport twice') -            self._app_transport = _SSLProtocolTransport(self._loop, self) -            self._app_transport_created = True -        return self._app_transport -      def connection_made(self, transport):          """Called when the low-level connection is made.          Start the SSL handshake.          """          self._transport = transport +        self._sslpipe = _SSLPipe(self._sslcontext, +                                 self._server_side, +                                 self._server_hostname)          self._start_handshake()      def connection_lost(self, exc): @@ -387,58 +486,72 @@ class SSLProtocol(protocols.BufferedProtocol):          meaning a regular EOF is received or the connection was          aborted or closed).          """ -        self._write_backlog.clear() -        self._outgoing.read() -        self._conn_lost += 1 - -        # Just mark the app transport as closed so that its __dealloc__ -        # doesn't complain. -        if self._app_transport is not None: -            self._app_transport._closed = True - -        if self._state != SSLProtocolState.DO_HANDSHAKE: -            if ( -                self._app_state == AppProtocolState.STATE_CON_MADE or -                self._app_state == AppProtocolState.STATE_EOF -            ): -                self._app_state = AppProtocolState.STATE_CON_LOST -                self._loop.call_soon(self._app_protocol.connection_lost, exc) -        self._set_state(SSLProtocolState.UNWRAPPED) +        if self._session_established: +            self._session_established = False +            self._loop.call_soon(self._app_protocol.connection_lost, exc) +        else: +            # Most likely an exception occurred while in SSL handshake. +            # Just mark the app transport as closed so that its __del__ +            # doesn't complain. +            if self._app_transport is not None: +                self._app_transport._closed = True          self._transport = None          self._app_transport = None -        self._app_protocol = None +        if getattr(self, '_handshake_timeout_handle', None): +            self._handshake_timeout_handle.cancel()          self._wakeup_waiter(exc) +        self._app_protocol = None +        self._sslpipe = None -        if self._shutdown_timeout_handle: -            self._shutdown_timeout_handle.cancel() -            self._shutdown_timeout_handle = None -        if self._handshake_timeout_handle: -            self._handshake_timeout_handle.cancel() -            self._handshake_timeout_handle = None +    def pause_writing(self): +        """Called when the low-level transport's buffer goes over +        the high-water mark. +        """ +        self._app_protocol.pause_writing() -    def get_buffer(self, n): -        want = n -        if want <= 0 or want > self.max_size: -            want = self.max_size -        if len(self._ssl_buffer) < want: -            self._ssl_buffer = bytearray(want) -            self._ssl_buffer_view = memoryview(self._ssl_buffer) -        return self._ssl_buffer_view +    def resume_writing(self): +        """Called when the low-level transport's buffer drains below +        the low-water mark. +        """ +        self._app_protocol.resume_writing() -    def buffer_updated(self, nbytes): -        self._incoming.write(self._ssl_buffer_view[:nbytes]) +    def data_received(self, data): +        """Called when some SSL data is received. -        if self._state == SSLProtocolState.DO_HANDSHAKE: -            self._do_handshake() +        The argument is a bytes object. +        """ +        if self._sslpipe is None: +            # transport closing, sslpipe is destroyed +            return -        elif self._state == SSLProtocolState.WRAPPED: -            self._do_read() +        try: +            ssldata, appdata = self._sslpipe.feed_ssldata(data) +        except (SystemExit, KeyboardInterrupt): +            raise +        except BaseException as e: +            self._fatal_error(e, 'SSL error in data received') +            return -        elif self._state == SSLProtocolState.FLUSHING: -            self._do_flush() +        for chunk in ssldata: +            self._transport.write(chunk) -        elif self._state == SSLProtocolState.SHUTDOWN: -            self._do_shutdown() +        for chunk in appdata: +            if chunk: +                try: +                    if self._app_protocol_is_buffer: +                        protocols._feed_data_to_buffered_proto( +                            self._app_protocol, chunk) +                    else: +                        self._app_protocol.data_received(chunk) +                except (SystemExit, KeyboardInterrupt): +                    raise +                except BaseException as ex: +                    self._fatal_error( +                        ex, 'application protocol failed to receive SSL data') +                    return +            else: +                self._start_shutdown() +                break      def eof_received(self):          """Called when the other end of the low-level stream @@ -448,32 +561,19 @@ class SSLProtocol(protocols.BufferedProtocol):          will close itself.  If it returns a true value, closing the          transport is up to the protocol.          """ -        self._eof_received = True          try:              if self._loop.get_debug():                  logger.debug("%r received EOF", self) -            if self._state == SSLProtocolState.DO_HANDSHAKE: -                self._on_handshake_complete(ConnectionResetError) - -            elif self._state == SSLProtocolState.WRAPPED: -                self._set_state(SSLProtocolState.FLUSHING) -                if self._app_reading_paused: -                    return True -                else: -                    self._do_flush() - -            elif self._state == SSLProtocolState.FLUSHING: -                self._do_write() -                self._set_state(SSLProtocolState.SHUTDOWN) -                self._do_shutdown() +            self._wakeup_waiter(ConnectionResetError) -            elif self._state == SSLProtocolState.SHUTDOWN: -                self._do_shutdown() - -        except Exception: +            if not self._in_handshake: +                keep_open = self._app_protocol.eof_received() +                if keep_open: +                    logger.warning('returning true from eof_received() ' +                                   'has no effect when using ssl') +        finally:              self._transport.close() -            raise      def _get_extra_info(self, name, default=None):          if name in self._extra: @@ -483,45 +583,19 @@ class SSLProtocol(protocols.BufferedProtocol):          else:              return default -    def _set_state(self, new_state): -        allowed = False - -        if new_state == SSLProtocolState.UNWRAPPED: -            allowed = True - -        elif ( -            self._state == SSLProtocolState.UNWRAPPED and -            new_state == SSLProtocolState.DO_HANDSHAKE -        ): -            allowed = True - -        elif ( -            self._state == SSLProtocolState.DO_HANDSHAKE and -            new_state == SSLProtocolState.WRAPPED -        ): -            allowed = True - -        elif ( -            self._state == SSLProtocolState.WRAPPED and -            new_state == SSLProtocolState.FLUSHING -        ): -            allowed = True - -        elif ( -            self._state == SSLProtocolState.FLUSHING and -            new_state == SSLProtocolState.SHUTDOWN -        ): -            allowed = True - -        if allowed: -            self._state = new_state - +    def _start_shutdown(self): +        if self._in_shutdown: +            return +        if self._in_handshake: +            self._abort()          else: -            raise RuntimeError( -                'cannot switch state from {} to {}'.format( -                    self._state, new_state)) +            self._in_shutdown = True +            self._write_appdata(b'') -    # Handshake flow +    def _write_appdata(self, data): +        self._write_backlog.append((data, 0)) +        self._write_buffer_size += len(data) +        self._process_write_backlog()      def _start_handshake(self):          if self._loop.get_debug(): @@ -529,18 +603,17 @@ class SSLProtocol(protocols.BufferedProtocol):              self._handshake_start_time = self._loop.time()          else:              self._handshake_start_time = None - -        self._set_state(SSLProtocolState.DO_HANDSHAKE) - -        # start handshake timeout count down +        self._in_handshake = True +        # (b'', 1) is a special value in _process_write_backlog() to do +        # the SSL handshake +        self._write_backlog.append((b'', 1))          self._handshake_timeout_handle = \              self._loop.call_later(self._ssl_handshake_timeout, -                                  lambda: self._check_handshake_timeout()) - -        self._do_handshake() +                                  self._check_handshake_timeout) +        self._process_write_backlog()      def _check_handshake_timeout(self): -        if self._state == SSLProtocolState.DO_HANDSHAKE: +        if self._in_handshake is True:              msg = (                  f"SSL handshake is taking longer than "                  f"{self._ssl_handshake_timeout} seconds: " @@ -548,37 +621,24 @@ class SSLProtocol(protocols.BufferedProtocol):              )              self._fatal_error(ConnectionAbortedError(msg)) -    def _do_handshake(self): -        try: -            self._sslobj.do_handshake() -        except SSLAgainErrors: -            self._process_outgoing() -        except ssl.SSLError as exc: -            self._on_handshake_complete(exc) -        else: -            self._on_handshake_complete(None) -      def _on_handshake_complete(self, handshake_exc): -        if self._handshake_timeout_handle is not None: -            self._handshake_timeout_handle.cancel() -            self._handshake_timeout_handle = None +        self._in_handshake = False +        self._handshake_timeout_handle.cancel() -        sslobj = self._sslobj +        sslobj = self._sslpipe.ssl_object          try: -            if handshake_exc is None: -                self._set_state(SSLProtocolState.WRAPPED) -            else: +            if handshake_exc is not None:                  raise handshake_exc              peercert = sslobj.getpeercert() -        except Exception as exc: -            self._set_state(SSLProtocolState.UNWRAPPED) +        except (SystemExit, KeyboardInterrupt): +            raise +        except BaseException as exc:              if isinstance(exc, ssl.CertificateError):                  msg = 'SSL handshake failed on verifying the certificate'              else:                  msg = 'SSL handshake failed'              self._fatal_error(exc, msg) -            self._wakeup_waiter(exc)              return          if self._loop.get_debug(): @@ -589,330 +649,85 @@ class SSLProtocol(protocols.BufferedProtocol):          self._extra.update(peercert=peercert,                             cipher=sslobj.cipher(),                             compression=sslobj.compression(), -                           ssl_object=sslobj) -        if self._app_state == AppProtocolState.STATE_INIT: -            self._app_state = AppProtocolState.STATE_CON_MADE -            self._app_protocol.connection_made(self._get_app_transport()) +                           ssl_object=sslobj, +                           ) +        if self._call_connection_made: +            self._app_protocol.connection_made(self._app_transport)          self._wakeup_waiter() -        self._do_read() - -    # Shutdown flow - -    def _start_shutdown(self): -        if ( -            self._state in ( -                SSLProtocolState.FLUSHING, -                SSLProtocolState.SHUTDOWN, -                SSLProtocolState.UNWRAPPED -            ) -        ): -            return -        if self._app_transport is not None: -            self._app_transport._closed = True -        if self._state == SSLProtocolState.DO_HANDSHAKE: -            self._abort() -        else: -            self._set_state(SSLProtocolState.FLUSHING) -            self._shutdown_timeout_handle = self._loop.call_later( -                self._ssl_shutdown_timeout, -                lambda: self._check_shutdown_timeout() -            ) -            self._do_flush() - -    def _check_shutdown_timeout(self): -        if ( -            self._state in ( -                SSLProtocolState.FLUSHING, -                SSLProtocolState.SHUTDOWN -            ) -        ): -            self._transport._force_close( -                exceptions.TimeoutError('SSL shutdown timed out')) - -    def _do_flush(self): -        self._do_read() -        self._set_state(SSLProtocolState.SHUTDOWN) -        self._do_shutdown() - -    def _do_shutdown(self): -        try: -            if not self._eof_received: -                self._sslobj.unwrap() -        except SSLAgainErrors: -            self._process_outgoing() -        except ssl.SSLError as exc: -            self._on_shutdown_complete(exc) -        else: -            self._process_outgoing() -            self._call_eof_received() -            self._on_shutdown_complete(None) - -    def _on_shutdown_complete(self, shutdown_exc): -        if self._shutdown_timeout_handle is not None: -            self._shutdown_timeout_handle.cancel() -            self._shutdown_timeout_handle = None - -        if shutdown_exc: -            self._fatal_error(shutdown_exc) -        else: -            self._loop.call_soon(self._transport.close) - -    def _abort(self): -        self._set_state(SSLProtocolState.UNWRAPPED) -        if self._transport is not None: -            self._transport.abort() - -    # Outgoing flow - -    def _write_appdata(self, list_of_data): -        if ( -            self._state in ( -                SSLProtocolState.FLUSHING, -                SSLProtocolState.SHUTDOWN, -                SSLProtocolState.UNWRAPPED -            ) -        ): -            if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: -                logger.warning('SSL connection is closed') -            self._conn_lost += 1 +        self._session_established = True +        # In case transport.write() was already called. Don't call +        # immediately _process_write_backlog(), but schedule it: +        # _on_handshake_complete() can be called indirectly from +        # _process_write_backlog(), and _process_write_backlog() is not +        # reentrant. +        self._loop.call_soon(self._process_write_backlog) + +    def _process_write_backlog(self): +        # Try to make progress on the write backlog. +        if self._transport is None or self._sslpipe is None:              return -        for data in list_of_data: -            self._write_backlog.append(data) -            self._write_buffer_size += len(data) -          try: -            if self._state == SSLProtocolState.WRAPPED: -                self._do_write() - -        except Exception as ex: -            self._fatal_error(ex, 'Fatal error on SSL protocol') - -    def _do_write(self): -        try: -            while self._write_backlog: -                data = self._write_backlog[0] -                count = self._sslobj.write(data) -                data_len = len(data) -                if count < data_len: -                    self._write_backlog[0] = data[count:] -                    self._write_buffer_size -= count +            for i in range(len(self._write_backlog)): +                data, offset = self._write_backlog[0] +                if data: +                    ssldata, offset = self._sslpipe.feed_appdata(data, offset) +                elif offset: +                    ssldata = self._sslpipe.do_handshake( +                        self._on_handshake_complete) +                    offset = 1                  else: -                    del self._write_backlog[0] -                    self._write_buffer_size -= data_len -        except SSLAgainErrors: -            pass -        self._process_outgoing() - -    def _process_outgoing(self): -        if not self._ssl_writing_paused: -            data = self._outgoing.read() -            if len(data): -                self._transport.write(data) -        self._control_app_writing() - -    # Incoming flow - -    def _do_read(self): -        if ( -            self._state not in ( -                SSLProtocolState.WRAPPED, -                SSLProtocolState.FLUSHING, -            ) -        ): -            return -        try: -            if not self._app_reading_paused: -                if self._app_protocol_is_buffer: -                    self._do_read__buffered() -                else: -                    self._do_read__copied() -                if self._write_backlog: -                    self._do_write() -                else: -                    self._process_outgoing() -            self._control_ssl_reading() -        except Exception as ex: -            self._fatal_error(ex, 'Fatal error on SSL protocol') - -    def _do_read__buffered(self): -        offset = 0 -        count = 1 - -        buf = self._app_protocol_get_buffer(self._get_read_buffer_size()) -        wants = len(buf) - -        try: -            count = self._sslobj.read(wants, buf) - -            if count > 0: -                offset = count -                while offset < wants: -                    count = self._sslobj.read(wants - offset, buf[offset:]) -                    if count > 0: -                        offset += count -                    else: -                        break -                else: -                    self._loop.call_soon(lambda: self._do_read()) -        except SSLAgainErrors: -            pass -        if offset > 0: -            self._app_protocol_buffer_updated(offset) -        if not count: -            # close_notify -            self._call_eof_received() -            self._start_shutdown() - -    def _do_read__copied(self): -        chunk = b'1' -        zero = True -        one = False - -        try: -            while True: -                chunk = self._sslobj.read(self.max_size) -                if not chunk: +                    ssldata = self._sslpipe.shutdown(self._finalize) +                    offset = 1 + +                for chunk in ssldata: +                    self._transport.write(chunk) + +                if offset < len(data): +                    self._write_backlog[0] = (data, offset) +                    # A short write means that a write is blocked on a read +                    # We need to enable reading if it is paused! +                    assert self._sslpipe.need_ssldata +                    if self._transport._paused: +                        self._transport.resume_reading()                      break -                if zero: -                    zero = False -                    one = True -                    first = chunk -                elif one: -                    one = False -                    data = [first, chunk] -                else: -                    data.append(chunk) -        except SSLAgainErrors: -            pass -        if one: -            self._app_protocol.data_received(first) -        elif not zero: -            self._app_protocol.data_received(b''.join(data)) -        if not chunk: -            # close_notify -            self._call_eof_received() -            self._start_shutdown() - -    def _call_eof_received(self): -        try: -            if self._app_state == AppProtocolState.STATE_CON_MADE: -                self._app_state = AppProtocolState.STATE_EOF -                keep_open = self._app_protocol.eof_received() -                if keep_open: -                    logger.warning('returning true from eof_received() ' -                                   'has no effect when using ssl') -        except (KeyboardInterrupt, SystemExit): -            raise -        except BaseException as ex: -            self._fatal_error(ex, 'Error calling eof_received()') - -    # Flow control for writes from APP socket -    def _control_app_writing(self): -        size = self._get_write_buffer_size() -        if size >= self._outgoing_high_water and not self._app_writing_paused: -            self._app_writing_paused = True -            try: -                self._app_protocol.pause_writing() -            except (KeyboardInterrupt, SystemExit): -                raise -            except BaseException as exc: -                self._loop.call_exception_handler({ -                    'message': 'protocol.pause_writing() failed', -                    'exception': exc, -                    'transport': self._app_transport, -                    'protocol': self, -                }) -        elif size <= self._outgoing_low_water and self._app_writing_paused: -            self._app_writing_paused = False -            try: -                self._app_protocol.resume_writing() -            except (KeyboardInterrupt, SystemExit): -                raise -            except BaseException as exc: -                self._loop.call_exception_handler({ -                    'message': 'protocol.resume_writing() failed', -                    'exception': exc, -                    'transport': self._app_transport, -                    'protocol': self, -                }) - -    def _get_write_buffer_size(self): -        return self._outgoing.pending + self._write_buffer_size - -    def _set_write_buffer_limits(self, high=None, low=None): -        high, low = add_flowcontrol_defaults( -            high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE) -        self._outgoing_high_water = high -        self._outgoing_low_water = low - -    # Flow control for reads to APP socket - -    def _pause_reading(self): -        self._app_reading_paused = True - -    def _resume_reading(self): -        if self._app_reading_paused: -            self._app_reading_paused = False - -            def resume(): -                if self._state == SSLProtocolState.WRAPPED: -                    self._do_read() -                elif self._state == SSLProtocolState.FLUSHING: -                    self._do_flush() -                elif self._state == SSLProtocolState.SHUTDOWN: -                    self._do_shutdown() -            self._loop.call_soon(resume) - -    # Flow control for reads from SSL socket - -    def _control_ssl_reading(self): -        size = self._get_read_buffer_size() -        if size >= self._incoming_high_water and not self._ssl_reading_paused: -            self._ssl_reading_paused = True -            self._transport.pause_reading() -        elif size <= self._incoming_low_water and self._ssl_reading_paused: -            self._ssl_reading_paused = False -            self._transport.resume_reading() - -    def _set_read_buffer_limits(self, high=None, low=None): -        high, low = add_flowcontrol_defaults( -            high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ) -        self._incoming_high_water = high -        self._incoming_low_water = low - -    def _get_read_buffer_size(self): -        return self._incoming.pending - -    # Flow control for writes to SSL socket - -    def pause_writing(self): -        """Called when the low-level transport's buffer goes over -        the high-water mark. -        """ -        assert not self._ssl_writing_paused -        self._ssl_writing_paused = True - -    def resume_writing(self): -        """Called when the low-level transport's buffer drains below -        the low-water mark. -        """ -        assert self._ssl_writing_paused -        self._ssl_writing_paused = False -        self._process_outgoing() +                # An entire chunk from the backlog was processed. We can +                # delete it and reduce the outstanding buffer size. +                del self._write_backlog[0] +                self._write_buffer_size -= len(data) +        except (SystemExit, KeyboardInterrupt): +            raise +        except BaseException as exc: +            if self._in_handshake: +                # Exceptions will be re-raised in _on_handshake_complete. +                self._on_handshake_complete(exc) +            else: +                self._fatal_error(exc, 'Fatal error on SSL transport')      def _fatal_error(self, exc, message='Fatal error on transport'): -        if self._transport: -            self._transport._force_close(exc) -          if isinstance(exc, OSError):              if self._loop.get_debug():                  logger.debug("%r: %s", self, message, exc_info=True) -        elif not isinstance(exc, exceptions.CancelledError): +        else:              self._loop.call_exception_handler({                  'message': message,                  'exception': exc,                  'transport': self._transport,                  'protocol': self,              }) +        if self._transport: +            self._transport._force_close(exc) + +    def _finalize(self): +        self._sslpipe = None + +        if self._transport is not None: +            self._transport.close() + +    def _abort(self): +        try: +            if self._transport is not None: +                self._transport.abort() +        finally: +            self._finalize() diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index 181e188515..a55b3a375f 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -229,8 +229,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):              self, protocol_factory, path=None, *,              ssl=None, sock=None,              server_hostname=None, -            ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None): +            ssl_handshake_timeout=None):          assert server_hostname is None or isinstance(server_hostname, str)          if ssl:              if server_hostname is None: @@ -242,9 +241,6 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):              if ssl_handshake_timeout is not None:                  raise ValueError(                      'ssl_handshake_timeout is only meaningful with ssl') -            if ssl_shutdown_timeout is not None: -                raise ValueError( -                    'ssl_shutdown_timeout is only meaningful with ssl')          if path is not None:              if sock is not None: @@ -271,15 +267,13 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):          transport, protocol = await self._create_connection_transport(              sock, protocol_factory, ssl, server_hostname, -            ssl_handshake_timeout=ssl_handshake_timeout, -            ssl_shutdown_timeout=ssl_shutdown_timeout) +            ssl_handshake_timeout=ssl_handshake_timeout)          return transport, protocol      async def create_unix_server(              self, protocol_factory, path=None, *,              sock=None, backlog=100, ssl=None,              ssl_handshake_timeout=None, -            ssl_shutdown_timeout=None,              start_serving=True):          if isinstance(ssl, bool):              raise TypeError('ssl argument must be an SSLContext or None') @@ -288,10 +282,6 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):              raise ValueError(                  'ssl_handshake_timeout is only meaningful with ssl') -        if ssl_shutdown_timeout is not None and not ssl: -            raise ValueError( -                'ssl_shutdown_timeout is only meaningful with ssl') -          if path is not None:              if sock is not None:                  raise ValueError( @@ -338,8 +328,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):          sock.setblocking(False)          server = base_events.Server(self, [sock], protocol_factory, -                                    ssl, backlog, ssl_handshake_timeout, -                                    ssl_shutdown_timeout) +                                    ssl, backlog, ssl_handshake_timeout)          if start_serving:              server._start_serving()              # Skip one loop iteration so that all 'loop.add_reader' | 
