summaryrefslogtreecommitdiff
path: root/trollius
diff options
context:
space:
mode:
authorVictor Stinner <vstinner@redhat.com>2015-07-06 22:11:55 +0200
committerVictor Stinner <vstinner@redhat.com>2015-07-06 22:30:54 +0200
commit1975461f4ccf0d9ffd9e20e4dd4a4650ad6a0c18 (patch)
treed185d0cf9c4bd65aebf93fcdd24eb15dac4383db /trollius
parent728a91239123913b9161e8ca174cf86288f8c7d7 (diff)
downloadtrollius-git-1975461f4ccf0d9ffd9e20e4dd4a4650ad6a0c18.tar.gz
rename asyncio/ directory to trollius/
Diffstat (limited to 'trollius')
-rw-r--r--trollius/__init__.py50
-rw-r--r--trollius/base_events.py1240
-rw-r--r--trollius/base_subprocess.py275
-rw-r--r--trollius/constants.py7
-rw-r--r--trollius/coroutines.py301
-rw-r--r--trollius/events.py611
-rw-r--r--trollius/futures.py413
-rw-r--r--trollius/locks.py470
-rw-r--r--trollius/log.py7
-rw-r--r--trollius/proactor_events.py547
-rw-r--r--trollius/protocols.py134
-rw-r--r--trollius/queues.py293
-rw-r--r--trollius/selector_events.py1068
-rw-r--r--trollius/selectors.py594
-rw-r--r--trollius/sslproto.py668
-rw-r--r--trollius/streams.py501
-rw-r--r--trollius/subprocess.py215
-rw-r--r--trollius/tasks.py681
-rw-r--r--trollius/test_support.py308
-rw-r--r--trollius/test_utils.py446
-rw-r--r--trollius/transports.py300
-rw-r--r--trollius/unix_events.py998
-rw-r--r--trollius/windows_events.py774
-rw-r--r--trollius/windows_utils.py223
24 files changed, 11124 insertions, 0 deletions
diff --git a/trollius/__init__.py b/trollius/__init__.py
new file mode 100644
index 0000000..011466b
--- /dev/null
+++ b/trollius/__init__.py
@@ -0,0 +1,50 @@
+"""The asyncio package, tracking PEP 3156."""
+
+import sys
+
+# The selectors module is in the stdlib in Python 3.4 but not in 3.3.
+# Do this first, so the other submodules can use "from . import selectors".
+# Prefer asyncio/selectors.py over the stdlib one, as ours may be newer.
+try:
+ from . import selectors
+except ImportError:
+ import selectors # Will also be exported.
+
+if sys.platform == 'win32':
+ # Similar thing for _overlapped.
+ try:
+ from . import _overlapped
+ except ImportError:
+ import _overlapped # Will also be exported.
+
+# This relies on each of the submodules having an __all__ variable.
+from .base_events import *
+from .coroutines import *
+from .events import *
+from .futures import *
+from .locks import *
+from .protocols import *
+from .queues import *
+from .streams import *
+from .subprocess import *
+from .tasks import *
+from .transports import *
+
+__all__ = (base_events.__all__ +
+ coroutines.__all__ +
+ events.__all__ +
+ futures.__all__ +
+ locks.__all__ +
+ protocols.__all__ +
+ queues.__all__ +
+ streams.__all__ +
+ subprocess.__all__ +
+ tasks.__all__ +
+ transports.__all__)
+
+if sys.platform == 'win32': # pragma: no cover
+ from .windows_events import *
+ __all__ += windows_events.__all__
+else:
+ from .unix_events import * # pragma: no cover
+ __all__ += unix_events.__all__
diff --git a/trollius/base_events.py b/trollius/base_events.py
new file mode 100644
index 0000000..5a536a2
--- /dev/null
+++ b/trollius/base_events.py
@@ -0,0 +1,1240 @@
+"""Base implementation of event loop.
+
+The event loop can be broken up into a multiplexer (the part
+responsible for notifying us of I/O events) and the event loop proper,
+which wraps a multiplexer with functionality for scheduling callbacks,
+immediately or at a given time in the future.
+
+Whenever a public API takes a callback, subsequent positional
+arguments will be passed to the callback if/when it is called. This
+avoids the proliferation of trivial lambdas implementing closures.
+Keyword arguments for the callback are not supported; this is a
+conscious design decision, leaving the door open for keyword arguments
+to modify the meaning of the API call itself.
+"""
+
+
+import collections
+import concurrent.futures
+import heapq
+import inspect
+import logging
+import os
+import socket
+import subprocess
+import threading
+import time
+import traceback
+import sys
+import warnings
+
+from . import coroutines
+from . import events
+from . import futures
+from . import tasks
+from .coroutines import coroutine
+from .log import logger
+
+
+__all__ = ['BaseEventLoop']
+
+
+# Argument for default thread pool executor creation.
+_MAX_WORKERS = 5
+
+# Minimum number of _scheduled timer handles before cleanup of
+# cancelled handles is performed.
+_MIN_SCHEDULED_TIMER_HANDLES = 100
+
+# Minimum fraction of _scheduled timer handles that are cancelled
+# before cleanup of cancelled handles is performed.
+_MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5
+
+def _format_handle(handle):
+ cb = handle._callback
+ if inspect.ismethod(cb) and isinstance(cb.__self__, tasks.Task):
+ # format the task
+ return repr(cb.__self__)
+ else:
+ return str(handle)
+
+
+def _format_pipe(fd):
+ if fd == subprocess.PIPE:
+ return '<pipe>'
+ elif fd == subprocess.STDOUT:
+ return '<stdout>'
+ else:
+ return repr(fd)
+
+
+class _StopError(BaseException):
+ """Raised to stop the event loop."""
+
+
+def _check_resolved_address(sock, address):
+ # Ensure that the address is already resolved to avoid the trap of hanging
+ # the entire event loop when the address requires doing a DNS lookup.
+ #
+ # getaddrinfo() is slow (around 10 us per call): this function should only
+ # be called in debug mode
+ family = sock.family
+
+ if family == socket.AF_INET:
+ host, port = address
+ elif family == socket.AF_INET6:
+ host, port = address[:2]
+ else:
+ return
+
+ # On Windows, socket.inet_pton() is only available since Python 3.4
+ if hasattr(socket, 'inet_pton'):
+ # getaddrinfo() is slow and has known issue: prefer inet_pton()
+ # if available
+ try:
+ socket.inet_pton(family, host)
+ except OSError as exc:
+ raise ValueError("address must be resolved (IP address), "
+ "got host %r: %s"
+ % (host, exc))
+ else:
+ # Use getaddrinfo(flags=AI_NUMERICHOST) to ensure that the address is
+ # already resolved.
+ type_mask = 0
+ if hasattr(socket, 'SOCK_NONBLOCK'):
+ type_mask |= socket.SOCK_NONBLOCK
+ if hasattr(socket, 'SOCK_CLOEXEC'):
+ type_mask |= socket.SOCK_CLOEXEC
+ try:
+ socket.getaddrinfo(host, port,
+ family=family,
+ type=(sock.type & ~type_mask),
+ proto=sock.proto,
+ flags=socket.AI_NUMERICHOST)
+ except socket.gaierror as err:
+ raise ValueError("address must be resolved (IP address), "
+ "got host %r: %s"
+ % (host, err))
+
+def _raise_stop_error(*args):
+ raise _StopError
+
+
+def _run_until_complete_cb(fut):
+ exc = fut._exception
+ if (isinstance(exc, BaseException)
+ and not isinstance(exc, Exception)):
+ # Issue #22429: run_forever() already finished, no need to
+ # stop it.
+ return
+ _raise_stop_error()
+
+
+class Server(events.AbstractServer):
+
+ def __init__(self, loop, sockets):
+ self._loop = loop
+ self.sockets = sockets
+ self._active_count = 0
+ self._waiters = []
+
+ def __repr__(self):
+ return '<%s sockets=%r>' % (self.__class__.__name__, self.sockets)
+
+ def _attach(self):
+ assert self.sockets is not None
+ self._active_count += 1
+
+ def _detach(self):
+ assert self._active_count > 0
+ self._active_count -= 1
+ if self._active_count == 0 and self.sockets is None:
+ self._wakeup()
+
+ def close(self):
+ sockets = self.sockets
+ if sockets is None:
+ return
+ self.sockets = None
+ for sock in sockets:
+ self._loop._stop_serving(sock)
+ if self._active_count == 0:
+ self._wakeup()
+
+ def _wakeup(self):
+ waiters = self._waiters
+ self._waiters = None
+ for waiter in waiters:
+ if not waiter.done():
+ waiter.set_result(waiter)
+
+ @coroutine
+ def wait_closed(self):
+ if self.sockets is None or self._waiters is None:
+ return
+ waiter = futures.Future(loop=self._loop)
+ self._waiters.append(waiter)
+ yield from waiter
+
+
+class BaseEventLoop(events.AbstractEventLoop):
+
+ def __init__(self):
+ self._timer_cancelled_count = 0
+ self._closed = False
+ self._ready = collections.deque()
+ self._scheduled = []
+ self._default_executor = None
+ self._internal_fds = 0
+ # Identifier of the thread running the event loop, or None if the
+ # event loop is not running
+ self._thread_id = None
+ self._clock_resolution = time.get_clock_info('monotonic').resolution
+ self._exception_handler = None
+ self.set_debug((not sys.flags.ignore_environment
+ and bool(os.environ.get('PYTHONASYNCIODEBUG'))))
+ # In debug mode, if the execution of a callback or a step of a task
+ # exceed this duration in seconds, the slow callback/task is logged.
+ self.slow_callback_duration = 0.1
+ self._current_handle = None
+ self._task_factory = None
+ self._coroutine_wrapper_set = False
+
+ def __repr__(self):
+ return ('<%s running=%s closed=%s debug=%s>'
+ % (self.__class__.__name__, self.is_running(),
+ self.is_closed(), self.get_debug()))
+
+ def create_task(self, coro):
+ """Schedule a coroutine object.
+
+ Return a task object.
+ """
+ self._check_closed()
+ if self._task_factory is None:
+ task = tasks.Task(coro, loop=self)
+ if task._source_traceback:
+ del task._source_traceback[-1]
+ else:
+ task = self._task_factory(self, coro)
+ return task
+
+ def set_task_factory(self, factory):
+ """Set a task factory that will be used by loop.create_task().
+
+ If factory is None the default task factory will be set.
+
+ If factory is a callable, it should have a signature matching
+ '(loop, coro)', where 'loop' will be a reference to the active
+ event loop, 'coro' will be a coroutine object. The callable
+ must return a Future.
+ """
+ if factory is not None and not callable(factory):
+ raise TypeError('task factory must be a callable or None')
+ self._task_factory = factory
+
+ def get_task_factory(self):
+ """Return a task factory, or None if the default one is in use."""
+ return self._task_factory
+
+ def _make_socket_transport(self, sock, protocol, waiter=None, *,
+ extra=None, server=None):
+ """Create socket transport."""
+ raise NotImplementedError
+
+ def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
+ *, server_side=False, server_hostname=None,
+ extra=None, server=None):
+ """Create SSL transport."""
+ raise NotImplementedError
+
+ def _make_datagram_transport(self, sock, protocol,
+ address=None, waiter=None, extra=None):
+ """Create datagram transport."""
+ raise NotImplementedError
+
+ def _make_read_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ """Create read pipe transport."""
+ raise NotImplementedError
+
+ def _make_write_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ """Create write pipe transport."""
+ raise NotImplementedError
+
+ @coroutine
+ def _make_subprocess_transport(self, protocol, args, shell,
+ stdin, stdout, stderr, bufsize,
+ extra=None, **kwargs):
+ """Create subprocess transport."""
+ raise NotImplementedError
+
+ def _write_to_self(self):
+ """Write a byte to self-pipe, to wake up the event loop.
+
+ This may be called from a different thread.
+
+ The subclass is responsible for implementing the self-pipe.
+ """
+ raise NotImplementedError
+
+ def _process_events(self, event_list):
+ """Process selector events."""
+ raise NotImplementedError
+
+ def _check_closed(self):
+ if self._closed:
+ raise RuntimeError('Event loop is closed')
+
+ def run_forever(self):
+ """Run until stop() is called."""
+ self._check_closed()
+ if self.is_running():
+ raise RuntimeError('Event loop is running.')
+ self._set_coroutine_wrapper(self._debug)
+ self._thread_id = threading.get_ident()
+ try:
+ while True:
+ try:
+ self._run_once()
+ except _StopError:
+ break
+ finally:
+ self._thread_id = None
+ self._set_coroutine_wrapper(False)
+
+ def run_until_complete(self, future):
+ """Run until the Future is done.
+
+ If the argument is a coroutine, it is wrapped in a Task.
+
+ WARNING: It would be disastrous to call run_until_complete()
+ with the same coroutine twice -- it would wrap it in two
+ different Tasks and that can't be good.
+
+ Return the Future's result, or raise its exception.
+ """
+ self._check_closed()
+
+ new_task = not isinstance(future, futures.Future)
+ future = tasks.ensure_future(future, loop=self)
+ if new_task:
+ # An exception is raised if the future didn't complete, so there
+ # is no need to log the "destroy pending task" message
+ future._log_destroy_pending = False
+
+ future.add_done_callback(_run_until_complete_cb)
+ try:
+ self.run_forever()
+ except:
+ if new_task and future.done() and not future.cancelled():
+ # The coroutine raised a BaseException. Consume the exception
+ # to not log a warning, the caller doesn't have access to the
+ # local task.
+ future.exception()
+ raise
+ future.remove_done_callback(_run_until_complete_cb)
+ if not future.done():
+ raise RuntimeError('Event loop stopped before Future completed.')
+
+ return future.result()
+
+ def stop(self):
+ """Stop running the event loop.
+
+ Every callback scheduled before stop() is called will run. Callbacks
+ scheduled after stop() is called will not run. However, those callbacks
+ will run if run_forever is called again later.
+ """
+ self.call_soon(_raise_stop_error)
+
+ def close(self):
+ """Close the event loop.
+
+ This clears the queues and shuts down the executor,
+ but does not wait for the executor to finish.
+
+ The event loop must not be running.
+ """
+ if self.is_running():
+ raise RuntimeError("Cannot close a running event loop")
+ if self._closed:
+ return
+ if self._debug:
+ logger.debug("Close %r", self)
+ self._closed = True
+ self._ready.clear()
+ self._scheduled.clear()
+ executor = self._default_executor
+ if executor is not None:
+ self._default_executor = None
+ executor.shutdown(wait=False)
+
+ def is_closed(self):
+ """Returns True if the event loop was closed."""
+ return self._closed
+
+ # On Python 3.3 and older, objects with a destructor part of a reference
+ # cycle are never destroyed. It's not more the case on Python 3.4 thanks
+ # to the PEP 442.
+ if sys.version_info >= (3, 4):
+ def __del__(self):
+ if not self.is_closed():
+ warnings.warn("unclosed event loop %r" % self, ResourceWarning)
+ if not self.is_running():
+ self.close()
+
+ def is_running(self):
+ """Returns True if the event loop is running."""
+ return (self._thread_id is not None)
+
+ def time(self):
+ """Return the time according to the event loop's clock.
+
+ This is a float expressed in seconds since an epoch, but the
+ epoch, precision, accuracy and drift are unspecified and may
+ differ per event loop.
+ """
+ return time.monotonic()
+
+ def call_later(self, delay, callback, *args):
+ """Arrange for a callback to be called at a given time.
+
+ Return a Handle: an opaque object with a cancel() method that
+ can be used to cancel the call.
+
+ The delay can be an int or float, expressed in seconds. It is
+ always relative to the current time.
+
+ Each callback will be called exactly once. If two callbacks
+ are scheduled for exactly the same time, it undefined which
+ will be called first.
+
+ Any positional arguments after the callback will be passed to
+ the callback when it is called.
+ """
+ timer = self.call_at(self.time() + delay, callback, *args)
+ if timer._source_traceback:
+ del timer._source_traceback[-1]
+ return timer
+
+ def call_at(self, when, callback, *args):
+ """Like call_later(), but uses an absolute time.
+
+ Absolute time corresponds to the event loop's time() method.
+ """
+ if (coroutines.iscoroutine(callback)
+ or coroutines.iscoroutinefunction(callback)):
+ raise TypeError("coroutines cannot be used with call_at()")
+ self._check_closed()
+ if self._debug:
+ self._check_thread()
+ timer = events.TimerHandle(when, callback, args, self)
+ if timer._source_traceback:
+ del timer._source_traceback[-1]
+ heapq.heappush(self._scheduled, timer)
+ timer._scheduled = True
+ return timer
+
+ def call_soon(self, callback, *args):
+ """Arrange for a callback to be called as soon as possible.
+
+ This operates as a FIFO queue: callbacks are called in the
+ order in which they are registered. Each callback will be
+ called exactly once.
+
+ Any positional arguments after the callback will be passed to
+ the callback when it is called.
+ """
+ if self._debug:
+ self._check_thread()
+ handle = self._call_soon(callback, args)
+ if handle._source_traceback:
+ del handle._source_traceback[-1]
+ return handle
+
+ def _call_soon(self, callback, args):
+ if (coroutines.iscoroutine(callback)
+ or coroutines.iscoroutinefunction(callback)):
+ raise TypeError("coroutines cannot be used with call_soon()")
+ self._check_closed()
+ handle = events.Handle(callback, args, self)
+ if handle._source_traceback:
+ del handle._source_traceback[-1]
+ self._ready.append(handle)
+ return handle
+
+ def _check_thread(self):
+ """Check that the current thread is the thread running the event loop.
+
+ Non-thread-safe methods of this class make this assumption and will
+ likely behave incorrectly when the assumption is violated.
+
+ Should only be called when (self._debug == True). The caller is
+ responsible for checking this condition for performance reasons.
+ """
+ if self._thread_id is None:
+ return
+ thread_id = threading.get_ident()
+ if thread_id != self._thread_id:
+ raise RuntimeError(
+ "Non-thread-safe operation invoked on an event loop other "
+ "than the current one")
+
+ def call_soon_threadsafe(self, callback, *args):
+ """Like call_soon(), but thread-safe."""
+ handle = self._call_soon(callback, args)
+ if handle._source_traceback:
+ del handle._source_traceback[-1]
+ self._write_to_self()
+ return handle
+
+ def run_in_executor(self, executor, func, *args):
+ if (coroutines.iscoroutine(func)
+ or coroutines.iscoroutinefunction(func)):
+ raise TypeError("coroutines cannot be used with run_in_executor()")
+ self._check_closed()
+ if isinstance(func, events.Handle):
+ assert not args
+ assert not isinstance(func, events.TimerHandle)
+ if func._cancelled:
+ f = futures.Future(loop=self)
+ f.set_result(None)
+ return f
+ func, args = func._callback, func._args
+ if executor is None:
+ executor = self._default_executor
+ if executor is None:
+ executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS)
+ self._default_executor = executor
+ return futures.wrap_future(executor.submit(func, *args), loop=self)
+
+ def set_default_executor(self, executor):
+ self._default_executor = executor
+
+ def _getaddrinfo_debug(self, host, port, family, type, proto, flags):
+ msg = ["%s:%r" % (host, port)]
+ if family:
+ msg.append('family=%r' % family)
+ if type:
+ msg.append('type=%r' % type)
+ if proto:
+ msg.append('proto=%r' % proto)
+ if flags:
+ msg.append('flags=%r' % flags)
+ msg = ', '.join(msg)
+ logger.debug('Get address info %s', msg)
+
+ t0 = self.time()
+ addrinfo = socket.getaddrinfo(host, port, family, type, proto, flags)
+ dt = self.time() - t0
+
+ msg = ('Getting address info %s took %.3f ms: %r'
+ % (msg, dt * 1e3, addrinfo))
+ if dt >= self.slow_callback_duration:
+ logger.info(msg)
+ else:
+ logger.debug(msg)
+ return addrinfo
+
+ def getaddrinfo(self, host, port, *,
+ family=0, type=0, proto=0, flags=0):
+ if self._debug:
+ return self.run_in_executor(None, self._getaddrinfo_debug,
+ host, port, family, type, proto, flags)
+ else:
+ return self.run_in_executor(None, socket.getaddrinfo,
+ host, port, family, type, proto, flags)
+
+ def getnameinfo(self, sockaddr, flags=0):
+ return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags)
+
+ @coroutine
+ def create_connection(self, protocol_factory, host=None, port=None, *,
+ ssl=None, family=0, proto=0, flags=0, sock=None,
+ local_addr=None, server_hostname=None):
+ """Connect to a TCP server.
+
+ Create a streaming transport connection to a given Internet host and
+ port: socket family AF_INET or socket.AF_INET6 depending on host (or
+ family if specified), socket type SOCK_STREAM. protocol_factory must be
+ a callable returning a protocol instance.
+
+ This method is a coroutine which will try to establish the connection
+ in the background. When successful, the coroutine returns a
+ (transport, protocol) pair.
+ """
+ if server_hostname is not None and not ssl:
+ raise ValueError('server_hostname is only meaningful with ssl')
+
+ if server_hostname is None and ssl:
+ # Use host as default for server_hostname. It is an error
+ # if host is empty or not set, e.g. when an
+ # already-connected socket was passed or when only a port
+ # is given. To avoid this error, you can pass
+ # server_hostname='' -- this will bypass the hostname
+ # check. (This also means that if host is a numeric
+ # IP/IPv6 address, we will attempt to verify that exact
+ # address; this will probably fail, but it is possible to
+ # create a certificate for a specific IP address, so we
+ # don't judge it here.)
+ if not host:
+ raise ValueError('You must set server_hostname '
+ 'when using ssl without a host')
+ server_hostname = host
+
+ if host is not None or port is not None:
+ if sock is not None:
+ raise ValueError(
+ 'host/port and sock can not be specified at the same time')
+
+ f1 = self.getaddrinfo(
+ host, port, family=family,
+ type=socket.SOCK_STREAM, proto=proto, flags=flags)
+ fs = [f1]
+ if local_addr is not None:
+ f2 = self.getaddrinfo(
+ *local_addr, family=family,
+ type=socket.SOCK_STREAM, proto=proto, flags=flags)
+ fs.append(f2)
+ else:
+ f2 = None
+
+ yield from tasks.wait(fs, loop=self)
+
+ infos = f1.result()
+ if not infos:
+ raise OSError('getaddrinfo() returned empty list')
+ if f2 is not None:
+ laddr_infos = f2.result()
+ if not laddr_infos:
+ raise OSError('getaddrinfo() returned empty list')
+
+ exceptions = []
+ for family, type, proto, cname, address in infos:
+ try:
+ sock = socket.socket(family=family, type=type, proto=proto)
+ sock.setblocking(False)
+ if f2 is not None:
+ for _, _, _, _, laddr in laddr_infos:
+ try:
+ sock.bind(laddr)
+ break
+ except OSError as exc:
+ exc = OSError(
+ exc.errno, 'error while '
+ 'attempting to bind on address '
+ '{!r}: {}'.format(
+ laddr, exc.strerror.lower()))
+ exceptions.append(exc)
+ else:
+ sock.close()
+ sock = None
+ continue
+ if self._debug:
+ logger.debug("connect %r to %r", sock, address)
+ yield from self.sock_connect(sock, address)
+ except OSError as exc:
+ if sock is not None:
+ sock.close()
+ exceptions.append(exc)
+ except:
+ if sock is not None:
+ sock.close()
+ raise
+ else:
+ break
+ else:
+ if len(exceptions) == 1:
+ raise exceptions[0]
+ else:
+ # If they all have the same str(), raise one.
+ model = str(exceptions[0])
+ if all(str(exc) == model for exc in exceptions):
+ raise exceptions[0]
+ # Raise a combined exception so the user can see all
+ # the various error messages.
+ raise OSError('Multiple exceptions: {}'.format(
+ ', '.join(str(exc) for exc in exceptions)))
+
+ elif sock is None:
+ raise ValueError(
+ 'host and port was not specified and no sock specified')
+
+ sock.setblocking(False)
+
+ transport, protocol = yield from self._create_connection_transport(
+ sock, protocol_factory, ssl, server_hostname)
+ if self._debug:
+ # Get the socket from the transport because SSL transport closes
+ # the old socket and creates a new SSL socket
+ sock = transport.get_extra_info('socket')
+ logger.debug("%r connected to %s:%r: (%r, %r)",
+ sock, host, port, transport, protocol)
+ return transport, protocol
+
+ @coroutine
+ def _create_connection_transport(self, sock, protocol_factory, ssl,
+ server_hostname):
+ protocol = protocol_factory()
+ waiter = futures.Future(loop=self)
+ if ssl:
+ sslcontext = None if isinstance(ssl, bool) else ssl
+ transport = self._make_ssl_transport(
+ sock, protocol, sslcontext, waiter,
+ server_side=False, server_hostname=server_hostname)
+ else:
+ transport = self._make_socket_transport(sock, protocol, waiter)
+
+ try:
+ yield from waiter
+ except:
+ transport.close()
+ raise
+
+ return transport, protocol
+
+ @coroutine
+ def create_datagram_endpoint(self, protocol_factory,
+ local_addr=None, remote_addr=None, *,
+ family=0, proto=0, flags=0):
+ """Create datagram connection."""
+ if not (local_addr or remote_addr):
+ if family == 0:
+ raise ValueError('unexpected address family')
+ addr_pairs_info = (((family, proto), (None, None)),)
+ else:
+ # join address by (family, protocol)
+ addr_infos = collections.OrderedDict()
+ for idx, addr in ((0, local_addr), (1, remote_addr)):
+ if addr is not None:
+ assert isinstance(addr, tuple) and len(addr) == 2, (
+ '2-tuple is expected')
+
+ infos = yield from self.getaddrinfo(
+ *addr, family=family, type=socket.SOCK_DGRAM,
+ proto=proto, flags=flags)
+ if not infos:
+ raise OSError('getaddrinfo() returned empty list')
+
+ for fam, _, pro, _, address in infos:
+ key = (fam, pro)
+ if key not in addr_infos:
+ addr_infos[key] = [None, None]
+ addr_infos[key][idx] = address
+
+ # each addr has to have info for each (family, proto) pair
+ addr_pairs_info = [
+ (key, addr_pair) for key, addr_pair in addr_infos.items()
+ if not ((local_addr and addr_pair[0] is None) or
+ (remote_addr and addr_pair[1] is None))]
+
+ if not addr_pairs_info:
+ raise ValueError('can not get address information')
+
+ exceptions = []
+
+ for ((family, proto),
+ (local_address, remote_address)) in addr_pairs_info:
+ sock = None
+ r_addr = None
+ try:
+ sock = socket.socket(
+ family=family, type=socket.SOCK_DGRAM, proto=proto)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.setblocking(False)
+
+ if local_addr:
+ sock.bind(local_address)
+ if remote_addr:
+ yield from self.sock_connect(sock, remote_address)
+ r_addr = remote_address
+ except OSError as exc:
+ if sock is not None:
+ sock.close()
+ exceptions.append(exc)
+ except:
+ if sock is not None:
+ sock.close()
+ raise
+ else:
+ break
+ else:
+ raise exceptions[0]
+
+ protocol = protocol_factory()
+ waiter = futures.Future(loop=self)
+ transport = self._make_datagram_transport(sock, protocol, r_addr,
+ waiter)
+ if self._debug:
+ if local_addr:
+ logger.info("Datagram endpoint local_addr=%r remote_addr=%r "
+ "created: (%r, %r)",
+ local_addr, remote_addr, transport, protocol)
+ else:
+ logger.debug("Datagram endpoint remote_addr=%r created: "
+ "(%r, %r)",
+ remote_addr, transport, protocol)
+
+ try:
+ yield from waiter
+ except:
+ transport.close()
+ raise
+
+ return transport, protocol
+
+ @coroutine
+ def create_server(self, protocol_factory, host=None, port=None,
+ *,
+ family=socket.AF_UNSPEC,
+ flags=socket.AI_PASSIVE,
+ sock=None,
+ backlog=100,
+ ssl=None,
+ reuse_address=None):
+ """Create a TCP server bound to host and port.
+
+ Return a Server object which can be used to stop the service.
+
+ This method is a coroutine.
+ """
+ if isinstance(ssl, bool):
+ raise TypeError('ssl argument must be an SSLContext or None')
+ if host is not None or port is not None:
+ if sock is not None:
+ raise ValueError(
+ 'host/port and sock can not be specified at the same time')
+
+ AF_INET6 = getattr(socket, 'AF_INET6', 0)
+ if reuse_address is None:
+ reuse_address = os.name == 'posix' and sys.platform != 'cygwin'
+ sockets = []
+ if host == '':
+ host = None
+
+ infos = yield from self.getaddrinfo(
+ host, port, family=family,
+ type=socket.SOCK_STREAM, proto=0, flags=flags)
+ if not infos:
+ raise OSError('getaddrinfo() returned empty list')
+
+ completed = False
+ try:
+ for res in infos:
+ af, socktype, proto, canonname, sa = res
+ try:
+ sock = socket.socket(af, socktype, proto)
+ except socket.error:
+ # Assume it's a bad family/type/protocol combination.
+ if self._debug:
+ logger.warning('create_server() failed to create '
+ 'socket.socket(%r, %r, %r)',
+ af, socktype, proto, exc_info=True)
+ continue
+ sockets.append(sock)
+ if reuse_address:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
+ True)
+ # Disable IPv4/IPv6 dual stack support (enabled by
+ # default on Linux) which makes a single socket
+ # listen on both address families.
+ if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'):
+ sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_V6ONLY,
+ True)
+ try:
+ sock.bind(sa)
+ except OSError as err:
+ raise OSError(err.errno, 'error while attempting '
+ 'to bind on address %r: %s'
+ % (sa, err.strerror.lower()))
+ completed = True
+ finally:
+ if not completed:
+ for sock in sockets:
+ sock.close()
+ else:
+ if sock is None:
+ raise ValueError('Neither host/port nor sock were specified')
+ sockets = [sock]
+
+ server = Server(self, sockets)
+ for sock in sockets:
+ sock.listen(backlog)
+ sock.setblocking(False)
+ self._start_serving(protocol_factory, sock, ssl, server)
+ if self._debug:
+ logger.info("%r is serving", server)
+ return server
+
+ @coroutine
+ def connect_read_pipe(self, protocol_factory, pipe):
+ protocol = protocol_factory()
+ waiter = futures.Future(loop=self)
+ transport = self._make_read_pipe_transport(pipe, protocol, waiter)
+
+ try:
+ yield from waiter
+ except:
+ transport.close()
+ raise
+
+ if self._debug:
+ logger.debug('Read pipe %r connected: (%r, %r)',
+ pipe.fileno(), transport, protocol)
+ return transport, protocol
+
+ @coroutine
+ def connect_write_pipe(self, protocol_factory, pipe):
+ protocol = protocol_factory()
+ waiter = futures.Future(loop=self)
+ transport = self._make_write_pipe_transport(pipe, protocol, waiter)
+
+ try:
+ yield from waiter
+ except:
+ transport.close()
+ raise
+
+ if self._debug:
+ logger.debug('Write pipe %r connected: (%r, %r)',
+ pipe.fileno(), transport, protocol)
+ return transport, protocol
+
+ def _log_subprocess(self, msg, stdin, stdout, stderr):
+ info = [msg]
+ if stdin is not None:
+ info.append('stdin=%s' % _format_pipe(stdin))
+ if stdout is not None and stderr == subprocess.STDOUT:
+ info.append('stdout=stderr=%s' % _format_pipe(stdout))
+ else:
+ if stdout is not None:
+ info.append('stdout=%s' % _format_pipe(stdout))
+ if stderr is not None:
+ info.append('stderr=%s' % _format_pipe(stderr))
+ logger.debug(' '.join(info))
+
+ @coroutine
+ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ universal_newlines=False, shell=True, bufsize=0,
+ **kwargs):
+ if not isinstance(cmd, (bytes, str)):
+ raise ValueError("cmd must be a string")
+ if universal_newlines:
+ raise ValueError("universal_newlines must be False")
+ if not shell:
+ raise ValueError("shell must be True")
+ if bufsize != 0:
+ raise ValueError("bufsize must be 0")
+ protocol = protocol_factory()
+ if self._debug:
+ # don't log parameters: they may contain sensitive information
+ # (password) and may be too long
+ debug_log = 'run shell command %r' % cmd
+ self._log_subprocess(debug_log, stdin, stdout, stderr)
+ transport = yield from self._make_subprocess_transport(
+ protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs)
+ if self._debug:
+ logger.info('%s: %r' % (debug_log, transport))
+ return transport, protocol
+
+ @coroutine
+ def subprocess_exec(self, protocol_factory, program, *args,
+ stdin=subprocess.PIPE, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, universal_newlines=False,
+ shell=False, bufsize=0, **kwargs):
+ if universal_newlines:
+ raise ValueError("universal_newlines must be False")
+ if shell:
+ raise ValueError("shell must be False")
+ if bufsize != 0:
+ raise ValueError("bufsize must be 0")
+ popen_args = (program,) + args
+ for arg in popen_args:
+ if not isinstance(arg, (str, bytes)):
+ raise TypeError("program arguments must be "
+ "a bytes or text string, not %s"
+ % type(arg).__name__)
+ protocol = protocol_factory()
+ if self._debug:
+ # don't log parameters: they may contain sensitive information
+ # (password) and may be too long
+ debug_log = 'execute program %r' % program
+ self._log_subprocess(debug_log, stdin, stdout, stderr)
+ transport = yield from self._make_subprocess_transport(
+ protocol, popen_args, False, stdin, stdout, stderr,
+ bufsize, **kwargs)
+ if self._debug:
+ logger.info('%s: %r' % (debug_log, transport))
+ return transport, protocol
+
+ def set_exception_handler(self, handler):
+ """Set handler as the new event loop exception handler.
+
+ If handler is None, the default exception handler will
+ be set.
+
+ If handler is a callable object, it should have a
+ signature matching '(loop, context)', where 'loop'
+ will be a reference to the active event loop, 'context'
+ will be a dict object (see `call_exception_handler()`
+ documentation for details about context).
+ """
+ if handler is not None and not callable(handler):
+ raise TypeError('A callable object or None is expected, '
+ 'got {!r}'.format(handler))
+ self._exception_handler = handler
+
+ def default_exception_handler(self, context):
+ """Default exception handler.
+
+ This is called when an exception occurs and no exception
+ handler is set, and can be called by a custom exception
+ handler that wants to defer to the default behavior.
+
+ The context parameter has the same meaning as in
+ `call_exception_handler()`.
+ """
+ message = context.get('message')
+ if not message:
+ message = 'Unhandled exception in event loop'
+
+ exception = context.get('exception')
+ if exception is not None:
+ exc_info = (type(exception), exception, exception.__traceback__)
+ else:
+ exc_info = False
+
+ if ('source_traceback' not in context
+ and self._current_handle is not None
+ and self._current_handle._source_traceback):
+ context['handle_traceback'] = self._current_handle._source_traceback
+
+ log_lines = [message]
+ for key in sorted(context):
+ if key in {'message', 'exception'}:
+ continue
+ value = context[key]
+ if key == 'source_traceback':
+ tb = ''.join(traceback.format_list(value))
+ value = 'Object created at (most recent call last):\n'
+ value += tb.rstrip()
+ elif key == 'handle_traceback':
+ tb = ''.join(traceback.format_list(value))
+ value = 'Handle created at (most recent call last):\n'
+ value += tb.rstrip()
+ else:
+ value = repr(value)
+ log_lines.append('{}: {}'.format(key, value))
+
+ logger.error('\n'.join(log_lines), exc_info=exc_info)
+
+ def call_exception_handler(self, context):
+ """Call the current event loop's exception handler.
+
+ The context argument is a dict containing the following keys:
+
+ - 'message': Error message;
+ - 'exception' (optional): Exception object;
+ - 'future' (optional): Future instance;
+ - 'handle' (optional): Handle instance;
+ - 'protocol' (optional): Protocol instance;
+ - 'transport' (optional): Transport instance;
+ - 'socket' (optional): Socket instance.
+
+ New keys maybe introduced in the future.
+
+ Note: do not overload this method in an event loop subclass.
+ For custom exception handling, use the
+ `set_exception_handler()` method.
+ """
+ if self._exception_handler is None:
+ try:
+ self.default_exception_handler(context)
+ except Exception:
+ # Second protection layer for unexpected errors
+ # in the default implementation, as well as for subclassed
+ # event loops with overloaded "default_exception_handler".
+ logger.error('Exception in default exception handler',
+ exc_info=True)
+ else:
+ try:
+ self._exception_handler(self, context)
+ except Exception as exc:
+ # Exception in the user set custom exception handler.
+ try:
+ # Let's try default handler.
+ self.default_exception_handler({
+ 'message': 'Unhandled error in exception handler',
+ 'exception': exc,
+ 'context': context,
+ })
+ except Exception:
+ # Guard 'default_exception_handler' in case it is
+ # overloaded.
+ logger.error('Exception in default exception handler '
+ 'while handling an unexpected error '
+ 'in custom exception handler',
+ exc_info=True)
+
+ def _add_callback(self, handle):
+ """Add a Handle to _scheduled (TimerHandle) or _ready."""
+ assert isinstance(handle, events.Handle), 'A Handle is required here'
+ if handle._cancelled:
+ return
+ assert not isinstance(handle, events.TimerHandle)
+ self._ready.append(handle)
+
+ def _add_callback_signalsafe(self, handle):
+ """Like _add_callback() but called from a signal handler."""
+ self._add_callback(handle)
+ self._write_to_self()
+
+ def _timer_handle_cancelled(self, handle):
+ """Notification that a TimerHandle has been cancelled."""
+ if handle._scheduled:
+ self._timer_cancelled_count += 1
+
+ def _run_once(self):
+ """Run one full iteration of the event loop.
+
+ This calls all currently ready callbacks, polls for I/O,
+ schedules the resulting callbacks, and finally schedules
+ 'call_later' callbacks.
+ """
+
+ sched_count = len(self._scheduled)
+ if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and
+ self._timer_cancelled_count / sched_count >
+ _MIN_CANCELLED_TIMER_HANDLES_FRACTION):
+ # Remove delayed calls that were cancelled if their number
+ # is too high
+ new_scheduled = []
+ for handle in self._scheduled:
+ if handle._cancelled:
+ handle._scheduled = False
+ else:
+ new_scheduled.append(handle)
+
+ heapq.heapify(new_scheduled)
+ self._scheduled = new_scheduled
+ self._timer_cancelled_count = 0
+ else:
+ # Remove delayed calls that were cancelled from head of queue.
+ while self._scheduled and self._scheduled[0]._cancelled:
+ self._timer_cancelled_count -= 1
+ handle = heapq.heappop(self._scheduled)
+ handle._scheduled = False
+
+ timeout = None
+ if self._ready:
+ timeout = 0
+ elif self._scheduled:
+ # Compute the desired timeout.
+ when = self._scheduled[0]._when
+ timeout = max(0, when - self.time())
+
+ if self._debug and timeout != 0:
+ t0 = self.time()
+ event_list = self._selector.select(timeout)
+ dt = self.time() - t0
+ if dt >= 1.0:
+ level = logging.INFO
+ else:
+ level = logging.DEBUG
+ nevent = len(event_list)
+ if timeout is None:
+ logger.log(level, 'poll took %.3f ms: %s events',
+ dt * 1e3, nevent)
+ elif nevent:
+ logger.log(level,
+ 'poll %.3f ms took %.3f ms: %s events',
+ timeout * 1e3, dt * 1e3, nevent)
+ elif dt >= 1.0:
+ logger.log(level,
+ 'poll %.3f ms took %.3f ms: timeout',
+ timeout * 1e3, dt * 1e3)
+ else:
+ event_list = self._selector.select(timeout)
+ self._process_events(event_list)
+
+ # Handle 'later' callbacks that are ready.
+ end_time = self.time() + self._clock_resolution
+ while self._scheduled:
+ handle = self._scheduled[0]
+ if handle._when >= end_time:
+ break
+ handle = heapq.heappop(self._scheduled)
+ handle._scheduled = False
+ self._ready.append(handle)
+
+ # This is the only place where callbacks are actually *called*.
+ # All other places just add them to ready.
+ # Note: We run all currently scheduled callbacks, but not any
+ # callbacks scheduled by callbacks run this time around --
+ # they will be run the next time (after another I/O poll).
+ # Use an idiom that is thread-safe without using locks.
+ ntodo = len(self._ready)
+ for i in range(ntodo):
+ handle = self._ready.popleft()
+ if handle._cancelled:
+ continue
+ if self._debug:
+ try:
+ self._current_handle = handle
+ t0 = self.time()
+ handle._run()
+ dt = self.time() - t0
+ if dt >= self.slow_callback_duration:
+ logger.warning('Executing %s took %.3f seconds',
+ _format_handle(handle), dt)
+ finally:
+ self._current_handle = None
+ else:
+ handle._run()
+ handle = None # Needed to break cycles when an exception occurs.
+
+ def _set_coroutine_wrapper(self, enabled):
+ try:
+ set_wrapper = sys.set_coroutine_wrapper
+ get_wrapper = sys.get_coroutine_wrapper
+ except AttributeError:
+ return
+
+ enabled = bool(enabled)
+ if self._coroutine_wrapper_set is enabled:
+ return
+
+ wrapper = coroutines.debug_wrapper
+ current_wrapper = get_wrapper()
+
+ if enabled:
+ if current_wrapper not in (None, wrapper):
+ warnings.warn(
+ "loop.set_debug(True): cannot set debug coroutine "
+ "wrapper; another wrapper is already set %r" %
+ current_wrapper, RuntimeWarning)
+ else:
+ set_wrapper(wrapper)
+ self._coroutine_wrapper_set = True
+ else:
+ if current_wrapper not in (None, wrapper):
+ warnings.warn(
+ "loop.set_debug(False): cannot unset debug coroutine "
+ "wrapper; another wrapper was set %r" %
+ current_wrapper, RuntimeWarning)
+ else:
+ set_wrapper(None)
+ self._coroutine_wrapper_set = False
+
+ def get_debug(self):
+ return self._debug
+
+ def set_debug(self, enabled):
+ self._debug = enabled
+
+ if self.is_running():
+ self._set_coroutine_wrapper(enabled)
diff --git a/trollius/base_subprocess.py b/trollius/base_subprocess.py
new file mode 100644
index 0000000..c1477b8
--- /dev/null
+++ b/trollius/base_subprocess.py
@@ -0,0 +1,275 @@
+import collections
+import subprocess
+import sys
+import warnings
+
+from . import futures
+from . import protocols
+from . import transports
+from .coroutines import coroutine
+from .log import logger
+
+
+class BaseSubprocessTransport(transports.SubprocessTransport):
+
+ def __init__(self, loop, protocol, args, shell,
+ stdin, stdout, stderr, bufsize,
+ waiter=None, extra=None, **kwargs):
+ super().__init__(extra)
+ self._closed = False
+ self._protocol = protocol
+ self._loop = loop
+ self._proc = None
+ self._pid = None
+ self._returncode = None
+ self._exit_waiters = []
+ self._pending_calls = collections.deque()
+ self._pipes = {}
+ self._finished = False
+
+ if stdin == subprocess.PIPE:
+ self._pipes[0] = None
+ if stdout == subprocess.PIPE:
+ self._pipes[1] = None
+ if stderr == subprocess.PIPE:
+ self._pipes[2] = None
+
+ # Create the child process: set the _proc attribute
+ self._start(args=args, shell=shell, stdin=stdin, stdout=stdout,
+ stderr=stderr, bufsize=bufsize, **kwargs)
+ self._pid = self._proc.pid
+ self._extra['subprocess'] = self._proc
+
+ if self._loop.get_debug():
+ if isinstance(args, (bytes, str)):
+ program = args
+ else:
+ program = args[0]
+ logger.debug('process %r created: pid %s',
+ program, self._pid)
+
+ self._loop.create_task(self._connect_pipes(waiter))
+
+ def __repr__(self):
+ info = [self.__class__.__name__]
+ if self._closed:
+ info.append('closed')
+ if self._pid is not None:
+ info.append('pid=%s' % self._pid)
+ if self._returncode is not None:
+ info.append('returncode=%s' % self._returncode)
+ elif self._pid is not None:
+ info.append('running')
+ else:
+ info.append('not started')
+
+ stdin = self._pipes.get(0)
+ if stdin is not None:
+ info.append('stdin=%s' % stdin.pipe)
+
+ stdout = self._pipes.get(1)
+ stderr = self._pipes.get(2)
+ if stdout is not None and stderr is stdout:
+ info.append('stdout=stderr=%s' % stdout.pipe)
+ else:
+ if stdout is not None:
+ info.append('stdout=%s' % stdout.pipe)
+ if stderr is not None:
+ info.append('stderr=%s' % stderr.pipe)
+
+ return '<%s>' % ' '.join(info)
+
+ def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs):
+ raise NotImplementedError
+
+ def close(self):
+ if self._closed:
+ return
+ self._closed = True
+
+ for proto in self._pipes.values():
+ if proto is None:
+ continue
+ proto.pipe.close()
+
+ if (self._proc is not None
+ # the child process finished?
+ and self._returncode is None
+ # the child process finished but the transport was not notified yet?
+ and self._proc.poll() is None
+ ):
+ if self._loop.get_debug():
+ logger.warning('Close running child process: kill %r', self)
+
+ try:
+ self._proc.kill()
+ except ProcessLookupError:
+ pass
+
+ # Don't clear the _proc reference yet: _post_init() may still run
+
+ # On Python 3.3 and older, objects with a destructor part of a reference
+ # cycle are never destroyed. It's not more the case on Python 3.4 thanks
+ # to the PEP 442.
+ if sys.version_info >= (3, 4):
+ def __del__(self):
+ if not self._closed:
+ warnings.warn("unclosed transport %r" % self, ResourceWarning)
+ self.close()
+
+ def get_pid(self):
+ return self._pid
+
+ def get_returncode(self):
+ return self._returncode
+
+ def get_pipe_transport(self, fd):
+ if fd in self._pipes:
+ return self._pipes[fd].pipe
+ else:
+ return None
+
+ def _check_proc(self):
+ if self._proc is None:
+ raise ProcessLookupError()
+
+ def send_signal(self, signal):
+ self._check_proc()
+ self._proc.send_signal(signal)
+
+ def terminate(self):
+ self._check_proc()
+ self._proc.terminate()
+
+ def kill(self):
+ self._check_proc()
+ self._proc.kill()
+
+ @coroutine
+ def _connect_pipes(self, waiter):
+ try:
+ proc = self._proc
+ loop = self._loop
+
+ if proc.stdin is not None:
+ _, pipe = yield from loop.connect_write_pipe(
+ lambda: WriteSubprocessPipeProto(self, 0),
+ proc.stdin)
+ self._pipes[0] = pipe
+
+ if proc.stdout is not None:
+ _, pipe = yield from loop.connect_read_pipe(
+ lambda: ReadSubprocessPipeProto(self, 1),
+ proc.stdout)
+ self._pipes[1] = pipe
+
+ if proc.stderr is not None:
+ _, pipe = yield from loop.connect_read_pipe(
+ lambda: ReadSubprocessPipeProto(self, 2),
+ proc.stderr)
+ self._pipes[2] = pipe
+
+ assert self._pending_calls is not None
+
+ loop.call_soon(self._protocol.connection_made, self)
+ for callback, data in self._pending_calls:
+ loop.call_soon(callback, *data)
+ self._pending_calls = None
+ except Exception as exc:
+ if waiter is not None and not waiter.cancelled():
+ waiter.set_exception(exc)
+ else:
+ if waiter is not None and not waiter.cancelled():
+ waiter.set_result(None)
+
+ def _call(self, cb, *data):
+ if self._pending_calls is not None:
+ self._pending_calls.append((cb, data))
+ else:
+ self._loop.call_soon(cb, *data)
+
+ def _pipe_connection_lost(self, fd, exc):
+ self._call(self._protocol.pipe_connection_lost, fd, exc)
+ self._try_finish()
+
+ def _pipe_data_received(self, fd, data):
+ self._call(self._protocol.pipe_data_received, fd, data)
+
+ def _process_exited(self, returncode):
+ assert returncode is not None, returncode
+ assert self._returncode is None, self._returncode
+ if self._loop.get_debug():
+ logger.info('%r exited with return code %r',
+ self, returncode)
+ self._returncode = returncode
+ self._call(self._protocol.process_exited)
+ self._try_finish()
+
+ # wake up futures waiting for wait()
+ for waiter in self._exit_waiters:
+ if not waiter.cancelled():
+ waiter.set_result(returncode)
+ self._exit_waiters = None
+
+ @coroutine
+ def _wait(self):
+ """Wait until the process exit and return the process return code.
+
+ This method is a coroutine."""
+ if self._returncode is not None:
+ return self._returncode
+
+ waiter = futures.Future(loop=self._loop)
+ self._exit_waiters.append(waiter)
+ return (yield from waiter)
+
+ def _try_finish(self):
+ assert not self._finished
+ if self._returncode is None:
+ return
+ if all(p is not None and p.disconnected
+ for p in self._pipes.values()):
+ self._finished = True
+ self._call(self._call_connection_lost, None)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._loop = None
+ self._proc = None
+ self._protocol = None
+
+
+class WriteSubprocessPipeProto(protocols.BaseProtocol):
+
+ def __init__(self, proc, fd):
+ self.proc = proc
+ self.fd = fd
+ self.pipe = None
+ self.disconnected = False
+
+ def connection_made(self, transport):
+ self.pipe = transport
+
+ def __repr__(self):
+ return ('<%s fd=%s pipe=%r>'
+ % (self.__class__.__name__, self.fd, self.pipe))
+
+ def connection_lost(self, exc):
+ self.disconnected = True
+ self.proc._pipe_connection_lost(self.fd, exc)
+ self.proc = None
+
+ def pause_writing(self):
+ self.proc._protocol.pause_writing()
+
+ def resume_writing(self):
+ self.proc._protocol.resume_writing()
+
+
+class ReadSubprocessPipeProto(WriteSubprocessPipeProto,
+ protocols.Protocol):
+
+ def data_received(self, data):
+ self.proc._pipe_data_received(self.fd, data)
diff --git a/trollius/constants.py b/trollius/constants.py
new file mode 100644
index 0000000..f9e1232
--- /dev/null
+++ b/trollius/constants.py
@@ -0,0 +1,7 @@
+"""Constants."""
+
+# After the connection is lost, log warnings after this many write()s.
+LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5
+
+# Seconds to wait before retrying accept().
+ACCEPT_RETRY_DELAY = 1
diff --git a/trollius/coroutines.py b/trollius/coroutines.py
new file mode 100644
index 0000000..15475f2
--- /dev/null
+++ b/trollius/coroutines.py
@@ -0,0 +1,301 @@
+__all__ = ['coroutine',
+ 'iscoroutinefunction', 'iscoroutine']
+
+import functools
+import inspect
+import opcode
+import os
+import sys
+import traceback
+import types
+
+from . import events
+from . import futures
+from .log import logger
+
+
+_PY35 = sys.version_info >= (3, 5)
+
+
+# Opcode of "yield from" instruction
+_YIELD_FROM = opcode.opmap['YIELD_FROM']
+
+# If you set _DEBUG to true, @coroutine will wrap the resulting
+# generator objects in a CoroWrapper instance (defined below). That
+# instance will log a message when the generator is never iterated
+# over, which may happen when you forget to use "yield from" with a
+# coroutine call. Note that the value of the _DEBUG flag is taken
+# when the decorator is used, so to be of any use it must be set
+# before you define your coroutines. A downside of using this feature
+# is that tracebacks show entries for the CoroWrapper.__next__ method
+# when _DEBUG is true.
+_DEBUG = (not sys.flags.ignore_environment
+ and bool(os.environ.get('PYTHONASYNCIODEBUG')))
+
+
+try:
+ _types_coroutine = types.coroutine
+except AttributeError:
+ _types_coroutine = None
+
+try:
+ _inspect_iscoroutinefunction = inspect.iscoroutinefunction
+except AttributeError:
+ _inspect_iscoroutinefunction = lambda func: False
+
+try:
+ from collections.abc import Coroutine as _CoroutineABC, \
+ Awaitable as _AwaitableABC
+except ImportError:
+ _CoroutineABC = _AwaitableABC = None
+
+
+# Check for CPython issue #21209
+def has_yield_from_bug():
+ class MyGen:
+ def __init__(self):
+ self.send_args = None
+ def __iter__(self):
+ return self
+ def __next__(self):
+ return 42
+ def send(self, *what):
+ self.send_args = what
+ return None
+ def yield_from_gen(gen):
+ yield from gen
+ value = (1, 2, 3)
+ gen = MyGen()
+ coro = yield_from_gen(gen)
+ next(coro)
+ coro.send(value)
+ return gen.send_args != (value,)
+_YIELD_FROM_BUG = has_yield_from_bug()
+del has_yield_from_bug
+
+
+def debug_wrapper(gen):
+ # This function is called from 'sys.set_coroutine_wrapper'.
+ # We only wrap here coroutines defined via 'async def' syntax.
+ # Generator-based coroutines are wrapped in @coroutine
+ # decorator.
+ return CoroWrapper(gen, None)
+
+
+class CoroWrapper:
+ # Wrapper for coroutine object in _DEBUG mode.
+
+ def __init__(self, gen, func=None):
+ assert inspect.isgenerator(gen) or inspect.iscoroutine(gen), gen
+ self.gen = gen
+ self.func = func # Used to unwrap @coroutine decorator
+ self._source_traceback = traceback.extract_stack(sys._getframe(1))
+ self.__name__ = getattr(gen, '__name__', None)
+ self.__qualname__ = getattr(gen, '__qualname__', None)
+
+ def __repr__(self):
+ coro_repr = _format_coroutine(self)
+ if self._source_traceback:
+ frame = self._source_traceback[-1]
+ coro_repr += ', created at %s:%s' % (frame[0], frame[1])
+ return '<%s %s>' % (self.__class__.__name__, coro_repr)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return self.gen.send(None)
+
+ if _YIELD_FROM_BUG:
+ # For for CPython issue #21209: using "yield from" and a custom
+ # generator, generator.send(tuple) unpacks the tuple instead of passing
+ # the tuple unchanged. Check if the caller is a generator using "yield
+ # from" to decide if the parameter should be unpacked or not.
+ def send(self, *value):
+ frame = sys._getframe()
+ caller = frame.f_back
+ assert caller.f_lasti >= 0
+ if caller.f_code.co_code[caller.f_lasti] != _YIELD_FROM:
+ value = value[0]
+ return self.gen.send(value)
+ else:
+ def send(self, value):
+ return self.gen.send(value)
+
+ def throw(self, exc):
+ return self.gen.throw(exc)
+
+ def close(self):
+ return self.gen.close()
+
+ @property
+ def gi_frame(self):
+ return self.gen.gi_frame
+
+ @property
+ def gi_running(self):
+ return self.gen.gi_running
+
+ @property
+ def gi_code(self):
+ return self.gen.gi_code
+
+ if _PY35:
+
+ __await__ = __iter__ # make compatible with 'await' expression
+
+ @property
+ def gi_yieldfrom(self):
+ return self.gen.gi_yieldfrom
+
+ @property
+ def cr_await(self):
+ return self.gen.cr_await
+
+ @property
+ def cr_running(self):
+ return self.gen.cr_running
+
+ @property
+ def cr_code(self):
+ return self.gen.cr_code
+
+ @property
+ def cr_frame(self):
+ return self.gen.cr_frame
+
+ def __del__(self):
+ # Be careful accessing self.gen.frame -- self.gen might not exist.
+ gen = getattr(self, 'gen', None)
+ frame = getattr(gen, 'gi_frame', None)
+ if frame is None:
+ frame = getattr(gen, 'cr_frame', None)
+ if frame is not None and frame.f_lasti == -1:
+ msg = '%r was never yielded from' % self
+ tb = getattr(self, '_source_traceback', ())
+ if tb:
+ tb = ''.join(traceback.format_list(tb))
+ msg += ('\nCoroutine object created at '
+ '(most recent call last):\n')
+ msg += tb.rstrip()
+ logger.error(msg)
+
+
+def coroutine(func):
+ """Decorator to mark coroutines.
+
+ If the coroutine is not yielded from before it is destroyed,
+ an error message is logged.
+ """
+ if _inspect_iscoroutinefunction(func):
+ # In Python 3.5 that's all we need to do for coroutines
+ # defiend with "async def".
+ # Wrapping in CoroWrapper will happen via
+ # 'sys.set_coroutine_wrapper' function.
+ return func
+
+ if inspect.isgeneratorfunction(func):
+ coro = func
+ else:
+ @functools.wraps(func)
+ def coro(*args, **kw):
+ res = func(*args, **kw)
+ if isinstance(res, futures.Future) or inspect.isgenerator(res):
+ res = yield from res
+ elif _AwaitableABC is not None:
+ # If 'func' returns an Awaitable (new in 3.5) we
+ # want to run it.
+ try:
+ await_meth = res.__await__
+ except AttributeError:
+ pass
+ else:
+ if isinstance(res, _AwaitableABC):
+ res = yield from await_meth()
+ return res
+
+ if not _DEBUG:
+ if _types_coroutine is None:
+ wrapper = coro
+ else:
+ wrapper = _types_coroutine(coro)
+ else:
+ @functools.wraps(func)
+ def wrapper(*args, **kwds):
+ w = CoroWrapper(coro(*args, **kwds), func=func)
+ if w._source_traceback:
+ del w._source_traceback[-1]
+ # Python < 3.5 does not implement __qualname__
+ # on generator objects, so we set it manually.
+ # We use getattr as some callables (such as
+ # functools.partial may lack __qualname__).
+ w.__name__ = getattr(func, '__name__', None)
+ w.__qualname__ = getattr(func, '__qualname__', None)
+ return w
+
+ wrapper._is_coroutine = True # For iscoroutinefunction().
+ return wrapper
+
+
+def iscoroutinefunction(func):
+ """Return True if func is a decorated coroutine function."""
+ return (getattr(func, '_is_coroutine', False) or
+ _inspect_iscoroutinefunction(func))
+
+
+_COROUTINE_TYPES = (types.GeneratorType, CoroWrapper)
+if _CoroutineABC is not None:
+ _COROUTINE_TYPES += (_CoroutineABC,)
+
+
+def iscoroutine(obj):
+ """Return True if obj is a coroutine object."""
+ return isinstance(obj, _COROUTINE_TYPES)
+
+
+def _format_coroutine(coro):
+ assert iscoroutine(coro)
+
+ coro_name = None
+ if isinstance(coro, CoroWrapper):
+ func = coro.func
+ coro_name = coro.__qualname__
+ if coro_name is not None:
+ coro_name = '{}()'.format(coro_name)
+ else:
+ func = coro
+
+ if coro_name is None:
+ coro_name = events._format_callback(func, ())
+
+ try:
+ coro_code = coro.gi_code
+ except AttributeError:
+ coro_code = coro.cr_code
+
+ try:
+ coro_frame = coro.gi_frame
+ except AttributeError:
+ coro_frame = coro.cr_frame
+
+ filename = coro_code.co_filename
+ if (isinstance(coro, CoroWrapper)
+ and not inspect.isgeneratorfunction(coro.func)
+ and coro.func is not None):
+ filename, lineno = events._get_function_source(coro.func)
+ if coro_frame is None:
+ coro_repr = ('%s done, defined at %s:%s'
+ % (coro_name, filename, lineno))
+ else:
+ coro_repr = ('%s running, defined at %s:%s'
+ % (coro_name, filename, lineno))
+ elif coro_frame is not None:
+ lineno = coro_frame.f_lineno
+ coro_repr = ('%s running at %s:%s'
+ % (coro_name, filename, lineno))
+ else:
+ lineno = coro_code.co_firstlineno
+ coro_repr = ('%s done, defined at %s:%s'
+ % (coro_name, filename, lineno))
+
+ return coro_repr
diff --git a/trollius/events.py b/trollius/events.py
new file mode 100644
index 0000000..496075b
--- /dev/null
+++ b/trollius/events.py
@@ -0,0 +1,611 @@
+"""Event loop and event loop policy."""
+
+__all__ = ['AbstractEventLoopPolicy',
+ 'AbstractEventLoop', 'AbstractServer',
+ 'Handle', 'TimerHandle',
+ 'get_event_loop_policy', 'set_event_loop_policy',
+ 'get_event_loop', 'set_event_loop', 'new_event_loop',
+ 'get_child_watcher', 'set_child_watcher',
+ ]
+
+import functools
+import inspect
+import reprlib
+import socket
+import subprocess
+import sys
+import threading
+import traceback
+
+
+_PY34 = sys.version_info >= (3, 4)
+
+
+def _get_function_source(func):
+ if _PY34:
+ func = inspect.unwrap(func)
+ elif hasattr(func, '__wrapped__'):
+ func = func.__wrapped__
+ if inspect.isfunction(func):
+ code = func.__code__
+ return (code.co_filename, code.co_firstlineno)
+ if isinstance(func, functools.partial):
+ return _get_function_source(func.func)
+ if _PY34 and isinstance(func, functools.partialmethod):
+ return _get_function_source(func.func)
+ return None
+
+
+def _format_args(args):
+ """Format function arguments.
+
+ Special case for a single parameter: ('hello',) is formatted as ('hello').
+ """
+ # use reprlib to limit the length of the output
+ args_repr = reprlib.repr(args)
+ if len(args) == 1 and args_repr.endswith(',)'):
+ args_repr = args_repr[:-2] + ')'
+ return args_repr
+
+
+def _format_callback(func, args, suffix=''):
+ if isinstance(func, functools.partial):
+ if args is not None:
+ suffix = _format_args(args) + suffix
+ return _format_callback(func.func, func.args, suffix)
+
+ if hasattr(func, '__qualname__'):
+ func_repr = getattr(func, '__qualname__')
+ elif hasattr(func, '__name__'):
+ func_repr = getattr(func, '__name__')
+ else:
+ func_repr = repr(func)
+
+ if args is not None:
+ func_repr += _format_args(args)
+ if suffix:
+ func_repr += suffix
+ return func_repr
+
+def _format_callback_source(func, args):
+ func_repr = _format_callback(func, args)
+ source = _get_function_source(func)
+ if source:
+ func_repr += ' at %s:%s' % source
+ return func_repr
+
+
+class Handle:
+ """Object returned by callback registration methods."""
+
+ __slots__ = ('_callback', '_args', '_cancelled', '_loop',
+ '_source_traceback', '_repr', '__weakref__')
+
+ def __init__(self, callback, args, loop):
+ assert not isinstance(callback, Handle), 'A Handle is not a callback'
+ self._loop = loop
+ self._callback = callback
+ self._args = args
+ self._cancelled = False
+ self._repr = None
+ if self._loop.get_debug():
+ self._source_traceback = traceback.extract_stack(sys._getframe(1))
+ else:
+ self._source_traceback = None
+
+ def _repr_info(self):
+ info = [self.__class__.__name__]
+ if self._cancelled:
+ info.append('cancelled')
+ if self._callback is not None:
+ info.append(_format_callback_source(self._callback, self._args))
+ if self._source_traceback:
+ frame = self._source_traceback[-1]
+ info.append('created at %s:%s' % (frame[0], frame[1]))
+ return info
+
+ def __repr__(self):
+ if self._repr is not None:
+ return self._repr
+ info = self._repr_info()
+ return '<%s>' % ' '.join(info)
+
+ def cancel(self):
+ if not self._cancelled:
+ self._cancelled = True
+ if self._loop.get_debug():
+ # Keep a representation in debug mode to keep callback and
+ # parameters. For example, to log the warning
+ # "Executing <Handle...> took 2.5 second"
+ self._repr = repr(self)
+ self._callback = None
+ self._args = None
+
+ def _run(self):
+ try:
+ self._callback(*self._args)
+ except Exception as exc:
+ cb = _format_callback_source(self._callback, self._args)
+ msg = 'Exception in callback {}'.format(cb)
+ context = {
+ 'message': msg,
+ 'exception': exc,
+ 'handle': self,
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ self = None # Needed to break cycles when an exception occurs.
+
+
+class TimerHandle(Handle):
+ """Object returned by timed callback registration methods."""
+
+ __slots__ = ['_scheduled', '_when']
+
+ def __init__(self, when, callback, args, loop):
+ assert when is not None
+ super().__init__(callback, args, loop)
+ if self._source_traceback:
+ del self._source_traceback[-1]
+ self._when = when
+ self._scheduled = False
+
+ def _repr_info(self):
+ info = super()._repr_info()
+ pos = 2 if self._cancelled else 1
+ info.insert(pos, 'when=%s' % self._when)
+ return info
+
+ def __hash__(self):
+ return hash(self._when)
+
+ def __lt__(self, other):
+ return self._when < other._when
+
+ def __le__(self, other):
+ if self._when < other._when:
+ return True
+ return self.__eq__(other)
+
+ def __gt__(self, other):
+ return self._when > other._when
+
+ def __ge__(self, other):
+ if self._when > other._when:
+ return True
+ return self.__eq__(other)
+
+ def __eq__(self, other):
+ if isinstance(other, TimerHandle):
+ return (self._when == other._when and
+ self._callback == other._callback and
+ self._args == other._args and
+ self._cancelled == other._cancelled)
+ return NotImplemented
+
+ def __ne__(self, other):
+ equal = self.__eq__(other)
+ return NotImplemented if equal is NotImplemented else not equal
+
+ def cancel(self):
+ if not self._cancelled:
+ self._loop._timer_handle_cancelled(self)
+ super().cancel()
+
+
+class AbstractServer:
+ """Abstract server returned by create_server()."""
+
+ def close(self):
+ """Stop serving. This leaves existing connections open."""
+ return NotImplemented
+
+ def wait_closed(self):
+ """Coroutine to wait until service is closed."""
+ return NotImplemented
+
+
+class AbstractEventLoop:
+ """Abstract event loop."""
+
+ # Running and stopping the event loop.
+
+ def run_forever(self):
+ """Run the event loop until stop() is called."""
+ raise NotImplementedError
+
+ def run_until_complete(self, future):
+ """Run the event loop until a Future is done.
+
+ Return the Future's result, or raise its exception.
+ """
+ raise NotImplementedError
+
+ def stop(self):
+ """Stop the event loop as soon as reasonable.
+
+ Exactly how soon that is may depend on the implementation, but
+ no more I/O callbacks should be scheduled.
+ """
+ raise NotImplementedError
+
+ def is_running(self):
+ """Return whether the event loop is currently running."""
+ raise NotImplementedError
+
+ def is_closed(self):
+ """Returns True if the event loop was closed."""
+ raise NotImplementedError
+
+ def close(self):
+ """Close the loop.
+
+ The loop should not be running.
+
+ This is idempotent and irreversible.
+
+ No other methods should be called after this one.
+ """
+ raise NotImplementedError
+
+ # Methods scheduling callbacks. All these return Handles.
+
+ def _timer_handle_cancelled(self, handle):
+ """Notification that a TimerHandle has been cancelled."""
+ raise NotImplementedError
+
+ def call_soon(self, callback, *args):
+ return self.call_later(0, callback, *args)
+
+ def call_later(self, delay, callback, *args):
+ raise NotImplementedError
+
+ def call_at(self, when, callback, *args):
+ raise NotImplementedError
+
+ def time(self):
+ raise NotImplementedError
+
+ # Method scheduling a coroutine object: create a task.
+
+ def create_task(self, coro):
+ raise NotImplementedError
+
+ # Methods for interacting with threads.
+
+ def call_soon_threadsafe(self, callback, *args):
+ raise NotImplementedError
+
+ def run_in_executor(self, executor, func, *args):
+ raise NotImplementedError
+
+ def set_default_executor(self, executor):
+ raise NotImplementedError
+
+ # Network I/O methods returning Futures.
+
+ def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0):
+ raise NotImplementedError
+
+ def getnameinfo(self, sockaddr, flags=0):
+ raise NotImplementedError
+
+ def create_connection(self, protocol_factory, host=None, port=None, *,
+ ssl=None, family=0, proto=0, flags=0, sock=None,
+ local_addr=None, server_hostname=None):
+ raise NotImplementedError
+
+ def create_server(self, protocol_factory, host=None, port=None, *,
+ family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE,
+ sock=None, backlog=100, ssl=None, reuse_address=None):
+ """A coroutine which creates a TCP server bound to host and port.
+
+ The return value is a Server object which can be used to stop
+ the service.
+
+ If host is an empty string or None all interfaces are assumed
+ and a list of multiple sockets will be returned (most likely
+ one for IPv4 and another one for IPv6).
+
+ family can be set to either AF_INET or AF_INET6 to force the
+ socket to use IPv4 or IPv6. If not set it will be determined
+ from host (defaults to AF_UNSPEC).
+
+ flags is a bitmask for getaddrinfo().
+
+ sock can optionally be specified in order to use a preexisting
+ socket object.
+
+ backlog is the maximum number of queued connections passed to
+ listen() (defaults to 100).
+
+ ssl can be set to an SSLContext to enable SSL over the
+ accepted connections.
+
+ reuse_address tells the kernel to reuse a local socket in
+ TIME_WAIT state, without waiting for its natural timeout to
+ expire. If not specified will automatically be set to True on
+ UNIX.
+ """
+ raise NotImplementedError
+
+ def create_unix_connection(self, protocol_factory, path, *,
+ ssl=None, sock=None,
+ server_hostname=None):
+ raise NotImplementedError
+
+ def create_unix_server(self, protocol_factory, path, *,
+ sock=None, backlog=100, ssl=None):
+ """A coroutine which creates a UNIX Domain Socket server.
+
+ The return value is a Server object, which can be used to stop
+ the service.
+
+ path is a str, representing a file systsem path to bind the
+ server socket to.
+
+ sock can optionally be specified in order to use a preexisting
+ socket object.
+
+ backlog is the maximum number of queued connections passed to
+ listen() (defaults to 100).
+
+ ssl can be set to an SSLContext to enable SSL over the
+ accepted connections.
+ """
+ raise NotImplementedError
+
+ def create_datagram_endpoint(self, protocol_factory,
+ local_addr=None, remote_addr=None, *,
+ family=0, proto=0, flags=0):
+ raise NotImplementedError
+
+ # Pipes and subprocesses.
+
+ def connect_read_pipe(self, protocol_factory, pipe):
+ """Register read pipe in event loop. Set the pipe to non-blocking mode.
+
+ protocol_factory should instantiate object with Protocol interface.
+ pipe is a file-like object.
+ Return pair (transport, protocol), where transport supports the
+ ReadTransport interface."""
+ # The reason to accept file-like object instead of just file descriptor
+ # is: we need to own pipe and close it at transport finishing
+ # Can got complicated errors if pass f.fileno(),
+ # close fd in pipe transport then close f and vise versa.
+ raise NotImplementedError
+
+ def connect_write_pipe(self, protocol_factory, pipe):
+ """Register write pipe in event loop.
+
+ protocol_factory should instantiate object with BaseProtocol interface.
+ Pipe is file-like object already switched to nonblocking.
+ Return pair (transport, protocol), where transport support
+ WriteTransport interface."""
+ # The reason to accept file-like object instead of just file descriptor
+ # is: we need to own pipe and close it at transport finishing
+ # Can got complicated errors if pass f.fileno(),
+ # close fd in pipe transport then close f and vise versa.
+ raise NotImplementedError
+
+ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ **kwargs):
+ raise NotImplementedError
+
+ def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ **kwargs):
+ raise NotImplementedError
+
+ # Ready-based callback registration methods.
+ # The add_*() methods return None.
+ # The remove_*() methods return True if something was removed,
+ # False if there was nothing to delete.
+
+ def add_reader(self, fd, callback, *args):
+ raise NotImplementedError
+
+ def remove_reader(self, fd):
+ raise NotImplementedError
+
+ def add_writer(self, fd, callback, *args):
+ raise NotImplementedError
+
+ def remove_writer(self, fd):
+ raise NotImplementedError
+
+ # Completion based I/O methods returning Futures.
+
+ def sock_recv(self, sock, nbytes):
+ raise NotImplementedError
+
+ def sock_sendall(self, sock, data):
+ raise NotImplementedError
+
+ def sock_connect(self, sock, address):
+ raise NotImplementedError
+
+ def sock_accept(self, sock):
+ raise NotImplementedError
+
+ # Signal handling.
+
+ def add_signal_handler(self, sig, callback, *args):
+ raise NotImplementedError
+
+ def remove_signal_handler(self, sig):
+ raise NotImplementedError
+
+ # Task factory.
+
+ def set_task_factory(self, factory):
+ raise NotImplementedError
+
+ def get_task_factory(self):
+ raise NotImplementedError
+
+ # Error handlers.
+
+ def set_exception_handler(self, handler):
+ raise NotImplementedError
+
+ def default_exception_handler(self, context):
+ raise NotImplementedError
+
+ def call_exception_handler(self, context):
+ raise NotImplementedError
+
+ # Debug flag management.
+
+ def get_debug(self):
+ raise NotImplementedError
+
+ def set_debug(self, enabled):
+ raise NotImplementedError
+
+
+class AbstractEventLoopPolicy:
+ """Abstract policy for accessing the event loop."""
+
+ def get_event_loop(self):
+ """Get the event loop for the current context.
+
+ Returns an event loop object implementing the BaseEventLoop interface,
+ or raises an exception in case no event loop has been set for the
+ current context and the current policy does not specify to create one.
+
+ It should never return None."""
+ raise NotImplementedError
+
+ def set_event_loop(self, loop):
+ """Set the event loop for the current context to loop."""
+ raise NotImplementedError
+
+ def new_event_loop(self):
+ """Create and return a new event loop object according to this
+ policy's rules. If there's need to set this loop as the event loop for
+ the current context, set_event_loop must be called explicitly."""
+ raise NotImplementedError
+
+ # Child processes handling (Unix only).
+
+ def get_child_watcher(self):
+ "Get the watcher for child processes."
+ raise NotImplementedError
+
+ def set_child_watcher(self, watcher):
+ """Set the watcher for child processes."""
+ raise NotImplementedError
+
+
+class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy):
+ """Default policy implementation for accessing the event loop.
+
+ In this policy, each thread has its own event loop. However, we
+ only automatically create an event loop by default for the main
+ thread; other threads by default have no event loop.
+
+ Other policies may have different rules (e.g. a single global
+ event loop, or automatically creating an event loop per thread, or
+ using some other notion of context to which an event loop is
+ associated).
+ """
+
+ _loop_factory = None
+
+ class _Local(threading.local):
+ _loop = None
+ _set_called = False
+
+ def __init__(self):
+ self._local = self._Local()
+
+ def get_event_loop(self):
+ """Get the event loop.
+
+ This may be None or an instance of EventLoop.
+ """
+ if (self._local._loop is None and
+ not self._local._set_called and
+ isinstance(threading.current_thread(), threading._MainThread)):
+ self.set_event_loop(self.new_event_loop())
+ if self._local._loop is None:
+ raise RuntimeError('There is no current event loop in thread %r.'
+ % threading.current_thread().name)
+ return self._local._loop
+
+ def set_event_loop(self, loop):
+ """Set the event loop."""
+ self._local._set_called = True
+ assert loop is None or isinstance(loop, AbstractEventLoop)
+ self._local._loop = loop
+
+ def new_event_loop(self):
+ """Create a new event loop.
+
+ You must call set_event_loop() to make this the current event
+ loop.
+ """
+ return self._loop_factory()
+
+
+# Event loop policy. The policy itself is always global, even if the
+# policy's rules say that there is an event loop per thread (or other
+# notion of context). The default policy is installed by the first
+# call to get_event_loop_policy().
+_event_loop_policy = None
+
+# Lock for protecting the on-the-fly creation of the event loop policy.
+_lock = threading.Lock()
+
+
+def _init_event_loop_policy():
+ global _event_loop_policy
+ with _lock:
+ if _event_loop_policy is None: # pragma: no branch
+ from . import DefaultEventLoopPolicy
+ _event_loop_policy = DefaultEventLoopPolicy()
+
+
+def get_event_loop_policy():
+ """Get the current event loop policy."""
+ if _event_loop_policy is None:
+ _init_event_loop_policy()
+ return _event_loop_policy
+
+
+def set_event_loop_policy(policy):
+ """Set the current event loop policy.
+
+ If policy is None, the default policy is restored."""
+ global _event_loop_policy
+ assert policy is None or isinstance(policy, AbstractEventLoopPolicy)
+ _event_loop_policy = policy
+
+
+def get_event_loop():
+ """Equivalent to calling get_event_loop_policy().get_event_loop()."""
+ return get_event_loop_policy().get_event_loop()
+
+
+def set_event_loop(loop):
+ """Equivalent to calling get_event_loop_policy().set_event_loop(loop)."""
+ get_event_loop_policy().set_event_loop(loop)
+
+
+def new_event_loop():
+ """Equivalent to calling get_event_loop_policy().new_event_loop()."""
+ return get_event_loop_policy().new_event_loop()
+
+
+def get_child_watcher():
+ """Equivalent to calling get_event_loop_policy().get_child_watcher()."""
+ return get_event_loop_policy().get_child_watcher()
+
+
+def set_child_watcher(watcher):
+ """Equivalent to calling
+ get_event_loop_policy().set_child_watcher(watcher)."""
+ return get_event_loop_policy().set_child_watcher(watcher)
diff --git a/trollius/futures.py b/trollius/futures.py
new file mode 100644
index 0000000..d06828a
--- /dev/null
+++ b/trollius/futures.py
@@ -0,0 +1,413 @@
+"""A Future class similar to the one in PEP 3148."""
+
+__all__ = ['CancelledError', 'TimeoutError',
+ 'InvalidStateError',
+ 'Future', 'wrap_future',
+ ]
+
+import concurrent.futures._base
+import logging
+import reprlib
+import sys
+import traceback
+
+from . import events
+
+# States for Future.
+_PENDING = 'PENDING'
+_CANCELLED = 'CANCELLED'
+_FINISHED = 'FINISHED'
+
+_PY34 = sys.version_info >= (3, 4)
+_PY35 = sys.version_info >= (3, 5)
+
+Error = concurrent.futures._base.Error
+CancelledError = concurrent.futures.CancelledError
+TimeoutError = concurrent.futures.TimeoutError
+
+STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging
+
+
+class InvalidStateError(Error):
+ """The operation is not allowed in this state."""
+
+
+class _TracebackLogger:
+ """Helper to log a traceback upon destruction if not cleared.
+
+ This solves a nasty problem with Futures and Tasks that have an
+ exception set: if nobody asks for the exception, the exception is
+ never logged. This violates the Zen of Python: 'Errors should
+ never pass silently. Unless explicitly silenced.'
+
+ However, we don't want to log the exception as soon as
+ set_exception() is called: if the calling code is written
+ properly, it will get the exception and handle it properly. But
+ we *do* want to log it if result() or exception() was never called
+ -- otherwise developers waste a lot of time wondering why their
+ buggy code fails silently.
+
+ An earlier attempt added a __del__() method to the Future class
+ itself, but this backfired because the presence of __del__()
+ prevents garbage collection from breaking cycles. A way out of
+ this catch-22 is to avoid having a __del__() method on the Future
+ class itself, but instead to have a reference to a helper object
+ with a __del__() method that logs the traceback, where we ensure
+ that the helper object doesn't participate in cycles, and only the
+ Future has a reference to it.
+
+ The helper object is added when set_exception() is called. When
+ the Future is collected, and the helper is present, the helper
+ object is also collected, and its __del__() method will log the
+ traceback. When the Future's result() or exception() method is
+ called (and a helper object is present), it removes the helper
+ object, after calling its clear() method to prevent it from
+ logging.
+
+ One downside is that we do a fair amount of work to extract the
+ traceback from the exception, even when it is never logged. It
+ would seem cheaper to just store the exception object, but that
+ references the traceback, which references stack frames, which may
+ reference the Future, which references the _TracebackLogger, and
+ then the _TracebackLogger would be included in a cycle, which is
+ what we're trying to avoid! As an optimization, we don't
+ immediately format the exception; we only do the work when
+ activate() is called, which call is delayed until after all the
+ Future's callbacks have run. Since usually a Future has at least
+ one callback (typically set by 'yield from') and usually that
+ callback extracts the callback, thereby removing the need to
+ format the exception.
+
+ PS. I don't claim credit for this solution. I first heard of it
+ in a discussion about closing files when they are collected.
+ """
+
+ __slots__ = ('loop', 'source_traceback', 'exc', 'tb')
+
+ def __init__(self, future, exc):
+ self.loop = future._loop
+ self.source_traceback = future._source_traceback
+ self.exc = exc
+ self.tb = None
+
+ def activate(self):
+ exc = self.exc
+ if exc is not None:
+ self.exc = None
+ self.tb = traceback.format_exception(exc.__class__, exc,
+ exc.__traceback__)
+
+ def clear(self):
+ self.exc = None
+ self.tb = None
+
+ def __del__(self):
+ if self.tb:
+ msg = 'Future/Task exception was never retrieved\n'
+ if self.source_traceback:
+ src = ''.join(traceback.format_list(self.source_traceback))
+ msg += 'Future/Task created at (most recent call last):\n'
+ msg += '%s\n' % src.rstrip()
+ msg += ''.join(self.tb).rstrip()
+ self.loop.call_exception_handler({'message': msg})
+
+
+class Future:
+ """This class is *almost* compatible with concurrent.futures.Future.
+
+ Differences:
+
+ - result() and exception() do not take a timeout argument and
+ raise an exception when the future isn't done yet.
+
+ - Callbacks registered with add_done_callback() are always called
+ via the event loop's call_soon_threadsafe().
+
+ - This class is not compatible with the wait() and as_completed()
+ methods in the concurrent.futures package.
+
+ (In Python 3.4 or later we may be able to unify the implementations.)
+ """
+
+ # Class variables serving as defaults for instance variables.
+ _state = _PENDING
+ _result = None
+ _exception = None
+ _loop = None
+ _source_traceback = None
+
+ _blocking = False # proper use of future (yield vs yield from)
+
+ _log_traceback = False # Used for Python 3.4 and later
+ _tb_logger = None # Used for Python 3.3 only
+
+ def __init__(self, *, loop=None):
+ """Initialize the future.
+
+ The optional event_loop argument allows to explicitly set the event
+ loop object used by the future. If it's not provided, the future uses
+ the default event loop.
+ """
+ if loop is None:
+ self._loop = events.get_event_loop()
+ else:
+ self._loop = loop
+ self._callbacks = []
+ if self._loop.get_debug():
+ self._source_traceback = traceback.extract_stack(sys._getframe(1))
+
+ def _format_callbacks(self):
+ cb = self._callbacks
+ size = len(cb)
+ if not size:
+ cb = ''
+
+ def format_cb(callback):
+ return events._format_callback_source(callback, ())
+
+ if size == 1:
+ cb = format_cb(cb[0])
+ elif size == 2:
+ cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1]))
+ elif size > 2:
+ cb = '{}, <{} more>, {}'.format(format_cb(cb[0]),
+ size-2,
+ format_cb(cb[-1]))
+ return 'cb=[%s]' % cb
+
+ def _repr_info(self):
+ info = [self._state.lower()]
+ if self._state == _FINISHED:
+ if self._exception is not None:
+ info.append('exception={!r}'.format(self._exception))
+ else:
+ # use reprlib to limit the length of the output, especially
+ # for very long strings
+ result = reprlib.repr(self._result)
+ info.append('result={}'.format(result))
+ if self._callbacks:
+ info.append(self._format_callbacks())
+ if self._source_traceback:
+ frame = self._source_traceback[-1]
+ info.append('created at %s:%s' % (frame[0], frame[1]))
+ return info
+
+ def __repr__(self):
+ info = self._repr_info()
+ return '<%s %s>' % (self.__class__.__name__, ' '.join(info))
+
+ # On Python 3.3 and older, objects with a destructor part of a reference
+ # cycle are never destroyed. It's not more the case on Python 3.4 thanks
+ # to the PEP 442.
+ if _PY34:
+ def __del__(self):
+ if not self._log_traceback:
+ # set_exception() was not called, or result() or exception()
+ # has consumed the exception
+ return
+ exc = self._exception
+ context = {
+ 'message': ('%s exception was never retrieved'
+ % self.__class__.__name__),
+ 'exception': exc,
+ 'future': self,
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+
+ def cancel(self):
+ """Cancel the future and schedule callbacks.
+
+ If the future is already done or cancelled, return False. Otherwise,
+ change the future's state to cancelled, schedule the callbacks and
+ return True.
+ """
+ if self._state != _PENDING:
+ return False
+ self._state = _CANCELLED
+ self._schedule_callbacks()
+ return True
+
+ def _schedule_callbacks(self):
+ """Internal: Ask the event loop to call all callbacks.
+
+ The callbacks are scheduled to be called as soon as possible. Also
+ clears the callback list.
+ """
+ callbacks = self._callbacks[:]
+ if not callbacks:
+ return
+
+ self._callbacks[:] = []
+ for callback in callbacks:
+ self._loop.call_soon(callback, self)
+
+ def cancelled(self):
+ """Return True if the future was cancelled."""
+ return self._state == _CANCELLED
+
+ # Don't implement running(); see http://bugs.python.org/issue18699
+
+ def done(self):
+ """Return True if the future is done.
+
+ Done means either that a result / exception are available, or that the
+ future was cancelled.
+ """
+ return self._state != _PENDING
+
+ def result(self):
+ """Return the result this future represents.
+
+ If the future has been cancelled, raises CancelledError. If the
+ future's result isn't yet available, raises InvalidStateError. If
+ the future is done and has an exception set, this exception is raised.
+ """
+ if self._state == _CANCELLED:
+ raise CancelledError
+ if self._state != _FINISHED:
+ raise InvalidStateError('Result is not ready.')
+ self._log_traceback = False
+ if self._tb_logger is not None:
+ self._tb_logger.clear()
+ self._tb_logger = None
+ if self._exception is not None:
+ raise self._exception
+ return self._result
+
+ def exception(self):
+ """Return the exception that was set on this future.
+
+ The exception (or None if no exception was set) is returned only if
+ the future is done. If the future has been cancelled, raises
+ CancelledError. If the future isn't done yet, raises
+ InvalidStateError.
+ """
+ if self._state == _CANCELLED:
+ raise CancelledError
+ if self._state != _FINISHED:
+ raise InvalidStateError('Exception is not set.')
+ self._log_traceback = False
+ if self._tb_logger is not None:
+ self._tb_logger.clear()
+ self._tb_logger = None
+ return self._exception
+
+ def add_done_callback(self, fn):
+ """Add a callback to be run when the future becomes done.
+
+ The callback is called with a single argument - the future object. If
+ the future is already done when this is called, the callback is
+ scheduled with call_soon.
+ """
+ if self._state != _PENDING:
+ self._loop.call_soon(fn, self)
+ else:
+ self._callbacks.append(fn)
+
+ # New method not in PEP 3148.
+
+ def remove_done_callback(self, fn):
+ """Remove all instances of a callback from the "call when done" list.
+
+ Returns the number of callbacks removed.
+ """
+ filtered_callbacks = [f for f in self._callbacks if f != fn]
+ removed_count = len(self._callbacks) - len(filtered_callbacks)
+ if removed_count:
+ self._callbacks[:] = filtered_callbacks
+ return removed_count
+
+ # So-called internal methods (note: no set_running_or_notify_cancel()).
+
+ def _set_result_unless_cancelled(self, result):
+ """Helper setting the result only if the future was not cancelled."""
+ if self.cancelled():
+ return
+ self.set_result(result)
+
+ def set_result(self, result):
+ """Mark the future done and set its result.
+
+ If the future is already done when this method is called, raises
+ InvalidStateError.
+ """
+ if self._state != _PENDING:
+ raise InvalidStateError('{}: {!r}'.format(self._state, self))
+ self._result = result
+ self._state = _FINISHED
+ self._schedule_callbacks()
+
+ def set_exception(self, exception):
+ """Mark the future done and set an exception.
+
+ If the future is already done when this method is called, raises
+ InvalidStateError.
+ """
+ if self._state != _PENDING:
+ raise InvalidStateError('{}: {!r}'.format(self._state, self))
+ if isinstance(exception, type):
+ exception = exception()
+ self._exception = exception
+ self._state = _FINISHED
+ self._schedule_callbacks()
+ if _PY34:
+ self._log_traceback = True
+ else:
+ self._tb_logger = _TracebackLogger(self, exception)
+ # Arrange for the logger to be activated after all callbacks
+ # have had a chance to call result() or exception().
+ self._loop.call_soon(self._tb_logger.activate)
+
+ # Truly internal methods.
+
+ def _copy_state(self, other):
+ """Internal helper to copy state from another Future.
+
+ The other Future may be a concurrent.futures.Future.
+ """
+ assert other.done()
+ if self.cancelled():
+ return
+ assert not self.done()
+ if other.cancelled():
+ self.cancel()
+ else:
+ exception = other.exception()
+ if exception is not None:
+ self.set_exception(exception)
+ else:
+ result = other.result()
+ self.set_result(result)
+
+ def __iter__(self):
+ if not self.done():
+ self._blocking = True
+ yield self # This tells Task to wait for completion.
+ assert self.done(), "yield from wasn't used with future"
+ return self.result() # May raise too.
+
+ if _PY35:
+ __await__ = __iter__ # make compatible with 'await' expression
+
+
+def wrap_future(fut, *, loop=None):
+ """Wrap concurrent.futures.Future object."""
+ if isinstance(fut, Future):
+ return fut
+ assert isinstance(fut, concurrent.futures.Future), \
+ 'concurrent.futures.Future is expected, got {!r}'.format(fut)
+ if loop is None:
+ loop = events.get_event_loop()
+ new_future = Future(loop=loop)
+
+ def _check_cancel_other(f):
+ if f.cancelled():
+ fut.cancel()
+
+ new_future.add_done_callback(_check_cancel_other)
+ fut.add_done_callback(
+ lambda future: loop.call_soon_threadsafe(
+ new_future._copy_state, future))
+ return new_future
diff --git a/trollius/locks.py b/trollius/locks.py
new file mode 100644
index 0000000..b2e516b
--- /dev/null
+++ b/trollius/locks.py
@@ -0,0 +1,470 @@
+"""Synchronization primitives."""
+
+__all__ = ['Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore']
+
+import collections
+import sys
+
+from . import events
+from . import futures
+from .coroutines import coroutine
+
+
+_PY35 = sys.version_info >= (3, 5)
+
+
+class _ContextManager:
+ """Context manager.
+
+ This enables the following idiom for acquiring and releasing a
+ lock around a block:
+
+ with (yield from lock):
+ <block>
+
+ while failing loudly when accidentally using:
+
+ with lock:
+ <block>
+ """
+
+ def __init__(self, lock):
+ self._lock = lock
+
+ def __enter__(self):
+ # We have no use for the "as ..." clause in the with
+ # statement for locks.
+ return None
+
+ def __exit__(self, *args):
+ try:
+ self._lock.release()
+ finally:
+ self._lock = None # Crudely prevent reuse.
+
+
+class _ContextManagerMixin:
+ def __enter__(self):
+ raise RuntimeError(
+ '"yield from" should be used as context manager expression')
+
+ def __exit__(self, *args):
+ # This must exist because __enter__ exists, even though that
+ # always raises; that's how the with-statement works.
+ pass
+
+ @coroutine
+ def __iter__(self):
+ # This is not a coroutine. It is meant to enable the idiom:
+ #
+ # with (yield from lock):
+ # <block>
+ #
+ # as an alternative to:
+ #
+ # yield from lock.acquire()
+ # try:
+ # <block>
+ # finally:
+ # lock.release()
+ yield from self.acquire()
+ return _ContextManager(self)
+
+ if _PY35:
+
+ def __await__(self):
+ # To make "with await lock" work.
+ yield from self.acquire()
+ return _ContextManager(self)
+
+ @coroutine
+ def __aenter__(self):
+ yield from self.acquire()
+ # We have no use for the "as ..." clause in the with
+ # statement for locks.
+ return None
+
+ @coroutine
+ def __aexit__(self, exc_type, exc, tb):
+ self.release()
+
+
+class Lock(_ContextManagerMixin):
+ """Primitive lock objects.
+
+ A primitive lock is a synchronization primitive that is not owned
+ by a particular coroutine when locked. A primitive lock is in one
+ of two states, 'locked' or 'unlocked'.
+
+ It is created in the unlocked state. It has two basic methods,
+ acquire() and release(). When the state is unlocked, acquire()
+ changes the state to locked and returns immediately. When the
+ state is locked, acquire() blocks until a call to release() in
+ another coroutine changes it to unlocked, then the acquire() call
+ resets it to locked and returns. The release() method should only
+ be called in the locked state; it changes the state to unlocked
+ and returns immediately. If an attempt is made to release an
+ unlocked lock, a RuntimeError will be raised.
+
+ When more than one coroutine is blocked in acquire() waiting for
+ the state to turn to unlocked, only one coroutine proceeds when a
+ release() call resets the state to unlocked; first coroutine which
+ is blocked in acquire() is being processed.
+
+ acquire() is a coroutine and should be called with 'yield from'.
+
+ Locks also support the context management protocol. '(yield from lock)'
+ should be used as context manager expression.
+
+ Usage:
+
+ lock = Lock()
+ ...
+ yield from lock
+ try:
+ ...
+ finally:
+ lock.release()
+
+ Context manager usage:
+
+ lock = Lock()
+ ...
+ with (yield from lock):
+ ...
+
+ Lock objects can be tested for locking state:
+
+ if not lock.locked():
+ yield from lock
+ else:
+ # lock is acquired
+ ...
+
+ """
+
+ def __init__(self, *, loop=None):
+ self._waiters = collections.deque()
+ self._locked = False
+ if loop is not None:
+ self._loop = loop
+ else:
+ self._loop = events.get_event_loop()
+
+ def __repr__(self):
+ res = super().__repr__()
+ extra = 'locked' if self._locked else 'unlocked'
+ if self._waiters:
+ extra = '{},waiters:{}'.format(extra, len(self._waiters))
+ return '<{} [{}]>'.format(res[1:-1], extra)
+
+ def locked(self):
+ """Return True if lock is acquired."""
+ return self._locked
+
+ @coroutine
+ def acquire(self):
+ """Acquire a lock.
+
+ This method blocks until the lock is unlocked, then sets it to
+ locked and returns True.
+ """
+ if not self._waiters and not self._locked:
+ self._locked = True
+ return True
+
+ fut = futures.Future(loop=self._loop)
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ self._locked = True
+ return True
+ finally:
+ self._waiters.remove(fut)
+
+ def release(self):
+ """Release a lock.
+
+ When the lock is locked, reset it to unlocked, and return.
+ If any other coroutines are blocked waiting for the lock to become
+ unlocked, allow exactly one of them to proceed.
+
+ When invoked on an unlocked lock, a RuntimeError is raised.
+
+ There is no return value.
+ """
+ if self._locked:
+ self._locked = False
+ # Wake up the first waiter who isn't cancelled.
+ for fut in self._waiters:
+ if not fut.done():
+ fut.set_result(True)
+ break
+ else:
+ raise RuntimeError('Lock is not acquired.')
+
+
+class Event:
+ """Asynchronous equivalent to threading.Event.
+
+ Class implementing event objects. An event manages a flag that can be set
+ to true with the set() method and reset to false with the clear() method.
+ The wait() method blocks until the flag is true. The flag is initially
+ false.
+ """
+
+ def __init__(self, *, loop=None):
+ self._waiters = collections.deque()
+ self._value = False
+ if loop is not None:
+ self._loop = loop
+ else:
+ self._loop = events.get_event_loop()
+
+ def __repr__(self):
+ res = super().__repr__()
+ extra = 'set' if self._value else 'unset'
+ if self._waiters:
+ extra = '{},waiters:{}'.format(extra, len(self._waiters))
+ return '<{} [{}]>'.format(res[1:-1], extra)
+
+ def is_set(self):
+ """Return True if and only if the internal flag is true."""
+ return self._value
+
+ def set(self):
+ """Set the internal flag to true. All coroutines waiting for it to
+ become true are awakened. Coroutine that call wait() once the flag is
+ true will not block at all.
+ """
+ if not self._value:
+ self._value = True
+
+ for fut in self._waiters:
+ if not fut.done():
+ fut.set_result(True)
+
+ def clear(self):
+ """Reset the internal flag to false. Subsequently, coroutines calling
+ wait() will block until set() is called to set the internal flag
+ to true again."""
+ self._value = False
+
+ @coroutine
+ def wait(self):
+ """Block until the internal flag is true.
+
+ If the internal flag is true on entry, return True
+ immediately. Otherwise, block until another coroutine calls
+ set() to set the flag to true, then return True.
+ """
+ if self._value:
+ return True
+
+ fut = futures.Future(loop=self._loop)
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ return True
+ finally:
+ self._waiters.remove(fut)
+
+
+class Condition(_ContextManagerMixin):
+ """Asynchronous equivalent to threading.Condition.
+
+ This class implements condition variable objects. A condition variable
+ allows one or more coroutines to wait until they are notified by another
+ coroutine.
+
+ A new Lock object is created and used as the underlying lock.
+ """
+
+ def __init__(self, lock=None, *, loop=None):
+ if loop is not None:
+ self._loop = loop
+ else:
+ self._loop = events.get_event_loop()
+
+ if lock is None:
+ lock = Lock(loop=self._loop)
+ elif lock._loop is not self._loop:
+ raise ValueError("loop argument must agree with lock")
+
+ self._lock = lock
+ # Export the lock's locked(), acquire() and release() methods.
+ self.locked = lock.locked
+ self.acquire = lock.acquire
+ self.release = lock.release
+
+ self._waiters = collections.deque()
+
+ def __repr__(self):
+ res = super().__repr__()
+ extra = 'locked' if self.locked() else 'unlocked'
+ if self._waiters:
+ extra = '{},waiters:{}'.format(extra, len(self._waiters))
+ return '<{} [{}]>'.format(res[1:-1], extra)
+
+ @coroutine
+ def wait(self):
+ """Wait until notified.
+
+ If the calling coroutine has not acquired the lock when this
+ method is called, a RuntimeError is raised.
+
+ This method releases the underlying lock, and then blocks
+ until it is awakened by a notify() or notify_all() call for
+ the same condition variable in another coroutine. Once
+ awakened, it re-acquires the lock and returns True.
+ """
+ if not self.locked():
+ raise RuntimeError('cannot wait on un-acquired lock')
+
+ self.release()
+ try:
+ fut = futures.Future(loop=self._loop)
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ return True
+ finally:
+ self._waiters.remove(fut)
+
+ finally:
+ yield from self.acquire()
+
+ @coroutine
+ def wait_for(self, predicate):
+ """Wait until a predicate becomes true.
+
+ The predicate should be a callable which result will be
+ interpreted as a boolean value. The final predicate value is
+ the return value.
+ """
+ result = predicate()
+ while not result:
+ yield from self.wait()
+ result = predicate()
+ return result
+
+ def notify(self, n=1):
+ """By default, wake up one coroutine waiting on this condition, if any.
+ If the calling coroutine has not acquired the lock when this method
+ is called, a RuntimeError is raised.
+
+ This method wakes up at most n of the coroutines waiting for the
+ condition variable; it is a no-op if no coroutines are waiting.
+
+ Note: an awakened coroutine does not actually return from its
+ wait() call until it can reacquire the lock. Since notify() does
+ not release the lock, its caller should.
+ """
+ if not self.locked():
+ raise RuntimeError('cannot notify on un-acquired lock')
+
+ idx = 0
+ for fut in self._waiters:
+ if idx >= n:
+ break
+
+ if not fut.done():
+ idx += 1
+ fut.set_result(False)
+
+ def notify_all(self):
+ """Wake up all threads waiting on this condition. This method acts
+ like notify(), but wakes up all waiting threads instead of one. If the
+ calling thread has not acquired the lock when this method is called,
+ a RuntimeError is raised.
+ """
+ self.notify(len(self._waiters))
+
+
+class Semaphore(_ContextManagerMixin):
+ """A Semaphore implementation.
+
+ A semaphore manages an internal counter which is decremented by each
+ acquire() call and incremented by each release() call. The counter
+ can never go below zero; when acquire() finds that it is zero, it blocks,
+ waiting until some other thread calls release().
+
+ Semaphores also support the context management protocol.
+
+ The optional argument gives the initial value for the internal
+ counter; it defaults to 1. If the value given is less than 0,
+ ValueError is raised.
+ """
+
+ def __init__(self, value=1, *, loop=None):
+ if value < 0:
+ raise ValueError("Semaphore initial value must be >= 0")
+ self._value = value
+ self._waiters = collections.deque()
+ if loop is not None:
+ self._loop = loop
+ else:
+ self._loop = events.get_event_loop()
+
+ def __repr__(self):
+ res = super().__repr__()
+ extra = 'locked' if self.locked() else 'unlocked,value:{}'.format(
+ self._value)
+ if self._waiters:
+ extra = '{},waiters:{}'.format(extra, len(self._waiters))
+ return '<{} [{}]>'.format(res[1:-1], extra)
+
+ def locked(self):
+ """Returns True if semaphore can not be acquired immediately."""
+ return self._value == 0
+
+ @coroutine
+ def acquire(self):
+ """Acquire a semaphore.
+
+ If the internal counter is larger than zero on entry,
+ decrement it by one and return True immediately. If it is
+ zero on entry, block, waiting until some other coroutine has
+ called release() to make it larger than 0, and then return
+ True.
+ """
+ if not self._waiters and self._value > 0:
+ self._value -= 1
+ return True
+
+ fut = futures.Future(loop=self._loop)
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ self._value -= 1
+ return True
+ finally:
+ self._waiters.remove(fut)
+
+ def release(self):
+ """Release a semaphore, incrementing the internal counter by one.
+ When it was zero on entry and another coroutine is waiting for it to
+ become larger than zero again, wake up that coroutine.
+ """
+ self._value += 1
+ for waiter in self._waiters:
+ if not waiter.done():
+ waiter.set_result(True)
+ break
+
+
+class BoundedSemaphore(Semaphore):
+ """A bounded semaphore implementation.
+
+ This raises ValueError in release() if it would increase the value
+ above the initial value.
+ """
+
+ def __init__(self, value=1, *, loop=None):
+ self._bound_value = value
+ super().__init__(value, loop=loop)
+
+ def release(self):
+ if self._value >= self._bound_value:
+ raise ValueError('BoundedSemaphore released too many times')
+ super().release()
diff --git a/trollius/log.py b/trollius/log.py
new file mode 100644
index 0000000..23a7074
--- /dev/null
+++ b/trollius/log.py
@@ -0,0 +1,7 @@
+"""Logging configuration."""
+
+import logging
+
+
+# Name the logger after the package.
+logger = logging.getLogger(__package__)
diff --git a/trollius/proactor_events.py b/trollius/proactor_events.py
new file mode 100644
index 0000000..9c2b8f1
--- /dev/null
+++ b/trollius/proactor_events.py
@@ -0,0 +1,547 @@
+"""Event loop using a proactor and related classes.
+
+A proactor is a "notify-on-completion" multiplexer. Currently a
+proactor is only implemented on Windows with IOCP.
+"""
+
+__all__ = ['BaseProactorEventLoop']
+
+import socket
+import sys
+import warnings
+
+from . import base_events
+from . import constants
+from . import futures
+from . import sslproto
+from . import transports
+from .log import logger
+
+
+class _ProactorBasePipeTransport(transports._FlowControlMixin,
+ transports.BaseTransport):
+ """Base class for pipe and socket transports."""
+
+ def __init__(self, loop, sock, protocol, waiter=None,
+ extra=None, server=None):
+ super().__init__(extra, loop)
+ self._set_extra(sock)
+ self._sock = sock
+ self._protocol = protocol
+ self._server = server
+ self._buffer = None # None or bytearray.
+ self._read_fut = None
+ self._write_fut = None
+ self._pending_write = 0
+ self._conn_lost = 0
+ self._closing = False # Set when close() called.
+ self._eof_written = False
+ if self._server is not None:
+ self._server._attach()
+ self._loop.call_soon(self._protocol.connection_made, self)
+ if waiter is not None:
+ # only wake up the waiter when connection_made() has been called
+ self._loop.call_soon(waiter._set_result_unless_cancelled, None)
+
+ def __repr__(self):
+ info = [self.__class__.__name__]
+ if self._sock is None:
+ info.append('closed')
+ elif self._closing:
+ info.append('closing')
+ if self._sock is not None:
+ info.append('fd=%s' % self._sock.fileno())
+ if self._read_fut is not None:
+ info.append('read=%s' % self._read_fut)
+ if self._write_fut is not None:
+ info.append("write=%r" % self._write_fut)
+ if self._buffer:
+ bufsize = len(self._buffer)
+ info.append('write_bufsize=%s' % bufsize)
+ if self._eof_written:
+ info.append('EOF written')
+ return '<%s>' % ' '.join(info)
+
+ def _set_extra(self, sock):
+ self._extra['pipe'] = sock
+
+ def close(self):
+ if self._closing:
+ return
+ self._closing = True
+ self._conn_lost += 1
+ if not self._buffer and self._write_fut is None:
+ self._loop.call_soon(self._call_connection_lost, None)
+ if self._read_fut is not None:
+ self._read_fut.cancel()
+ self._read_fut = None
+
+ # On Python 3.3 and older, objects with a destructor part of a reference
+ # cycle are never destroyed. It's not more the case on Python 3.4 thanks
+ # to the PEP 442.
+ if sys.version_info >= (3, 4):
+ def __del__(self):
+ if self._sock is not None:
+ warnings.warn("unclosed transport %r" % self, ResourceWarning)
+ self.close()
+
+ def _fatal_error(self, exc, message='Fatal error on pipe transport'):
+ if isinstance(exc, (BrokenPipeError, ConnectionResetError)):
+ if self._loop.get_debug():
+ logger.debug("%r: %s", self, message, exc_info=True)
+ else:
+ self._loop.call_exception_handler({
+ 'message': message,
+ 'exception': exc,
+ 'transport': self,
+ 'protocol': self._protocol,
+ })
+ self._force_close(exc)
+
+ def _force_close(self, exc):
+ if self._closing:
+ return
+ self._closing = True
+ self._conn_lost += 1
+ if self._write_fut:
+ self._write_fut.cancel()
+ self._write_fut = None
+ if self._read_fut:
+ self._read_fut.cancel()
+ self._read_fut = None
+ self._pending_write = 0
+ self._buffer = None
+ self._loop.call_soon(self._call_connection_lost, exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ # XXX If there is a pending overlapped read on the other
+ # end then it may fail with ERROR_NETNAME_DELETED if we
+ # just close our end. First calling shutdown() seems to
+ # cure it, but maybe using DisconnectEx() would be better.
+ if hasattr(self._sock, 'shutdown'):
+ self._sock.shutdown(socket.SHUT_RDWR)
+ self._sock.close()
+ self._sock = None
+ server = self._server
+ if server is not None:
+ server._detach()
+ self._server = None
+
+ def get_write_buffer_size(self):
+ size = self._pending_write
+ if self._buffer is not None:
+ size += len(self._buffer)
+ return size
+
+
+class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
+ transports.ReadTransport):
+ """Transport for read pipes."""
+
+ def __init__(self, loop, sock, protocol, waiter=None,
+ extra=None, server=None):
+ super().__init__(loop, sock, protocol, waiter, extra, server)
+ self._paused = False
+ self._loop.call_soon(self._loop_reading)
+
+ def pause_reading(self):
+ if self._closing:
+ raise RuntimeError('Cannot pause_reading() when closing')
+ if self._paused:
+ raise RuntimeError('Already paused')
+ self._paused = True
+ if self._loop.get_debug():
+ logger.debug("%r pauses reading", self)
+
+ def resume_reading(self):
+ if not self._paused:
+ raise RuntimeError('Not paused')
+ self._paused = False
+ if self._closing:
+ return
+ self._loop.call_soon(self._loop_reading, self._read_fut)
+ if self._loop.get_debug():
+ logger.debug("%r resumes reading", self)
+
+ def _loop_reading(self, fut=None):
+ if self._paused:
+ return
+ data = None
+
+ try:
+ if fut is not None:
+ assert self._read_fut is fut or (self._read_fut is None and
+ self._closing)
+ self._read_fut = None
+ data = fut.result() # deliver data later in "finally" clause
+
+ if self._closing:
+ # since close() has been called we ignore any read data
+ data = None
+ return
+
+ if data == b'':
+ # we got end-of-file so no need to reschedule a new read
+ return
+
+ # reschedule a new read
+ self._read_fut = self._loop._proactor.recv(self._sock, 4096)
+ except ConnectionAbortedError as exc:
+ if not self._closing:
+ self._fatal_error(exc, 'Fatal read error on pipe transport')
+ elif self._loop.get_debug():
+ logger.debug("Read error on pipe transport while closing",
+ exc_info=True)
+ except ConnectionResetError as exc:
+ self._force_close(exc)
+ except OSError as exc:
+ self._fatal_error(exc, 'Fatal read error on pipe transport')
+ except futures.CancelledError:
+ if not self._closing:
+ raise
+ else:
+ self._read_fut.add_done_callback(self._loop_reading)
+ finally:
+ if data:
+ self._protocol.data_received(data)
+ elif data is not None:
+ if self._loop.get_debug():
+ logger.debug("%r received EOF", self)
+ keep_open = self._protocol.eof_received()
+ if not keep_open:
+ self.close()
+
+
+class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
+ transports.WriteTransport):
+ """Transport for write pipes."""
+
+ def write(self, data):
+ if not isinstance(data, (bytes, bytearray, memoryview)):
+ raise TypeError('data argument must be byte-ish (%r)',
+ type(data))
+ if self._eof_written:
+ raise RuntimeError('write_eof() already called')
+
+ if not data:
+ return
+
+ if self._conn_lost:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ logger.warning('socket.send() raised exception.')
+ self._conn_lost += 1
+ return
+
+ # Observable states:
+ # 1. IDLE: _write_fut and _buffer both None
+ # 2. WRITING: _write_fut set; _buffer None
+ # 3. BACKED UP: _write_fut set; _buffer a bytearray
+ # We always copy the data, so the caller can't modify it
+ # while we're still waiting for the I/O to happen.
+ if self._write_fut is None: # IDLE -> WRITING
+ assert self._buffer is None
+ # Pass a copy, except if it's already immutable.
+ self._loop_writing(data=bytes(data))
+ elif not self._buffer: # WRITING -> BACKED UP
+ # Make a mutable copy which we can extend.
+ self._buffer = bytearray(data)
+ self._maybe_pause_protocol()
+ else: # BACKED UP
+ # Append to buffer (also copies).
+ self._buffer.extend(data)
+ self._maybe_pause_protocol()
+
+ def _loop_writing(self, f=None, data=None):
+ try:
+ assert f is self._write_fut
+ self._write_fut = None
+ self._pending_write = 0
+ if f:
+ f.result()
+ if data is None:
+ data = self._buffer
+ self._buffer = None
+ if not data:
+ if self._closing:
+ self._loop.call_soon(self._call_connection_lost, None)
+ if self._eof_written:
+ self._sock.shutdown(socket.SHUT_WR)
+ # Now that we've reduced the buffer size, tell the
+ # protocol to resume writing if it was paused. Note that
+ # we do this last since the callback is called immediately
+ # and it may add more data to the buffer (even causing the
+ # protocol to be paused again).
+ self._maybe_resume_protocol()
+ else:
+ self._write_fut = self._loop._proactor.send(self._sock, data)
+ if not self._write_fut.done():
+ assert self._pending_write == 0
+ self._pending_write = len(data)
+ self._write_fut.add_done_callback(self._loop_writing)
+ self._maybe_pause_protocol()
+ else:
+ self._write_fut.add_done_callback(self._loop_writing)
+ except ConnectionResetError as exc:
+ self._force_close(exc)
+ except OSError as exc:
+ self._fatal_error(exc, 'Fatal write error on pipe transport')
+
+ def can_write_eof(self):
+ return True
+
+ def write_eof(self):
+ self.close()
+
+ def abort(self):
+ self._force_close(None)
+
+
+class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport):
+ def __init__(self, *args, **kw):
+ super().__init__(*args, **kw)
+ self._read_fut = self._loop._proactor.recv(self._sock, 16)
+ self._read_fut.add_done_callback(self._pipe_closed)
+
+ def _pipe_closed(self, fut):
+ if fut.cancelled():
+ # the transport has been closed
+ return
+ assert fut.result() == b''
+ if self._closing:
+ assert self._read_fut is None
+ return
+ assert fut is self._read_fut, (fut, self._read_fut)
+ self._read_fut = None
+ if self._write_fut is not None:
+ self._force_close(BrokenPipeError())
+ else:
+ self.close()
+
+
+class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport,
+ _ProactorBaseWritePipeTransport,
+ transports.Transport):
+ """Transport for duplex pipes."""
+
+ def can_write_eof(self):
+ return False
+
+ def write_eof(self):
+ raise NotImplementedError
+
+
+class _ProactorSocketTransport(_ProactorReadPipeTransport,
+ _ProactorBaseWritePipeTransport,
+ transports.Transport):
+ """Transport for connected sockets."""
+
+ def _set_extra(self, sock):
+ self._extra['socket'] = sock
+ try:
+ self._extra['sockname'] = sock.getsockname()
+ except (socket.error, AttributeError):
+ if self._loop.get_debug():
+ logger.warning("getsockname() failed on %r",
+ sock, exc_info=True)
+ if 'peername' not in self._extra:
+ try:
+ self._extra['peername'] = sock.getpeername()
+ except (socket.error, AttributeError):
+ if self._loop.get_debug():
+ logger.warning("getpeername() failed on %r",
+ sock, exc_info=True)
+
+ def can_write_eof(self):
+ return True
+
+ def write_eof(self):
+ if self._closing or self._eof_written:
+ return
+ self._eof_written = True
+ if self._write_fut is None:
+ self._sock.shutdown(socket.SHUT_WR)
+
+
+class BaseProactorEventLoop(base_events.BaseEventLoop):
+
+ def __init__(self, proactor):
+ super().__init__()
+ logger.debug('Using proactor: %s', proactor.__class__.__name__)
+ self._proactor = proactor
+ self._selector = proactor # convenient alias
+ self._self_reading_future = None
+ self._accept_futures = {} # socket file descriptor => Future
+ proactor.set_loop(self)
+ self._make_self_pipe()
+
+ def _make_socket_transport(self, sock, protocol, waiter=None,
+ extra=None, server=None):
+ return _ProactorSocketTransport(self, sock, protocol, waiter,
+ extra, server)
+
+ def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
+ *, server_side=False, server_hostname=None,
+ extra=None, server=None):
+ if not sslproto._is_sslproto_available():
+ raise NotImplementedError("Proactor event loop requires Python 3.5"
+ " or newer (ssl.MemoryBIO) to support "
+ "SSL")
+
+ ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter,
+ server_side, server_hostname)
+ _ProactorSocketTransport(self, rawsock, ssl_protocol,
+ extra=extra, server=server)
+ return ssl_protocol._app_transport
+
+ def _make_duplex_pipe_transport(self, sock, protocol, waiter=None,
+ extra=None):
+ return _ProactorDuplexPipeTransport(self,
+ sock, protocol, waiter, extra)
+
+ def _make_read_pipe_transport(self, sock, protocol, waiter=None,
+ extra=None):
+ return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra)
+
+ def _make_write_pipe_transport(self, sock, protocol, waiter=None,
+ extra=None):
+ # We want connection_lost() to be called when other end closes
+ return _ProactorWritePipeTransport(self,
+ sock, protocol, waiter, extra)
+
+ def close(self):
+ if self.is_running():
+ raise RuntimeError("Cannot close a running event loop")
+ if self.is_closed():
+ return
+
+ # Call these methods before closing the event loop (before calling
+ # BaseEventLoop.close), because they can schedule callbacks with
+ # call_soon(), which is forbidden when the event loop is closed.
+ self._stop_accept_futures()
+ self._close_self_pipe()
+ self._proactor.close()
+ self._proactor = None
+ self._selector = None
+
+ # Close the event loop
+ super().close()
+
+ def sock_recv(self, sock, n):
+ return self._proactor.recv(sock, n)
+
+ def sock_sendall(self, sock, data):
+ return self._proactor.send(sock, data)
+
+ def sock_connect(self, sock, address):
+ try:
+ if self._debug:
+ base_events._check_resolved_address(sock, address)
+ except ValueError as err:
+ fut = futures.Future(loop=self)
+ fut.set_exception(err)
+ return fut
+ else:
+ return self._proactor.connect(sock, address)
+
+ def sock_accept(self, sock):
+ return self._proactor.accept(sock)
+
+ def _socketpair(self):
+ raise NotImplementedError
+
+ def _close_self_pipe(self):
+ if self._self_reading_future is not None:
+ self._self_reading_future.cancel()
+ self._self_reading_future = None
+ self._ssock.close()
+ self._ssock = None
+ self._csock.close()
+ self._csock = None
+ self._internal_fds -= 1
+
+ def _make_self_pipe(self):
+ # A self-socket, really. :-)
+ self._ssock, self._csock = self._socketpair()
+ self._ssock.setblocking(False)
+ self._csock.setblocking(False)
+ self._internal_fds += 1
+ self.call_soon(self._loop_self_reading)
+
+ def _loop_self_reading(self, f=None):
+ try:
+ if f is not None:
+ f.result() # may raise
+ f = self._proactor.recv(self._ssock, 4096)
+ except futures.CancelledError:
+ # _close_self_pipe() has been called, stop waiting for data
+ return
+ except Exception as exc:
+ self.call_exception_handler({
+ 'message': 'Error on reading from the event loop self pipe',
+ 'exception': exc,
+ 'loop': self,
+ })
+ else:
+ self._self_reading_future = f
+ f.add_done_callback(self._loop_self_reading)
+
+ def _write_to_self(self):
+ self._csock.send(b'\0')
+
+ def _start_serving(self, protocol_factory, sock,
+ sslcontext=None, server=None):
+
+ def loop(f=None):
+ try:
+ if f is not None:
+ conn, addr = f.result()
+ if self._debug:
+ logger.debug("%r got a new connection from %r: %r",
+ server, addr, conn)
+ protocol = protocol_factory()
+ if sslcontext is not None:
+ self._make_ssl_transport(
+ conn, protocol, sslcontext, server_side=True,
+ extra={'peername': addr}, server=server)
+ else:
+ self._make_socket_transport(
+ conn, protocol,
+ extra={'peername': addr}, server=server)
+ if self.is_closed():
+ return
+ f = self._proactor.accept(sock)
+ except OSError as exc:
+ if sock.fileno() != -1:
+ self.call_exception_handler({
+ 'message': 'Accept failed on a socket',
+ 'exception': exc,
+ 'socket': sock,
+ })
+ sock.close()
+ elif self._debug:
+ logger.debug("Accept failed on socket %r",
+ sock, exc_info=True)
+ except futures.CancelledError:
+ sock.close()
+ else:
+ self._accept_futures[sock.fileno()] = f
+ f.add_done_callback(loop)
+
+ self.call_soon(loop)
+
+ def _process_events(self, event_list):
+ # Events are processed in the IocpProactor._poll() method
+ pass
+
+ def _stop_accept_futures(self):
+ for future in self._accept_futures.values():
+ future.cancel()
+ self._accept_futures.clear()
+
+ def _stop_serving(self, sock):
+ self._stop_accept_futures()
+ self._proactor._stop_serving(sock)
+ sock.close()
diff --git a/trollius/protocols.py b/trollius/protocols.py
new file mode 100644
index 0000000..80fcac9
--- /dev/null
+++ b/trollius/protocols.py
@@ -0,0 +1,134 @@
+"""Abstract Protocol class."""
+
+__all__ = ['BaseProtocol', 'Protocol', 'DatagramProtocol',
+ 'SubprocessProtocol']
+
+
+class BaseProtocol:
+ """Common base class for protocol interfaces.
+
+ Usually user implements protocols that derived from BaseProtocol
+ like Protocol or ProcessProtocol.
+
+ The only case when BaseProtocol should be implemented directly is
+ write-only transport like write pipe
+ """
+
+ def connection_made(self, transport):
+ """Called when a connection is made.
+
+ The argument is the transport representing the pipe connection.
+ To receive data, wait for data_received() calls.
+ When the connection is closed, connection_lost() is called.
+ """
+
+ def connection_lost(self, exc):
+ """Called when the connection is lost or closed.
+
+ The argument is an exception object or None (the latter
+ meaning a regular EOF is received or the connection was
+ aborted or closed).
+ """
+
+ def pause_writing(self):
+ """Called when the transport's buffer goes over the high-water mark.
+
+ Pause and resume calls are paired -- pause_writing() is called
+ once when the buffer goes strictly over the high-water mark
+ (even if subsequent writes increases the buffer size even
+ more), and eventually resume_writing() is called once when the
+ buffer size reaches the low-water mark.
+
+ Note that if the buffer size equals the high-water mark,
+ pause_writing() is not called -- it must go strictly over.
+ Conversely, resume_writing() is called when the buffer size is
+ equal or lower than the low-water mark. These end conditions
+ are important to ensure that things go as expected when either
+ mark is zero.
+
+ NOTE: This is the only Protocol callback that is not called
+ through EventLoop.call_soon() -- if it were, it would have no
+ effect when it's most needed (when the app keeps writing
+ without yielding until pause_writing() is called).
+ """
+
+ def resume_writing(self):
+ """Called when the transport's buffer drains below the low-water mark.
+
+ See pause_writing() for details.
+ """
+
+
+class Protocol(BaseProtocol):
+ """Interface for stream protocol.
+
+ The user should implement this interface. They can inherit from
+ this class but don't need to. The implementations here do
+ nothing (they don't raise exceptions).
+
+ When the user wants to requests a transport, they pass a protocol
+ factory to a utility function (e.g., EventLoop.create_connection()).
+
+ When the connection is made successfully, connection_made() is
+ called with a suitable transport object. Then data_received()
+ will be called 0 or more times with data (bytes) received from the
+ transport; finally, connection_lost() will be called exactly once
+ with either an exception object or None as an argument.
+
+ State machine of calls:
+
+ start -> CM [-> DR*] [-> ER?] -> CL -> end
+
+ * CM: connection_made()
+ * DR: data_received()
+ * ER: eof_received()
+ * CL: connection_lost()
+ """
+
+ def data_received(self, data):
+ """Called when some data is received.
+
+ The argument is a bytes object.
+ """
+
+ def eof_received(self):
+ """Called when the other end calls write_eof() or equivalent.
+
+ If this returns a false value (including None), the transport
+ will close itself. If it returns a true value, closing the
+ transport is up to the protocol.
+ """
+
+
+class DatagramProtocol(BaseProtocol):
+ """Interface for datagram protocol."""
+
+ def datagram_received(self, data, addr):
+ """Called when some datagram is received."""
+
+ def error_received(self, exc):
+ """Called when a send or receive operation raises an OSError.
+
+ (Other than BlockingIOError or InterruptedError.)
+ """
+
+
+class SubprocessProtocol(BaseProtocol):
+ """Interface for protocol for subprocess calls."""
+
+ def pipe_data_received(self, fd, data):
+ """Called when the subprocess writes data into stdout/stderr pipe.
+
+ fd is int file descriptor.
+ data is bytes object.
+ """
+
+ def pipe_connection_lost(self, fd, exc):
+ """Called when a file descriptor associated with the child process is
+ closed.
+
+ fd is the int file descriptor that was closed.
+ """
+
+ def process_exited(self):
+ """Called when subprocess has exited."""
diff --git a/trollius/queues.py b/trollius/queues.py
new file mode 100644
index 0000000..ed11662
--- /dev/null
+++ b/trollius/queues.py
@@ -0,0 +1,293 @@
+"""Queues"""
+
+__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty',
+ 'JoinableQueue']
+
+import collections
+import heapq
+
+from . import events
+from . import futures
+from . import locks
+from .tasks import coroutine
+
+
+class QueueEmpty(Exception):
+ """Exception raised when Queue.get_nowait() is called on a Queue object
+ which is empty.
+ """
+ pass
+
+
+class QueueFull(Exception):
+ """Exception raised when the Queue.put_nowait() method is called on a Queue
+ object which is full.
+ """
+ pass
+
+
+class Queue:
+ """A queue, useful for coordinating producer and consumer coroutines.
+
+ If maxsize is less than or equal to zero, the queue size is infinite. If it
+ is an integer greater than 0, then "yield from put()" will block when the
+ queue reaches maxsize, until an item is removed by get().
+
+ Unlike the standard library Queue, you can reliably know this Queue's size
+ with qsize(), since your single-threaded asyncio application won't be
+ interrupted between calling qsize() and doing an operation on the Queue.
+ """
+
+ def __init__(self, maxsize=0, *, loop=None):
+ if loop is None:
+ self._loop = events.get_event_loop()
+ else:
+ self._loop = loop
+ self._maxsize = maxsize
+
+ # Futures.
+ self._getters = collections.deque()
+ # Pairs of (item, Future).
+ self._putters = collections.deque()
+ self._unfinished_tasks = 0
+ self._finished = locks.Event(loop=self._loop)
+ self._finished.set()
+ self._init(maxsize)
+
+ # These three are overridable in subclasses.
+
+ def _init(self, maxsize):
+ self._queue = collections.deque()
+
+ def _get(self):
+ return self._queue.popleft()
+
+ def _put(self, item):
+ self._queue.append(item)
+
+ # End of the overridable methods.
+
+ def __put_internal(self, item):
+ self._put(item)
+ self._unfinished_tasks += 1
+ self._finished.clear()
+
+ def __repr__(self):
+ return '<{} at {:#x} {}>'.format(
+ type(self).__name__, id(self), self._format())
+
+ def __str__(self):
+ return '<{} {}>'.format(type(self).__name__, self._format())
+
+ def _format(self):
+ result = 'maxsize={!r}'.format(self._maxsize)
+ if getattr(self, '_queue', None):
+ result += ' _queue={!r}'.format(list(self._queue))
+ if self._getters:
+ result += ' _getters[{}]'.format(len(self._getters))
+ if self._putters:
+ result += ' _putters[{}]'.format(len(self._putters))
+ if self._unfinished_tasks:
+ result += ' tasks={}'.format(self._unfinished_tasks)
+ return result
+
+ def _consume_done_getters(self):
+ # Delete waiters at the head of the get() queue who've timed out.
+ while self._getters and self._getters[0].done():
+ self._getters.popleft()
+
+ def _consume_done_putters(self):
+ # Delete waiters at the head of the put() queue who've timed out.
+ while self._putters and self._putters[0][1].done():
+ self._putters.popleft()
+
+ def qsize(self):
+ """Number of items in the queue."""
+ return len(self._queue)
+
+ @property
+ def maxsize(self):
+ """Number of items allowed in the queue."""
+ return self._maxsize
+
+ def empty(self):
+ """Return True if the queue is empty, False otherwise."""
+ return not self._queue
+
+ def full(self):
+ """Return True if there are maxsize items in the queue.
+
+ Note: if the Queue was initialized with maxsize=0 (the default),
+ then full() is never True.
+ """
+ if self._maxsize <= 0:
+ return False
+ else:
+ return self.qsize() >= self._maxsize
+
+ @coroutine
+ def put(self, item):
+ """Put an item into the queue.
+
+ Put an item into the queue. If the queue is full, wait until a free
+ slot is available before adding item.
+
+ This method is a coroutine.
+ """
+ self._consume_done_getters()
+ if self._getters:
+ assert not self._queue, (
+ 'queue non-empty, why are getters waiting?')
+
+ getter = self._getters.popleft()
+ self.__put_internal(item)
+
+ # getter cannot be cancelled, we just removed done getters
+ getter.set_result(self._get())
+
+ elif self._maxsize > 0 and self._maxsize <= self.qsize():
+ waiter = futures.Future(loop=self._loop)
+
+ self._putters.append((item, waiter))
+ yield from waiter
+
+ else:
+ self.__put_internal(item)
+
+ def put_nowait(self, item):
+ """Put an item into the queue without blocking.
+
+ If no free slot is immediately available, raise QueueFull.
+ """
+ self._consume_done_getters()
+ if self._getters:
+ assert not self._queue, (
+ 'queue non-empty, why are getters waiting?')
+
+ getter = self._getters.popleft()
+ self.__put_internal(item)
+
+ # getter cannot be cancelled, we just removed done getters
+ getter.set_result(self._get())
+
+ elif self._maxsize > 0 and self._maxsize <= self.qsize():
+ raise QueueFull
+ else:
+ self.__put_internal(item)
+
+ @coroutine
+ def get(self):
+ """Remove and return an item from the queue.
+
+ If queue is empty, wait until an item is available.
+
+ This method is a coroutine.
+ """
+ self._consume_done_putters()
+ if self._putters:
+ assert self.full(), 'queue not full, why are putters waiting?'
+ item, putter = self._putters.popleft()
+ self.__put_internal(item)
+
+ # When a getter runs and frees up a slot so this putter can
+ # run, we need to defer the put for a tick to ensure that
+ # getters and putters alternate perfectly. See
+ # ChannelTest.test_wait.
+ self._loop.call_soon(putter._set_result_unless_cancelled, None)
+
+ return self._get()
+
+ elif self.qsize():
+ return self._get()
+ else:
+ waiter = futures.Future(loop=self._loop)
+
+ self._getters.append(waiter)
+ return (yield from waiter)
+
+ def get_nowait(self):
+ """Remove and return an item from the queue.
+
+ Return an item if one is immediately available, else raise QueueEmpty.
+ """
+ self._consume_done_putters()
+ if self._putters:
+ assert self.full(), 'queue not full, why are putters waiting?'
+ item, putter = self._putters.popleft()
+ self.__put_internal(item)
+ # Wake putter on next tick.
+
+ # getter cannot be cancelled, we just removed done putters
+ putter.set_result(None)
+
+ return self._get()
+
+ elif self.qsize():
+ return self._get()
+ else:
+ raise QueueEmpty
+
+ def task_done(self):
+ """Indicate that a formerly enqueued task is complete.
+
+ Used by queue consumers. For each get() used to fetch a task,
+ a subsequent call to task_done() tells the queue that the processing
+ on the task is complete.
+
+ If a join() is currently blocking, it will resume when all items have
+ been processed (meaning that a task_done() call was received for every
+ item that had been put() into the queue).
+
+ Raises ValueError if called more times than there were items placed in
+ the queue.
+ """
+ if self._unfinished_tasks <= 0:
+ raise ValueError('task_done() called too many times')
+ self._unfinished_tasks -= 1
+ if self._unfinished_tasks == 0:
+ self._finished.set()
+
+ @coroutine
+ def join(self):
+ """Block until all items in the queue have been gotten and processed.
+
+ The count of unfinished tasks goes up whenever an item is added to the
+ queue. The count goes down whenever a consumer calls task_done() to
+ indicate that the item was retrieved and all work on it is complete.
+ When the count of unfinished tasks drops to zero, join() unblocks.
+ """
+ if self._unfinished_tasks > 0:
+ yield from self._finished.wait()
+
+
+class PriorityQueue(Queue):
+ """A subclass of Queue; retrieves entries in priority order (lowest first).
+
+ Entries are typically tuples of the form: (priority number, data).
+ """
+
+ def _init(self, maxsize):
+ self._queue = []
+
+ def _put(self, item, heappush=heapq.heappush):
+ heappush(self._queue, item)
+
+ def _get(self, heappop=heapq.heappop):
+ return heappop(self._queue)
+
+
+class LifoQueue(Queue):
+ """A subclass of Queue that retrieves most recently added entries first."""
+
+ def _init(self, maxsize):
+ self._queue = []
+
+ def _put(self, item):
+ self._queue.append(item)
+
+ def _get(self):
+ return self._queue.pop()
+
+
+JoinableQueue = Queue
+"""Deprecated alias for Queue."""
diff --git a/trollius/selector_events.py b/trollius/selector_events.py
new file mode 100644
index 0000000..7c5b9b5
--- /dev/null
+++ b/trollius/selector_events.py
@@ -0,0 +1,1068 @@
+"""Event loop using a selector and related classes.
+
+A selector is a "notify-when-ready" multiplexer. For a subclass which
+also includes support for signal handling, see the unix_events sub-module.
+"""
+
+__all__ = ['BaseSelectorEventLoop']
+
+import collections
+import errno
+import functools
+import socket
+import sys
+import warnings
+try:
+ import ssl
+except ImportError: # pragma: no cover
+ ssl = None
+
+from . import base_events
+from . import constants
+from . import events
+from . import futures
+from . import selectors
+from . import transports
+from . import sslproto
+from .coroutines import coroutine
+from .log import logger
+
+
+def _test_selector_event(selector, fd, event):
+ # Test if the selector is monitoring 'event' events
+ # for the file descriptor 'fd'.
+ try:
+ key = selector.get_key(fd)
+ except KeyError:
+ return False
+ else:
+ return bool(key.events & event)
+
+
+class BaseSelectorEventLoop(base_events.BaseEventLoop):
+ """Selector event loop.
+
+ See events.EventLoop for API specification.
+ """
+
+ def __init__(self, selector=None):
+ super().__init__()
+
+ if selector is None:
+ selector = selectors.DefaultSelector()
+ logger.debug('Using selector: %s', selector.__class__.__name__)
+ self._selector = selector
+ self._make_self_pipe()
+
+ def _make_socket_transport(self, sock, protocol, waiter=None, *,
+ extra=None, server=None):
+ return _SelectorSocketTransport(self, sock, protocol, waiter,
+ extra, server)
+
+ def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
+ *, server_side=False, server_hostname=None,
+ extra=None, server=None):
+ if not sslproto._is_sslproto_available():
+ return self._make_legacy_ssl_transport(
+ rawsock, protocol, sslcontext, waiter,
+ server_side=server_side, server_hostname=server_hostname,
+ extra=extra, server=server)
+
+ ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter,
+ server_side, server_hostname)
+ _SelectorSocketTransport(self, rawsock, ssl_protocol,
+ extra=extra, server=server)
+ return ssl_protocol._app_transport
+
+ def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext,
+ waiter, *,
+ server_side=False, server_hostname=None,
+ extra=None, server=None):
+ # Use the legacy API: SSL_write, SSL_read, etc. The legacy API is used
+ # on Python 3.4 and older, when ssl.MemoryBIO is not available.
+ return _SelectorSslTransport(
+ self, rawsock, protocol, sslcontext, waiter,
+ server_side, server_hostname, extra, server)
+
+ def _make_datagram_transport(self, sock, protocol,
+ address=None, waiter=None, extra=None):
+ return _SelectorDatagramTransport(self, sock, protocol,
+ address, waiter, extra)
+
+ def close(self):
+ if self.is_running():
+ raise RuntimeError("Cannot close a running event loop")
+ if self.is_closed():
+ return
+ self._close_self_pipe()
+ super().close()
+ if self._selector is not None:
+ self._selector.close()
+ self._selector = None
+
+ def _socketpair(self):
+ raise NotImplementedError
+
+ def _close_self_pipe(self):
+ self.remove_reader(self._ssock.fileno())
+ self._ssock.close()
+ self._ssock = None
+ self._csock.close()
+ self._csock = None
+ self._internal_fds -= 1
+
+ def _make_self_pipe(self):
+ # A self-socket, really. :-)
+ self._ssock, self._csock = self._socketpair()
+ self._ssock.setblocking(False)
+ self._csock.setblocking(False)
+ self._internal_fds += 1
+ self.add_reader(self._ssock.fileno(), self._read_from_self)
+
+ def _process_self_data(self, data):
+ pass
+
+ def _read_from_self(self):
+ while True:
+ try:
+ data = self._ssock.recv(4096)
+ if not data:
+ break
+ self._process_self_data(data)
+ except InterruptedError:
+ continue
+ except BlockingIOError:
+ break
+
+ def _write_to_self(self):
+ # This may be called from a different thread, possibly after
+ # _close_self_pipe() has been called or even while it is
+ # running. Guard for self._csock being None or closed. When
+ # a socket is closed, send() raises OSError (with errno set to
+ # EBADF, but let's not rely on the exact error code).
+ csock = self._csock
+ if csock is not None:
+ try:
+ csock.send(b'\0')
+ except OSError:
+ if self._debug:
+ logger.debug("Fail to write a null byte into the "
+ "self-pipe socket",
+ exc_info=True)
+
+ def _start_serving(self, protocol_factory, sock,
+ sslcontext=None, server=None):
+ self.add_reader(sock.fileno(), self._accept_connection,
+ protocol_factory, sock, sslcontext, server)
+
+ def _accept_connection(self, protocol_factory, sock,
+ sslcontext=None, server=None):
+ try:
+ conn, addr = sock.accept()
+ if self._debug:
+ logger.debug("%r got a new connection from %r: %r",
+ server, addr, conn)
+ conn.setblocking(False)
+ except (BlockingIOError, InterruptedError, ConnectionAbortedError):
+ pass # False alarm.
+ except OSError as exc:
+ # There's nowhere to send the error, so just log it.
+ if exc.errno in (errno.EMFILE, errno.ENFILE,
+ errno.ENOBUFS, errno.ENOMEM):
+ # Some platforms (e.g. Linux keep reporting the FD as
+ # ready, so we remove the read handler temporarily.
+ # We'll try again in a while.
+ self.call_exception_handler({
+ 'message': 'socket.accept() out of system resource',
+ 'exception': exc,
+ 'socket': sock,
+ })
+ self.remove_reader(sock.fileno())
+ self.call_later(constants.ACCEPT_RETRY_DELAY,
+ self._start_serving,
+ protocol_factory, sock, sslcontext, server)
+ else:
+ raise # The event loop will catch, log and ignore it.
+ else:
+ extra = {'peername': addr}
+ accept = self._accept_connection2(protocol_factory, conn, extra,
+ sslcontext, server)
+ self.create_task(accept)
+
+ @coroutine
+ def _accept_connection2(self, protocol_factory, conn, extra,
+ sslcontext=None, server=None):
+ protocol = None
+ transport = None
+ try:
+ protocol = protocol_factory()
+ waiter = futures.Future(loop=self)
+ if sslcontext:
+ transport = self._make_ssl_transport(
+ conn, protocol, sslcontext, waiter=waiter,
+ server_side=True, extra=extra, server=server)
+ else:
+ transport = self._make_socket_transport(
+ conn, protocol, waiter=waiter, extra=extra,
+ server=server)
+
+ try:
+ yield from waiter
+ except:
+ transport.close()
+ raise
+
+ # It's now up to the protocol to handle the connection.
+ except Exception as exc:
+ if self._debug:
+ context = {
+ 'message': ('Error on transport creation '
+ 'for incoming connection'),
+ 'exception': exc,
+ }
+ if protocol is not None:
+ context['protocol'] = protocol
+ if transport is not None:
+ context['transport'] = transport
+ self.call_exception_handler(context)
+
+ def add_reader(self, fd, callback, *args):
+ """Add a reader callback."""
+ self._check_closed()
+ handle = events.Handle(callback, args, self)
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ self._selector.register(fd, selectors.EVENT_READ,
+ (handle, None))
+ else:
+ mask, (reader, writer) = key.events, key.data
+ self._selector.modify(fd, mask | selectors.EVENT_READ,
+ (handle, writer))
+ if reader is not None:
+ reader.cancel()
+
+ def remove_reader(self, fd):
+ """Remove a reader callback."""
+ if self.is_closed():
+ return False
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ return False
+ else:
+ mask, (reader, writer) = key.events, key.data
+ mask &= ~selectors.EVENT_READ
+ if not mask:
+ self._selector.unregister(fd)
+ else:
+ self._selector.modify(fd, mask, (None, writer))
+
+ if reader is not None:
+ reader.cancel()
+ return True
+ else:
+ return False
+
+ def add_writer(self, fd, callback, *args):
+ """Add a writer callback.."""
+ self._check_closed()
+ handle = events.Handle(callback, args, self)
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ self._selector.register(fd, selectors.EVENT_WRITE,
+ (None, handle))
+ else:
+ mask, (reader, writer) = key.events, key.data
+ self._selector.modify(fd, mask | selectors.EVENT_WRITE,
+ (reader, handle))
+ if writer is not None:
+ writer.cancel()
+
+ def remove_writer(self, fd):
+ """Remove a writer callback."""
+ if self.is_closed():
+ return False
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ return False
+ else:
+ mask, (reader, writer) = key.events, key.data
+ # Remove both writer and connector.
+ mask &= ~selectors.EVENT_WRITE
+ if not mask:
+ self._selector.unregister(fd)
+ else:
+ self._selector.modify(fd, mask, (reader, None))
+
+ if writer is not None:
+ writer.cancel()
+ return True
+ else:
+ return False
+
+ def sock_recv(self, sock, n):
+ """Receive data from the socket.
+
+ The return value is a bytes object representing the data received.
+ The maximum amount of data to be received at once is specified by
+ nbytes.
+
+ This method is a coroutine.
+ """
+ if self._debug and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
+ fut = futures.Future(loop=self)
+ self._sock_recv(fut, False, sock, n)
+ return fut
+
+ def _sock_recv(self, fut, registered, sock, n):
+ # _sock_recv() can add itself as an I/O callback if the operation can't
+ # be done immediately. Don't use it directly, call sock_recv().
+ fd = sock.fileno()
+ if registered:
+ # Remove the callback early. It should be rare that the
+ # selector says the fd is ready but the call still returns
+ # EAGAIN, and I am willing to take a hit in that case in
+ # order to simplify the common case.
+ self.remove_reader(fd)
+ if fut.cancelled():
+ return
+ try:
+ data = sock.recv(n)
+ except (BlockingIOError, InterruptedError):
+ self.add_reader(fd, self._sock_recv, fut, True, sock, n)
+ except Exception as exc:
+ fut.set_exception(exc)
+ else:
+ fut.set_result(data)
+
+ def sock_sendall(self, sock, data):
+ """Send data to the socket.
+
+ The socket must be connected to a remote socket. This method continues
+ to send data from data until either all data has been sent or an
+ error occurs. None is returned on success. On error, an exception is
+ raised, and there is no way to determine how much data, if any, was
+ successfully processed by the receiving end of the connection.
+
+ This method is a coroutine.
+ """
+ if self._debug and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
+ fut = futures.Future(loop=self)
+ if data:
+ self._sock_sendall(fut, False, sock, data)
+ else:
+ fut.set_result(None)
+ return fut
+
+ def _sock_sendall(self, fut, registered, sock, data):
+ fd = sock.fileno()
+
+ if registered:
+ self.remove_writer(fd)
+ if fut.cancelled():
+ return
+
+ try:
+ n = sock.send(data)
+ except (BlockingIOError, InterruptedError):
+ n = 0
+ except Exception as exc:
+ fut.set_exception(exc)
+ return
+
+ if n == len(data):
+ fut.set_result(None)
+ else:
+ if n:
+ data = data[n:]
+ self.add_writer(fd, self._sock_sendall, fut, True, sock, data)
+
+ def sock_connect(self, sock, address):
+ """Connect to a remote socket at address.
+
+ The address must be already resolved to avoid the trap of hanging the
+ entire event loop when the address requires doing a DNS lookup. For
+ example, it must be an IP address, not an hostname, for AF_INET and
+ AF_INET6 address families. Use getaddrinfo() to resolve the hostname
+ asynchronously.
+
+ This method is a coroutine.
+ """
+ if self._debug and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
+ fut = futures.Future(loop=self)
+ try:
+ if self._debug:
+ base_events._check_resolved_address(sock, address)
+ except ValueError as err:
+ fut.set_exception(err)
+ else:
+ self._sock_connect(fut, sock, address)
+ return fut
+
+ def _sock_connect(self, fut, sock, address):
+ fd = sock.fileno()
+ try:
+ sock.connect(address)
+ except (BlockingIOError, InterruptedError):
+ # Issue #23618: When the C function connect() fails with EINTR, the
+ # connection runs in background. We have to wait until the socket
+ # becomes writable to be notified when the connection succeed or
+ # fails.
+ fut.add_done_callback(functools.partial(self._sock_connect_done,
+ fd))
+ self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
+ except Exception as exc:
+ fut.set_exception(exc)
+ else:
+ fut.set_result(None)
+
+ def _sock_connect_done(self, fd, fut):
+ self.remove_writer(fd)
+
+ def _sock_connect_cb(self, fut, sock, address):
+ if fut.cancelled():
+ return
+
+ try:
+ err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ if err != 0:
+ # Jump to any except clause below.
+ raise OSError(err, 'Connect call failed %s' % (address,))
+ except (BlockingIOError, InterruptedError):
+ # socket is still registered, the callback will be retried later
+ pass
+ except Exception as exc:
+ fut.set_exception(exc)
+ else:
+ fut.set_result(None)
+
+ def sock_accept(self, sock):
+ """Accept a connection.
+
+ The socket must be bound to an address and listening for connections.
+ The return value is a pair (conn, address) where conn is a new socket
+ object usable to send and receive data on the connection, and address
+ is the address bound to the socket on the other end of the connection.
+
+ This method is a coroutine.
+ """
+ if self._debug and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
+ fut = futures.Future(loop=self)
+ self._sock_accept(fut, False, sock)
+ return fut
+
+ def _sock_accept(self, fut, registered, sock):
+ fd = sock.fileno()
+ if registered:
+ self.remove_reader(fd)
+ if fut.cancelled():
+ return
+ try:
+ conn, address = sock.accept()
+ conn.setblocking(False)
+ except (BlockingIOError, InterruptedError):
+ self.add_reader(fd, self._sock_accept, fut, True, sock)
+ except Exception as exc:
+ fut.set_exception(exc)
+ else:
+ fut.set_result((conn, address))
+
+ def _process_events(self, event_list):
+ for key, mask in event_list:
+ fileobj, (reader, writer) = key.fileobj, key.data
+ if mask & selectors.EVENT_READ and reader is not None:
+ if reader._cancelled:
+ self.remove_reader(fileobj)
+ else:
+ self._add_callback(reader)
+ if mask & selectors.EVENT_WRITE and writer is not None:
+ if writer._cancelled:
+ self.remove_writer(fileobj)
+ else:
+ self._add_callback(writer)
+
+ def _stop_serving(self, sock):
+ self.remove_reader(sock.fileno())
+ sock.close()
+
+
+class _SelectorTransport(transports._FlowControlMixin,
+ transports.Transport):
+
+ max_size = 256 * 1024 # Buffer size passed to recv().
+
+ _buffer_factory = bytearray # Constructs initial value for self._buffer.
+
+ # Attribute used in the destructor: it must be set even if the constructor
+ # is not called (see _SelectorSslTransport which may start by raising an
+ # exception)
+ _sock = None
+
+ def __init__(self, loop, sock, protocol, extra=None, server=None):
+ super().__init__(extra, loop)
+ self._extra['socket'] = sock
+ self._extra['sockname'] = sock.getsockname()
+ if 'peername' not in self._extra:
+ try:
+ self._extra['peername'] = sock.getpeername()
+ except socket.error:
+ self._extra['peername'] = None
+ self._sock = sock
+ self._sock_fd = sock.fileno()
+ self._protocol = protocol
+ self._protocol_connected = True
+ self._server = server
+ self._buffer = self._buffer_factory()
+ self._conn_lost = 0 # Set when call to connection_lost scheduled.
+ self._closing = False # Set when close() called.
+ if self._server is not None:
+ self._server._attach()
+
+ def __repr__(self):
+ info = [self.__class__.__name__]
+ if self._sock is None:
+ info.append('closed')
+ elif self._closing:
+ info.append('closing')
+ info.append('fd=%s' % self._sock_fd)
+ # test if the transport was closed
+ if self._loop is not None and not self._loop.is_closed():
+ polling = _test_selector_event(self._loop._selector,
+ self._sock_fd, selectors.EVENT_READ)
+ if polling:
+ info.append('read=polling')
+ else:
+ info.append('read=idle')
+
+ polling = _test_selector_event(self._loop._selector,
+ self._sock_fd,
+ selectors.EVENT_WRITE)
+ if polling:
+ state = 'polling'
+ else:
+ state = 'idle'
+
+ bufsize = self.get_write_buffer_size()
+ info.append('write=<%s, bufsize=%s>' % (state, bufsize))
+ return '<%s>' % ' '.join(info)
+
+ def abort(self):
+ self._force_close(None)
+
+ def close(self):
+ if self._closing:
+ return
+ self._closing = True
+ self._loop.remove_reader(self._sock_fd)
+ if not self._buffer:
+ self._conn_lost += 1
+ self._loop.call_soon(self._call_connection_lost, None)
+
+ # On Python 3.3 and older, objects with a destructor part of a reference
+ # cycle are never destroyed. It's not more the case on Python 3.4 thanks
+ # to the PEP 442.
+ if sys.version_info >= (3, 4):
+ def __del__(self):
+ if self._sock is not None:
+ warnings.warn("unclosed transport %r" % self, ResourceWarning)
+ self._sock.close()
+
+ def _fatal_error(self, exc, message='Fatal error on transport'):
+ # Should be called from exception handler only.
+ if isinstance(exc, (BrokenPipeError,
+ ConnectionResetError, ConnectionAbortedError)):
+ if self._loop.get_debug():
+ logger.debug("%r: %s", self, message, exc_info=True)
+ else:
+ self._loop.call_exception_handler({
+ 'message': message,
+ 'exception': exc,
+ 'transport': self,
+ 'protocol': self._protocol,
+ })
+ self._force_close(exc)
+
+ def _force_close(self, exc):
+ if self._conn_lost:
+ return
+ if self._buffer:
+ self._buffer.clear()
+ self._loop.remove_writer(self._sock_fd)
+ if not self._closing:
+ self._closing = True
+ self._loop.remove_reader(self._sock_fd)
+ self._conn_lost += 1
+ self._loop.call_soon(self._call_connection_lost, exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ if self._protocol_connected:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._sock.close()
+ self._sock = None
+ self._protocol = None
+ self._loop = None
+ server = self._server
+ if server is not None:
+ server._detach()
+ self._server = None
+
+ def get_write_buffer_size(self):
+ return len(self._buffer)
+
+
+class _SelectorSocketTransport(_SelectorTransport):
+
+ def __init__(self, loop, sock, protocol, waiter=None,
+ extra=None, server=None):
+ super().__init__(loop, sock, protocol, extra, server)
+ self._eof = False
+ self._paused = False
+
+ self._loop.call_soon(self._protocol.connection_made, self)
+ # only start reading when connection_made() has been called
+ self._loop.call_soon(self._loop.add_reader,
+ self._sock_fd, self._read_ready)
+ if waiter is not None:
+ # only wake up the waiter when connection_made() has been called
+ self._loop.call_soon(waiter._set_result_unless_cancelled, None)
+
+ def pause_reading(self):
+ if self._closing:
+ raise RuntimeError('Cannot pause_reading() when closing')
+ if self._paused:
+ raise RuntimeError('Already paused')
+ self._paused = True
+ self._loop.remove_reader(self._sock_fd)
+ if self._loop.get_debug():
+ logger.debug("%r pauses reading", self)
+
+ def resume_reading(self):
+ if not self._paused:
+ raise RuntimeError('Not paused')
+ self._paused = False
+ if self._closing:
+ return
+ self._loop.add_reader(self._sock_fd, self._read_ready)
+ if self._loop.get_debug():
+ logger.debug("%r resumes reading", self)
+
+ def _read_ready(self):
+ try:
+ data = self._sock.recv(self.max_size)
+ except (BlockingIOError, InterruptedError):
+ pass
+ except Exception as exc:
+ self._fatal_error(exc, 'Fatal read error on socket transport')
+ else:
+ if data:
+ self._protocol.data_received(data)
+ else:
+ if self._loop.get_debug():
+ logger.debug("%r received EOF", self)
+ keep_open = self._protocol.eof_received()
+ if keep_open:
+ # We're keeping the connection open so the
+ # protocol can write more, but we still can't
+ # receive more, so remove the reader callback.
+ self._loop.remove_reader(self._sock_fd)
+ else:
+ self.close()
+
+ def write(self, data):
+ if not isinstance(data, (bytes, bytearray, memoryview)):
+ raise TypeError('data argument must be byte-ish (%r)',
+ type(data))
+ if self._eof:
+ raise RuntimeError('Cannot call write() after write_eof()')
+ if not data:
+ return
+
+ if self._conn_lost:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ logger.warning('socket.send() raised exception.')
+ self._conn_lost += 1
+ return
+
+ if not self._buffer:
+ # Optimization: try to send now.
+ try:
+ n = self._sock.send(data)
+ except (BlockingIOError, InterruptedError):
+ pass
+ except Exception as exc:
+ self._fatal_error(exc, 'Fatal write error on socket transport')
+ return
+ else:
+ data = data[n:]
+ if not data:
+ return
+ # Not all was written; register write handler.
+ self._loop.add_writer(self._sock_fd, self._write_ready)
+
+ # Add it to the buffer.
+ self._buffer.extend(data)
+ self._maybe_pause_protocol()
+
+ def _write_ready(self):
+ assert self._buffer, 'Data should not be empty'
+
+ try:
+ n = self._sock.send(self._buffer)
+ except (BlockingIOError, InterruptedError):
+ pass
+ except Exception as exc:
+ self._loop.remove_writer(self._sock_fd)
+ self._buffer.clear()
+ self._fatal_error(exc, 'Fatal write error on socket transport')
+ else:
+ if n:
+ del self._buffer[:n]
+ self._maybe_resume_protocol() # May append to buffer.
+ if not self._buffer:
+ self._loop.remove_writer(self._sock_fd)
+ if self._closing:
+ self._call_connection_lost(None)
+ elif self._eof:
+ self._sock.shutdown(socket.SHUT_WR)
+
+ def write_eof(self):
+ if self._eof:
+ return
+ self._eof = True
+ if not self._buffer:
+ self._sock.shutdown(socket.SHUT_WR)
+
+ def can_write_eof(self):
+ return True
+
+
+class _SelectorSslTransport(_SelectorTransport):
+
+ _buffer_factory = bytearray
+
+ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None,
+ server_side=False, server_hostname=None,
+ extra=None, server=None):
+ if ssl is None:
+ raise RuntimeError('stdlib ssl module not available')
+
+ if not sslcontext:
+ sslcontext = sslproto._create_transport_context(server_side, server_hostname)
+
+ wrap_kwargs = {
+ 'server_side': server_side,
+ 'do_handshake_on_connect': False,
+ }
+ if server_hostname and not server_side:
+ wrap_kwargs['server_hostname'] = server_hostname
+ sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs)
+
+ super().__init__(loop, sslsock, protocol, extra, server)
+ # the protocol connection is only made after the SSL handshake
+ self._protocol_connected = False
+
+ self._server_hostname = server_hostname
+ self._waiter = waiter
+ self._sslcontext = sslcontext
+ self._paused = False
+
+ # SSL-specific extra info. (peercert is set later)
+ self._extra.update(sslcontext=sslcontext)
+
+ if self._loop.get_debug():
+ logger.debug("%r starts SSL handshake", self)
+ start_time = self._loop.time()
+ else:
+ start_time = None
+ self._on_handshake(start_time)
+
+ def _wakeup_waiter(self, exc=None):
+ if self._waiter is None:
+ return
+ if not self._waiter.cancelled():
+ if exc is not None:
+ self._waiter.set_exception(exc)
+ else:
+ self._waiter.set_result(None)
+ self._waiter = None
+
+ def _on_handshake(self, start_time):
+ try:
+ self._sock.do_handshake()
+ except ssl.SSLWantReadError:
+ self._loop.add_reader(self._sock_fd,
+ self._on_handshake, start_time)
+ return
+ except ssl.SSLWantWriteError:
+ self._loop.add_writer(self._sock_fd,
+ self._on_handshake, start_time)
+ return
+ except BaseException as exc:
+ if self._loop.get_debug():
+ logger.warning("%r: SSL handshake failed",
+ self, exc_info=True)
+ self._loop.remove_reader(self._sock_fd)
+ self._loop.remove_writer(self._sock_fd)
+ self._sock.close()
+ self._wakeup_waiter(exc)
+ if isinstance(exc, Exception):
+ return
+ else:
+ raise
+
+ self._loop.remove_reader(self._sock_fd)
+ self._loop.remove_writer(self._sock_fd)
+
+ peercert = self._sock.getpeercert()
+ if not hasattr(self._sslcontext, 'check_hostname'):
+ # Verify hostname if requested, Python 3.4+ uses check_hostname
+ # and checks the hostname in do_handshake()
+ if (self._server_hostname and
+ self._sslcontext.verify_mode != ssl.CERT_NONE):
+ try:
+ ssl.match_hostname(peercert, self._server_hostname)
+ except Exception as exc:
+ if self._loop.get_debug():
+ logger.warning("%r: SSL handshake failed "
+ "on matching the hostname",
+ self, exc_info=True)
+ self._sock.close()
+ self._wakeup_waiter(exc)
+ return
+
+ # Add extra info that becomes available after handshake.
+ self._extra.update(peercert=peercert,
+ cipher=self._sock.cipher(),
+ compression=self._sock.compression(),
+ )
+
+ self._read_wants_write = False
+ self._write_wants_read = False
+ self._loop.add_reader(self._sock_fd, self._read_ready)
+ self._protocol_connected = True
+ self._loop.call_soon(self._protocol.connection_made, self)
+ # only wake up the waiter when connection_made() has been called
+ self._loop.call_soon(self._wakeup_waiter)
+
+ if self._loop.get_debug():
+ dt = self._loop.time() - start_time
+ logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
+
+ def pause_reading(self):
+ # XXX This is a bit icky, given the comment at the top of
+ # _read_ready(). Is it possible to evoke a deadlock? I don't
+ # know, although it doesn't look like it; write() will still
+ # accept more data for the buffer and eventually the app will
+ # call resume_reading() again, and things will flow again.
+
+ if self._closing:
+ raise RuntimeError('Cannot pause_reading() when closing')
+ if self._paused:
+ raise RuntimeError('Already paused')
+ self._paused = True
+ self._loop.remove_reader(self._sock_fd)
+ if self._loop.get_debug():
+ logger.debug("%r pauses reading", self)
+
+ def resume_reading(self):
+ if not self._paused:
+ raise RuntimeError('Not paused')
+ self._paused = False
+ if self._closing:
+ return
+ self._loop.add_reader(self._sock_fd, self._read_ready)
+ if self._loop.get_debug():
+ logger.debug("%r resumes reading", self)
+
+ def _read_ready(self):
+ if self._write_wants_read:
+ self._write_wants_read = False
+ self._write_ready()
+
+ if self._buffer:
+ self._loop.add_writer(self._sock_fd, self._write_ready)
+
+ try:
+ data = self._sock.recv(self.max_size)
+ except (BlockingIOError, InterruptedError, ssl.SSLWantReadError):
+ pass
+ except ssl.SSLWantWriteError:
+ self._read_wants_write = True
+ self._loop.remove_reader(self._sock_fd)
+ self._loop.add_writer(self._sock_fd, self._write_ready)
+ except Exception as exc:
+ self._fatal_error(exc, 'Fatal read error on SSL transport')
+ else:
+ if data:
+ self._protocol.data_received(data)
+ else:
+ try:
+ if self._loop.get_debug():
+ logger.debug("%r received EOF", self)
+ keep_open = self._protocol.eof_received()
+ if keep_open:
+ logger.warning('returning true from eof_received() '
+ 'has no effect when using ssl')
+ finally:
+ self.close()
+
+ def _write_ready(self):
+ if self._read_wants_write:
+ self._read_wants_write = False
+ self._read_ready()
+
+ if not (self._paused or self._closing):
+ self._loop.add_reader(self._sock_fd, self._read_ready)
+
+ if self._buffer:
+ try:
+ n = self._sock.send(self._buffer)
+ except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError):
+ n = 0
+ except ssl.SSLWantReadError:
+ n = 0
+ self._loop.remove_writer(self._sock_fd)
+ self._write_wants_read = True
+ except Exception as exc:
+ self._loop.remove_writer(self._sock_fd)
+ self._buffer.clear()
+ self._fatal_error(exc, 'Fatal write error on SSL transport')
+ return
+
+ if n:
+ del self._buffer[:n]
+
+ self._maybe_resume_protocol() # May append to buffer.
+
+ if not self._buffer:
+ self._loop.remove_writer(self._sock_fd)
+ if self._closing:
+ self._call_connection_lost(None)
+
+ def write(self, data):
+ if not isinstance(data, (bytes, bytearray, memoryview)):
+ raise TypeError('data argument must be byte-ish (%r)',
+ type(data))
+ if not data:
+ return
+
+ if self._conn_lost:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ logger.warning('socket.send() raised exception.')
+ self._conn_lost += 1
+ return
+
+ if not self._buffer:
+ self._loop.add_writer(self._sock_fd, self._write_ready)
+
+ # Add it to the buffer.
+ self._buffer.extend(data)
+ self._maybe_pause_protocol()
+
+ def can_write_eof(self):
+ return False
+
+
+class _SelectorDatagramTransport(_SelectorTransport):
+
+ _buffer_factory = collections.deque
+
+ def __init__(self, loop, sock, protocol, address=None,
+ waiter=None, extra=None):
+ super().__init__(loop, sock, protocol, extra)
+ self._address = address
+ self._loop.call_soon(self._protocol.connection_made, self)
+ # only start reading when connection_made() has been called
+ self._loop.call_soon(self._loop.add_reader,
+ self._sock_fd, self._read_ready)
+ if waiter is not None:
+ # only wake up the waiter when connection_made() has been called
+ self._loop.call_soon(waiter._set_result_unless_cancelled, None)
+
+ def get_write_buffer_size(self):
+ return sum(len(data) for data, _ in self._buffer)
+
+ def _read_ready(self):
+ try:
+ data, addr = self._sock.recvfrom(self.max_size)
+ except (BlockingIOError, InterruptedError):
+ pass
+ except OSError as exc:
+ self._protocol.error_received(exc)
+ except Exception as exc:
+ self._fatal_error(exc, 'Fatal read error on datagram transport')
+ else:
+ self._protocol.datagram_received(data, addr)
+
+ def sendto(self, data, addr=None):
+ if not isinstance(data, (bytes, bytearray, memoryview)):
+ raise TypeError('data argument must be byte-ish (%r)',
+ type(data))
+ if not data:
+ return
+
+ if self._address and addr not in (None, self._address):
+ raise ValueError('Invalid address: must be None or %s' %
+ (self._address,))
+
+ if self._conn_lost and self._address:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ logger.warning('socket.send() raised exception.')
+ self._conn_lost += 1
+ return
+
+ if not self._buffer:
+ # Attempt to send it right away first.
+ try:
+ if self._address:
+ self._sock.send(data)
+ else:
+ self._sock.sendto(data, addr)
+ return
+ except (BlockingIOError, InterruptedError):
+ self._loop.add_writer(self._sock_fd, self._sendto_ready)
+ except OSError as exc:
+ self._protocol.error_received(exc)
+ return
+ except Exception as exc:
+ self._fatal_error(exc,
+ 'Fatal write error on datagram transport')
+ return
+
+ # Ensure that what we buffer is immutable.
+ self._buffer.append((bytes(data), addr))
+ self._maybe_pause_protocol()
+
+ def _sendto_ready(self):
+ while self._buffer:
+ data, addr = self._buffer.popleft()
+ try:
+ if self._address:
+ self._sock.send(data)
+ else:
+ self._sock.sendto(data, addr)
+ except (BlockingIOError, InterruptedError):
+ self._buffer.appendleft((data, addr)) # Try again later.
+ break
+ except OSError as exc:
+ self._protocol.error_received(exc)
+ return
+ except Exception as exc:
+ self._fatal_error(exc,
+ 'Fatal write error on datagram transport')
+ return
+
+ self._maybe_resume_protocol() # May append to buffer.
+ if not self._buffer:
+ self._loop.remove_writer(self._sock_fd)
+ if self._closing:
+ self._call_connection_lost(None)
diff --git a/trollius/selectors.py b/trollius/selectors.py
new file mode 100644
index 0000000..6d569c3
--- /dev/null
+++ b/trollius/selectors.py
@@ -0,0 +1,594 @@
+"""Selectors module.
+
+This module allows high-level and efficient I/O multiplexing, built upon the
+`select` module primitives.
+"""
+
+
+from abc import ABCMeta, abstractmethod
+from collections import namedtuple, Mapping
+import math
+import select
+import sys
+
+
+# generic events, that must be mapped to implementation-specific ones
+EVENT_READ = (1 << 0)
+EVENT_WRITE = (1 << 1)
+
+
+def _fileobj_to_fd(fileobj):
+ """Return a file descriptor from a file object.
+
+ Parameters:
+ fileobj -- file object or file descriptor
+
+ Returns:
+ corresponding file descriptor
+
+ Raises:
+ ValueError if the object is invalid
+ """
+ if isinstance(fileobj, int):
+ fd = fileobj
+ else:
+ try:
+ fd = int(fileobj.fileno())
+ except (AttributeError, TypeError, ValueError):
+ raise ValueError("Invalid file object: "
+ "{!r}".format(fileobj)) from None
+ if fd < 0:
+ raise ValueError("Invalid file descriptor: {}".format(fd))
+ return fd
+
+
+SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
+"""Object used to associate a file object to its backing file descriptor,
+selected event mask and attached data."""
+
+
+class _SelectorMapping(Mapping):
+ """Mapping of file objects to selector keys."""
+
+ def __init__(self, selector):
+ self._selector = selector
+
+ def __len__(self):
+ return len(self._selector._fd_to_key)
+
+ def __getitem__(self, fileobj):
+ try:
+ fd = self._selector._fileobj_lookup(fileobj)
+ return self._selector._fd_to_key[fd]
+ except KeyError:
+ raise KeyError("{!r} is not registered".format(fileobj)) from None
+
+ def __iter__(self):
+ return iter(self._selector._fd_to_key)
+
+
+class BaseSelector(metaclass=ABCMeta):
+ """Selector abstract base class.
+
+ A selector supports registering file objects to be monitored for specific
+ I/O events.
+
+ A file object is a file descriptor or any object with a `fileno()` method.
+ An arbitrary object can be attached to the file object, which can be used
+ for example to store context information, a callback, etc.
+
+ A selector can use various implementations (select(), poll(), epoll()...)
+ depending on the platform. The default `Selector` class uses the most
+ efficient implementation on the current platform.
+ """
+
+ @abstractmethod
+ def register(self, fileobj, events, data=None):
+ """Register a file object.
+
+ Parameters:
+ fileobj -- file object or file descriptor
+ events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
+ data -- attached data
+
+ Returns:
+ SelectorKey instance
+
+ Raises:
+ ValueError if events is invalid
+ KeyError if fileobj is already registered
+ OSError if fileobj is closed or otherwise is unacceptable to
+ the underlying system call (if a system call is made)
+
+ Note:
+ OSError may or may not be raised
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def unregister(self, fileobj):
+ """Unregister a file object.
+
+ Parameters:
+ fileobj -- file object or file descriptor
+
+ Returns:
+ SelectorKey instance
+
+ Raises:
+ KeyError if fileobj is not registered
+
+ Note:
+ If fileobj is registered but has since been closed this does
+ *not* raise OSError (even if the wrapped syscall does)
+ """
+ raise NotImplementedError
+
+ def modify(self, fileobj, events, data=None):
+ """Change a registered file object monitored events or attached data.
+
+ Parameters:
+ fileobj -- file object or file descriptor
+ events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
+ data -- attached data
+
+ Returns:
+ SelectorKey instance
+
+ Raises:
+ Anything that unregister() or register() raises
+ """
+ self.unregister(fileobj)
+ return self.register(fileobj, events, data)
+
+ @abstractmethod
+ def select(self, timeout=None):
+ """Perform the actual selection, until some monitored file objects are
+ ready or a timeout expires.
+
+ Parameters:
+ timeout -- if timeout > 0, this specifies the maximum wait time, in
+ seconds
+ if timeout <= 0, the select() call won't block, and will
+ report the currently ready file objects
+ if timeout is None, select() will block until a monitored
+ file object becomes ready
+
+ Returns:
+ list of (key, events) for ready file objects
+ `events` is a bitwise mask of EVENT_READ|EVENT_WRITE
+ """
+ raise NotImplementedError
+
+ def close(self):
+ """Close the selector.
+
+ This must be called to make sure that any underlying resource is freed.
+ """
+ pass
+
+ def get_key(self, fileobj):
+ """Return the key associated to a registered file object.
+
+ Returns:
+ SelectorKey for this file object
+ """
+ mapping = self.get_map()
+ if mapping is None:
+ raise RuntimeError('Selector is closed')
+ try:
+ return mapping[fileobj]
+ except KeyError:
+ raise KeyError("{!r} is not registered".format(fileobj)) from None
+
+ @abstractmethod
+ def get_map(self):
+ """Return a mapping of file objects to selector keys."""
+ raise NotImplementedError
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ self.close()
+
+
+class _BaseSelectorImpl(BaseSelector):
+ """Base selector implementation."""
+
+ def __init__(self):
+ # this maps file descriptors to keys
+ self._fd_to_key = {}
+ # read-only mapping returned by get_map()
+ self._map = _SelectorMapping(self)
+
+ def _fileobj_lookup(self, fileobj):
+ """Return a file descriptor from a file object.
+
+ This wraps _fileobj_to_fd() to do an exhaustive search in case
+ the object is invalid but we still have it in our map. This
+ is used by unregister() so we can unregister an object that
+ was previously registered even if it is closed. It is also
+ used by _SelectorMapping.
+ """
+ try:
+ return _fileobj_to_fd(fileobj)
+ except ValueError:
+ # Do an exhaustive search.
+ for key in self._fd_to_key.values():
+ if key.fileobj is fileobj:
+ return key.fd
+ # Raise ValueError after all.
+ raise
+
+ def register(self, fileobj, events, data=None):
+ if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)):
+ raise ValueError("Invalid events: {!r}".format(events))
+
+ key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
+
+ if key.fd in self._fd_to_key:
+ raise KeyError("{!r} (FD {}) is already registered"
+ .format(fileobj, key.fd))
+
+ self._fd_to_key[key.fd] = key
+ return key
+
+ def unregister(self, fileobj):
+ try:
+ key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
+ except KeyError:
+ raise KeyError("{!r} is not registered".format(fileobj)) from None
+ return key
+
+ def modify(self, fileobj, events, data=None):
+ # TODO: Subclasses can probably optimize this even further.
+ try:
+ key = self._fd_to_key[self._fileobj_lookup(fileobj)]
+ except KeyError:
+ raise KeyError("{!r} is not registered".format(fileobj)) from None
+ if events != key.events:
+ self.unregister(fileobj)
+ key = self.register(fileobj, events, data)
+ elif data != key.data:
+ # Use a shortcut to update the data.
+ key = key._replace(data=data)
+ self._fd_to_key[key.fd] = key
+ return key
+
+ def close(self):
+ self._fd_to_key.clear()
+ self._map = None
+
+ def get_map(self):
+ return self._map
+
+ def _key_from_fd(self, fd):
+ """Return the key associated to a given file descriptor.
+
+ Parameters:
+ fd -- file descriptor
+
+ Returns:
+ corresponding key, or None if not found
+ """
+ try:
+ return self._fd_to_key[fd]
+ except KeyError:
+ return None
+
+
+class SelectSelector(_BaseSelectorImpl):
+ """Select-based selector."""
+
+ def __init__(self):
+ super().__init__()
+ self._readers = set()
+ self._writers = set()
+
+ def register(self, fileobj, events, data=None):
+ key = super().register(fileobj, events, data)
+ if events & EVENT_READ:
+ self._readers.add(key.fd)
+ if events & EVENT_WRITE:
+ self._writers.add(key.fd)
+ return key
+
+ def unregister(self, fileobj):
+ key = super().unregister(fileobj)
+ self._readers.discard(key.fd)
+ self._writers.discard(key.fd)
+ return key
+
+ if sys.platform == 'win32':
+ def _select(self, r, w, _, timeout=None):
+ r, w, x = select.select(r, w, w, timeout)
+ return r, w + x, []
+ else:
+ _select = select.select
+
+ def select(self, timeout=None):
+ timeout = None if timeout is None else max(timeout, 0)
+ ready = []
+ try:
+ r, w, _ = self._select(self._readers, self._writers, [], timeout)
+ except InterruptedError:
+ return ready
+ r = set(r)
+ w = set(w)
+ for fd in r | w:
+ events = 0
+ if fd in r:
+ events |= EVENT_READ
+ if fd in w:
+ events |= EVENT_WRITE
+
+ key = self._key_from_fd(fd)
+ if key:
+ ready.append((key, events & key.events))
+ return ready
+
+
+if hasattr(select, 'poll'):
+
+ class PollSelector(_BaseSelectorImpl):
+ """Poll-based selector."""
+
+ def __init__(self):
+ super().__init__()
+ self._poll = select.poll()
+
+ def register(self, fileobj, events, data=None):
+ key = super().register(fileobj, events, data)
+ poll_events = 0
+ if events & EVENT_READ:
+ poll_events |= select.POLLIN
+ if events & EVENT_WRITE:
+ poll_events |= select.POLLOUT
+ self._poll.register(key.fd, poll_events)
+ return key
+
+ def unregister(self, fileobj):
+ key = super().unregister(fileobj)
+ self._poll.unregister(key.fd)
+ return key
+
+ def select(self, timeout=None):
+ if timeout is None:
+ timeout = None
+ elif timeout <= 0:
+ timeout = 0
+ else:
+ # poll() has a resolution of 1 millisecond, round away from
+ # zero to wait *at least* timeout seconds.
+ timeout = math.ceil(timeout * 1e3)
+ ready = []
+ try:
+ fd_event_list = self._poll.poll(timeout)
+ except InterruptedError:
+ return ready
+ for fd, event in fd_event_list:
+ events = 0
+ if event & ~select.POLLIN:
+ events |= EVENT_WRITE
+ if event & ~select.POLLOUT:
+ events |= EVENT_READ
+
+ key = self._key_from_fd(fd)
+ if key:
+ ready.append((key, events & key.events))
+ return ready
+
+
+if hasattr(select, 'epoll'):
+
+ class EpollSelector(_BaseSelectorImpl):
+ """Epoll-based selector."""
+
+ def __init__(self):
+ super().__init__()
+ self._epoll = select.epoll()
+
+ def fileno(self):
+ return self._epoll.fileno()
+
+ def register(self, fileobj, events, data=None):
+ key = super().register(fileobj, events, data)
+ epoll_events = 0
+ if events & EVENT_READ:
+ epoll_events |= select.EPOLLIN
+ if events & EVENT_WRITE:
+ epoll_events |= select.EPOLLOUT
+ self._epoll.register(key.fd, epoll_events)
+ return key
+
+ def unregister(self, fileobj):
+ key = super().unregister(fileobj)
+ try:
+ self._epoll.unregister(key.fd)
+ except OSError:
+ # This can happen if the FD was closed since it
+ # was registered.
+ pass
+ return key
+
+ def select(self, timeout=None):
+ if timeout is None:
+ timeout = -1
+ elif timeout <= 0:
+ timeout = 0
+ else:
+ # epoll_wait() has a resolution of 1 millisecond, round away
+ # from zero to wait *at least* timeout seconds.
+ timeout = math.ceil(timeout * 1e3) * 1e-3
+
+ # epoll_wait() expects `maxevents` to be greater than zero;
+ # we want to make sure that `select()` can be called when no
+ # FD is registered.
+ max_ev = max(len(self._fd_to_key), 1)
+
+ ready = []
+ try:
+ fd_event_list = self._epoll.poll(timeout, max_ev)
+ except InterruptedError:
+ return ready
+ for fd, event in fd_event_list:
+ events = 0
+ if event & ~select.EPOLLIN:
+ events |= EVENT_WRITE
+ if event & ~select.EPOLLOUT:
+ events |= EVENT_READ
+
+ key = self._key_from_fd(fd)
+ if key:
+ ready.append((key, events & key.events))
+ return ready
+
+ def close(self):
+ self._epoll.close()
+ super().close()
+
+
+if hasattr(select, 'devpoll'):
+
+ class DevpollSelector(_BaseSelectorImpl):
+ """Solaris /dev/poll selector."""
+
+ def __init__(self):
+ super().__init__()
+ self._devpoll = select.devpoll()
+
+ def fileno(self):
+ return self._devpoll.fileno()
+
+ def register(self, fileobj, events, data=None):
+ key = super().register(fileobj, events, data)
+ poll_events = 0
+ if events & EVENT_READ:
+ poll_events |= select.POLLIN
+ if events & EVENT_WRITE:
+ poll_events |= select.POLLOUT
+ self._devpoll.register(key.fd, poll_events)
+ return key
+
+ def unregister(self, fileobj):
+ key = super().unregister(fileobj)
+ self._devpoll.unregister(key.fd)
+ return key
+
+ def select(self, timeout=None):
+ if timeout is None:
+ timeout = None
+ elif timeout <= 0:
+ timeout = 0
+ else:
+ # devpoll() has a resolution of 1 millisecond, round away from
+ # zero to wait *at least* timeout seconds.
+ timeout = math.ceil(timeout * 1e3)
+ ready = []
+ try:
+ fd_event_list = self._devpoll.poll(timeout)
+ except InterruptedError:
+ return ready
+ for fd, event in fd_event_list:
+ events = 0
+ if event & ~select.POLLIN:
+ events |= EVENT_WRITE
+ if event & ~select.POLLOUT:
+ events |= EVENT_READ
+
+ key = self._key_from_fd(fd)
+ if key:
+ ready.append((key, events & key.events))
+ return ready
+
+ def close(self):
+ self._devpoll.close()
+ super().close()
+
+
+if hasattr(select, 'kqueue'):
+
+ class KqueueSelector(_BaseSelectorImpl):
+ """Kqueue-based selector."""
+
+ def __init__(self):
+ super().__init__()
+ self._kqueue = select.kqueue()
+
+ def fileno(self):
+ return self._kqueue.fileno()
+
+ def register(self, fileobj, events, data=None):
+ key = super().register(fileobj, events, data)
+ if events & EVENT_READ:
+ kev = select.kevent(key.fd, select.KQ_FILTER_READ,
+ select.KQ_EV_ADD)
+ self._kqueue.control([kev], 0, 0)
+ if events & EVENT_WRITE:
+ kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
+ select.KQ_EV_ADD)
+ self._kqueue.control([kev], 0, 0)
+ return key
+
+ def unregister(self, fileobj):
+ key = super().unregister(fileobj)
+ if key.events & EVENT_READ:
+ kev = select.kevent(key.fd, select.KQ_FILTER_READ,
+ select.KQ_EV_DELETE)
+ try:
+ self._kqueue.control([kev], 0, 0)
+ except OSError:
+ # This can happen if the FD was closed since it
+ # was registered.
+ pass
+ if key.events & EVENT_WRITE:
+ kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
+ select.KQ_EV_DELETE)
+ try:
+ self._kqueue.control([kev], 0, 0)
+ except OSError:
+ # See comment above.
+ pass
+ return key
+
+ def select(self, timeout=None):
+ timeout = None if timeout is None else max(timeout, 0)
+ max_ev = len(self._fd_to_key)
+ ready = []
+ try:
+ kev_list = self._kqueue.control(None, max_ev, timeout)
+ except InterruptedError:
+ return ready
+ for kev in kev_list:
+ fd = kev.ident
+ flag = kev.filter
+ events = 0
+ if flag == select.KQ_FILTER_READ:
+ events |= EVENT_READ
+ if flag == select.KQ_FILTER_WRITE:
+ events |= EVENT_WRITE
+
+ key = self._key_from_fd(fd)
+ if key:
+ ready.append((key, events & key.events))
+ return ready
+
+ def close(self):
+ self._kqueue.close()
+ super().close()
+
+
+# Choose the best implementation, roughly:
+# epoll|kqueue|devpoll > poll > select.
+# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
+if 'KqueueSelector' in globals():
+ DefaultSelector = KqueueSelector
+elif 'EpollSelector' in globals():
+ DefaultSelector = EpollSelector
+elif 'DevpollSelector' in globals():
+ DefaultSelector = DevpollSelector
+elif 'PollSelector' in globals():
+ DefaultSelector = PollSelector
+else:
+ DefaultSelector = SelectSelector
diff --git a/trollius/sslproto.py b/trollius/sslproto.py
new file mode 100644
index 0000000..235855e
--- /dev/null
+++ b/trollius/sslproto.py
@@ -0,0 +1,668 @@
+import collections
+import sys
+import warnings
+try:
+ import ssl
+except ImportError: # pragma: no cover
+ ssl = None
+
+from . import protocols
+from . import transports
+from .log import logger
+
+
+def _create_transport_context(server_side, server_hostname):
+ if server_side:
+ raise ValueError('Server side SSL needs a valid SSLContext')
+
+ # Client side may pass ssl=True to use a default
+ # context; in that case the sslcontext passed is None.
+ # The default is secure for client connections.
+ if hasattr(ssl, 'create_default_context'):
+ # Python 3.4+: use up-to-date strong settings.
+ sslcontext = ssl.create_default_context()
+ if not server_hostname:
+ sslcontext.check_hostname = False
+ else:
+ # Fallback for Python 3.3.
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext.options |= ssl.OP_NO_SSLv2
+ sslcontext.options |= ssl.OP_NO_SSLv3
+ sslcontext.set_default_verify_paths()
+ sslcontext.verify_mode = ssl.CERT_REQUIRED
+ return sslcontext
+
+
+def _is_sslproto_available():
+ return hasattr(ssl, "MemoryBIO")
+
+
+# States of an _SSLPipe.
+_UNWRAPPED = "UNWRAPPED"
+_DO_HANDSHAKE = "DO_HANDSHAKE"
+_WRAPPED = "WRAPPED"
+_SHUTDOWN = "SHUTDOWN"
+
+
+class _SSLPipe(object):
+ """An SSL "Pipe".
+
+ An SSL pipe allows you to communicate with an SSL/TLS protocol instance
+ through memory buffers. It can be used to implement a security layer for an
+ existing connection where you don't have access to the connection's file
+ descriptor, or for some reason you don't want to use it.
+
+ An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
+ data is passed through untransformed. In wrapped mode, application level
+ data is encrypted to SSL record level data and vice versa. The SSL record
+ level is the lowest level in the SSL protocol suite and is what travels
+ as-is over the wire.
+
+ An SslPipe initially is in "unwrapped" mode. To start SSL, call
+ do_handshake(). To shutdown SSL again, call unwrap().
+ """
+
+ max_size = 256 * 1024 # Buffer size passed to read()
+
+ def __init__(self, context, server_side, server_hostname=None):
+ """
+ The *context* argument specifies the ssl.SSLContext to use.
+
+ The *server_side* argument indicates whether this is a server side or
+ client side transport.
+
+ The optional *server_hostname* argument can be used to specify the
+ hostname you are connecting to. You may only specify this parameter if
+ the _ssl module supports Server Name Indication (SNI).
+ """
+ self._context = context
+ self._server_side = server_side
+ self._server_hostname = server_hostname
+ self._state = _UNWRAPPED
+ self._incoming = ssl.MemoryBIO()
+ self._outgoing = ssl.MemoryBIO()
+ self._sslobj = None
+ self._need_ssldata = False
+ self._handshake_cb = None
+ self._shutdown_cb = None
+
+ @property
+ def context(self):
+ """The SSL context passed to the constructor."""
+ return self._context
+
+ @property
+ def ssl_object(self):
+ """The internal ssl.SSLObject instance.
+
+ Return None if the pipe is not wrapped.
+ """
+ return self._sslobj
+
+ @property
+ def need_ssldata(self):
+ """Whether more record level data is needed to complete a handshake
+ that is currently in progress."""
+ return self._need_ssldata
+
+ @property
+ def wrapped(self):
+ """
+ Whether a security layer is currently in effect.
+
+ Return False during handshake.
+ """
+ return self._state == _WRAPPED
+
+ def do_handshake(self, callback=None):
+ """Start the SSL handshake.
+
+ Return a list of ssldata. A ssldata element is a list of buffers
+
+ The optional *callback* argument can be used to install a callback that
+ will be called when the handshake is complete. The callback will be
+ called with None if successful, else an exception instance.
+ """
+ if self._state != _UNWRAPPED:
+ raise RuntimeError('handshake in progress or completed')
+ self._sslobj = self._context.wrap_bio(
+ self._incoming, self._outgoing,
+ server_side=self._server_side,
+ server_hostname=self._server_hostname)
+ self._state = _DO_HANDSHAKE
+ self._handshake_cb = callback
+ ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
+ assert len(appdata) == 0
+ return ssldata
+
+ def shutdown(self, callback=None):
+ """Start the SSL shutdown sequence.
+
+ Return a list of ssldata. A ssldata element is a list of buffers
+
+ The optional *callback* argument can be used to install a callback that
+ will be called when the shutdown is complete. The callback will be
+ called without arguments.
+ """
+ if self._state == _UNWRAPPED:
+ raise RuntimeError('no security layer present')
+ if self._state == _SHUTDOWN:
+ raise RuntimeError('shutdown in progress')
+ assert self._state in (_WRAPPED, _DO_HANDSHAKE)
+ self._state = _SHUTDOWN
+ self._shutdown_cb = callback
+ ssldata, appdata = self.feed_ssldata(b'')
+ assert appdata == [] or appdata == [b'']
+ return ssldata
+
+ def feed_eof(self):
+ """Send a potentially "ragged" EOF.
+
+ This method will raise an SSL_ERROR_EOF exception if the EOF is
+ unexpected.
+ """
+ self._incoming.write_eof()
+ ssldata, appdata = self.feed_ssldata(b'')
+ assert appdata == [] or appdata == [b'']
+
+ def feed_ssldata(self, data, only_handshake=False):
+ """Feed SSL record level data into the pipe.
+
+ The data must be a bytes instance. It is OK to send an empty bytes
+ instance. This can be used to get ssldata for a handshake initiated by
+ this endpoint.
+
+ Return a (ssldata, appdata) tuple. The ssldata element is a list of
+ buffers containing SSL data that needs to be sent to the remote SSL.
+
+ The appdata element is a list of buffers containing plaintext data that
+ needs to be forwarded to the application. The appdata list may contain
+ an empty buffer indicating an SSL "close_notify" alert. This alert must
+ be acknowledged by calling shutdown().
+ """
+ if self._state == _UNWRAPPED:
+ # If unwrapped, pass plaintext data straight through.
+ if data:
+ appdata = [data]
+ else:
+ appdata = []
+ return ([], appdata)
+
+ self._need_ssldata = False
+ if data:
+ self._incoming.write(data)
+
+ ssldata = []
+ appdata = []
+ try:
+ if self._state == _DO_HANDSHAKE:
+ # Call do_handshake() until it doesn't raise anymore.
+ self._sslobj.do_handshake()
+ self._state = _WRAPPED
+ if self._handshake_cb:
+ self._handshake_cb(None)
+ if only_handshake:
+ return (ssldata, appdata)
+ # Handshake done: execute the wrapped block
+
+ if self._state == _WRAPPED:
+ # Main state: read data from SSL until close_notify
+ while True:
+ chunk = self._sslobj.read(self.max_size)
+ appdata.append(chunk)
+ if not chunk: # close_notify
+ break
+
+ elif self._state == _SHUTDOWN:
+ # Call shutdown() until it doesn't raise anymore.
+ self._sslobj.unwrap()
+ self._sslobj = None
+ self._state = _UNWRAPPED
+ if self._shutdown_cb:
+ self._shutdown_cb()
+
+ elif self._state == _UNWRAPPED:
+ # Drain possible plaintext data after close_notify.
+ appdata.append(self._incoming.read())
+ except (ssl.SSLError, ssl.CertificateError) as exc:
+ if getattr(exc, 'errno', None) not in (
+ ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
+ ssl.SSL_ERROR_SYSCALL):
+ if self._state == _DO_HANDSHAKE and self._handshake_cb:
+ self._handshake_cb(exc)
+ raise
+ self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
+
+ # Check for record level data that needs to be sent back.
+ # Happens for the initial handshake and renegotiations.
+ if self._outgoing.pending:
+ ssldata.append(self._outgoing.read())
+ return (ssldata, appdata)
+
+ def feed_appdata(self, data, offset=0):
+ """Feed plaintext data into the pipe.
+
+ Return an (ssldata, offset) tuple. The ssldata element is a list of
+ buffers containing record level data that needs to be sent to the
+ remote SSL instance. The offset is the number of plaintext bytes that
+ were processed, which may be less than the length of data.
+
+ NOTE: In case of short writes, this call MUST be retried with the SAME
+ buffer passed into the *data* argument (i.e. the id() must be the
+ same). This is an OpenSSL requirement. A further particularity is that
+ a short write will always have offset == 0, because the _ssl module
+ does not enable partial writes. And even though the offset is zero,
+ there will still be encrypted data in ssldata.
+ """
+ assert 0 <= offset <= len(data)
+ if self._state == _UNWRAPPED:
+ # pass through data in unwrapped mode
+ if offset < len(data):
+ ssldata = [data[offset:]]
+ else:
+ ssldata = []
+ return (ssldata, len(data))
+
+ ssldata = []
+ view = memoryview(data)
+ while True:
+ self._need_ssldata = False
+ try:
+ if offset < len(view):
+ offset += self._sslobj.write(view[offset:])
+ except ssl.SSLError as exc:
+ # It is not allowed to call write() after unwrap() until the
+ # close_notify is acknowledged. We return the condition to the
+ # caller as a short write.
+ if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
+ exc.errno = ssl.SSL_ERROR_WANT_READ
+ if exc.errno not in (ssl.SSL_ERROR_WANT_READ,
+ ssl.SSL_ERROR_WANT_WRITE,
+ ssl.SSL_ERROR_SYSCALL):
+ raise
+ self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
+
+ # See if there's any record level data back for us.
+ if self._outgoing.pending:
+ ssldata.append(self._outgoing.read())
+ if offset == len(view) or self._need_ssldata:
+ break
+ return (ssldata, offset)
+
+
+class _SSLProtocolTransport(transports._FlowControlMixin,
+ transports.Transport):
+
+ def __init__(self, loop, ssl_protocol, app_protocol):
+ self._loop = loop
+ self._ssl_protocol = ssl_protocol
+ self._app_protocol = app_protocol
+ self._closed = False
+
+ def get_extra_info(self, name, default=None):
+ """Get optional transport information."""
+ return self._ssl_protocol._get_extra_info(name, default)
+
+ def close(self):
+ """Close the transport.
+
+ Buffered data will be flushed asynchronously. No more data
+ will be received. After all buffered data is flushed, the
+ protocol's connection_lost() method will (eventually) called
+ with None as its argument.
+ """
+ self._closed = True
+ self._ssl_protocol._start_shutdown()
+
+ # On Python 3.3 and older, objects with a destructor part of a reference
+ # cycle are never destroyed. It's not more the case on Python 3.4 thanks
+ # to the PEP 442.
+ if sys.version_info >= (3, 4):
+ def __del__(self):
+ if not self._closed:
+ warnings.warn("unclosed transport %r" % self, ResourceWarning)
+ self.close()
+
+ def pause_reading(self):
+ """Pause the receiving end.
+
+ No data will be passed to the protocol's data_received()
+ method until resume_reading() is called.
+ """
+ self._ssl_protocol._transport.pause_reading()
+
+ def resume_reading(self):
+ """Resume the receiving end.
+
+ Data received will once again be passed to the protocol's
+ data_received() method.
+ """
+ self._ssl_protocol._transport.resume_reading()
+
+ def set_write_buffer_limits(self, high=None, low=None):
+ """Set the high- and low-water limits for write flow control.
+
+ These two values control when to call the protocol's
+ pause_writing() and resume_writing() methods. If specified,
+ the low-water limit must be less than or equal to the
+ high-water limit. Neither value can be negative.
+
+ The defaults are implementation-specific. If only the
+ high-water limit is given, the low-water limit defaults to a
+ implementation-specific value less than or equal to the
+ high-water limit. Setting high to zero forces low to zero as
+ well, and causes pause_writing() to be called whenever the
+ buffer becomes non-empty. Setting low to zero causes
+ resume_writing() to be called only once the buffer is empty.
+ Use of zero for either limit is generally sub-optimal as it
+ reduces opportunities for doing I/O and computation
+ concurrently.
+ """
+ self._ssl_protocol._transport.set_write_buffer_limits(high, low)
+
+ def get_write_buffer_size(self):
+ """Return the current size of the write buffer."""
+ return self._ssl_protocol._transport.get_write_buffer_size()
+
+ def write(self, data):
+ """Write some data bytes to the transport.
+
+ This does not block; it buffers the data and arranges for it
+ to be sent out asynchronously.
+ """
+ if not isinstance(data, (bytes, bytearray, memoryview)):
+ raise TypeError("data: expecting a bytes-like instance, got {!r}"
+ .format(type(data).__name__))
+ if not data:
+ return
+ self._ssl_protocol._write_appdata(data)
+
+ def can_write_eof(self):
+ """Return True if this transport supports write_eof(), False if not."""
+ return False
+
+ def abort(self):
+ """Close the transport immediately.
+
+ Buffered data will be lost. No more data will be received.
+ The protocol's connection_lost() method will (eventually) be
+ called with None as its argument.
+ """
+ self._ssl_protocol._abort()
+
+
+class SSLProtocol(protocols.Protocol):
+ """SSL protocol.
+
+ Implementation of SSL on top of a socket using incoming and outgoing
+ buffers which are ssl.MemoryBIO objects.
+ """
+
+ def __init__(self, loop, app_protocol, sslcontext, waiter,
+ server_side=False, server_hostname=None):
+ if ssl is None:
+ raise RuntimeError('stdlib ssl module not available')
+
+ if not sslcontext:
+ sslcontext = _create_transport_context(server_side, server_hostname)
+
+ self._server_side = server_side
+ if server_hostname and not server_side:
+ self._server_hostname = server_hostname
+ else:
+ self._server_hostname = None
+ self._sslcontext = sslcontext
+ # SSL-specific extra info. More info are set when the handshake
+ # completes.
+ self._extra = dict(sslcontext=sslcontext)
+
+ # App data write buffering
+ self._write_backlog = collections.deque()
+ self._write_buffer_size = 0
+
+ self._waiter = waiter
+ self._loop = loop
+ self._app_protocol = app_protocol
+ self._app_transport = _SSLProtocolTransport(self._loop,
+ self, self._app_protocol)
+ self._sslpipe = None
+ self._session_established = False
+ self._in_handshake = False
+ self._in_shutdown = False
+ self._transport = None
+
+ def _wakeup_waiter(self, exc=None):
+ if self._waiter is None:
+ return
+ if not self._waiter.cancelled():
+ if exc is not None:
+ self._waiter.set_exception(exc)
+ else:
+ self._waiter.set_result(None)
+ self._waiter = None
+
+ def connection_made(self, transport):
+ """Called when the low-level connection is made.
+
+ Start the SSL handshake.
+ """
+ self._transport = transport
+ self._sslpipe = _SSLPipe(self._sslcontext,
+ self._server_side,
+ self._server_hostname)
+ self._start_handshake()
+
+ def connection_lost(self, exc):
+ """Called when the low-level connection is lost or closed.
+
+ The argument is an exception object or None (the latter
+ meaning a regular EOF is received or the connection was
+ aborted or closed).
+ """
+ if self._session_established:
+ self._session_established = False
+ self._loop.call_soon(self._app_protocol.connection_lost, exc)
+ self._transport = None
+ self._app_transport = None
+
+ def pause_writing(self):
+ """Called when the low-level transport's buffer goes over
+ the high-water mark.
+ """
+ self._app_protocol.pause_writing()
+
+ def resume_writing(self):
+ """Called when the low-level transport's buffer drains below
+ the low-water mark.
+ """
+ self._app_protocol.resume_writing()
+
+ def data_received(self, data):
+ """Called when some SSL data is received.
+
+ The argument is a bytes object.
+ """
+ try:
+ ssldata, appdata = self._sslpipe.feed_ssldata(data)
+ except ssl.SSLError as e:
+ if self._loop.get_debug():
+ logger.warning('%r: SSL error %s (reason %s)',
+ self, e.errno, e.reason)
+ self._abort()
+ return
+
+ for chunk in ssldata:
+ self._transport.write(chunk)
+
+ for chunk in appdata:
+ if chunk:
+ self._app_protocol.data_received(chunk)
+ else:
+ self._start_shutdown()
+ break
+
+ def eof_received(self):
+ """Called when the other end of the low-level stream
+ is half-closed.
+
+ If this returns a false value (including None), the transport
+ will close itself. If it returns a true value, closing the
+ transport is up to the protocol.
+ """
+ try:
+ if self._loop.get_debug():
+ logger.debug("%r received EOF", self)
+
+ self._wakeup_waiter(ConnectionResetError)
+
+ if not self._in_handshake:
+ keep_open = self._app_protocol.eof_received()
+ if keep_open:
+ logger.warning('returning true from eof_received() '
+ 'has no effect when using ssl')
+ finally:
+ self._transport.close()
+
+ def _get_extra_info(self, name, default=None):
+ if name in self._extra:
+ return self._extra[name]
+ else:
+ return self._transport.get_extra_info(name, default)
+
+ def _start_shutdown(self):
+ if self._in_shutdown:
+ return
+ self._in_shutdown = True
+ self._write_appdata(b'')
+
+ def _write_appdata(self, data):
+ self._write_backlog.append((data, 0))
+ self._write_buffer_size += len(data)
+ self._process_write_backlog()
+
+ def _start_handshake(self):
+ if self._loop.get_debug():
+ logger.debug("%r starts SSL handshake", self)
+ self._handshake_start_time = self._loop.time()
+ else:
+ self._handshake_start_time = None
+ self._in_handshake = True
+ # (b'', 1) is a special value in _process_write_backlog() to do
+ # the SSL handshake
+ self._write_backlog.append((b'', 1))
+ self._loop.call_soon(self._process_write_backlog)
+
+ def _on_handshake_complete(self, handshake_exc):
+ self._in_handshake = False
+
+ sslobj = self._sslpipe.ssl_object
+ try:
+ if handshake_exc is not None:
+ raise handshake_exc
+
+ peercert = sslobj.getpeercert()
+ if not hasattr(self._sslcontext, 'check_hostname'):
+ # Verify hostname if requested, Python 3.4+ uses check_hostname
+ # and checks the hostname in do_handshake()
+ if (self._server_hostname
+ and self._sslcontext.verify_mode != ssl.CERT_NONE):
+ ssl.match_hostname(peercert, self._server_hostname)
+ except BaseException as exc:
+ if self._loop.get_debug():
+ if isinstance(exc, ssl.CertificateError):
+ logger.warning("%r: SSL handshake failed "
+ "on verifying the certificate",
+ self, exc_info=True)
+ else:
+ logger.warning("%r: SSL handshake failed",
+ self, exc_info=True)
+ self._transport.close()
+ if isinstance(exc, Exception):
+ self._wakeup_waiter(exc)
+ return
+ else:
+ raise
+
+ if self._loop.get_debug():
+ dt = self._loop.time() - self._handshake_start_time
+ logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
+
+ # Add extra info that becomes available after handshake.
+ self._extra.update(peercert=peercert,
+ cipher=sslobj.cipher(),
+ compression=sslobj.compression(),
+ )
+ self._app_protocol.connection_made(self._app_transport)
+ self._wakeup_waiter()
+ self._session_established = True
+ # In case transport.write() was already called. Don't call
+ # immediatly _process_write_backlog(), but schedule it:
+ # _on_handshake_complete() can be called indirectly from
+ # _process_write_backlog(), and _process_write_backlog() is not
+ # reentrant.
+ self._loop.call_soon(self._process_write_backlog)
+
+ def _process_write_backlog(self):
+ # Try to make progress on the write backlog.
+ if self._transport is None:
+ return
+
+ try:
+ for i in range(len(self._write_backlog)):
+ data, offset = self._write_backlog[0]
+ if data:
+ ssldata, offset = self._sslpipe.feed_appdata(data, offset)
+ elif offset:
+ ssldata = self._sslpipe.do_handshake(self._on_handshake_complete)
+ offset = 1
+ else:
+ ssldata = self._sslpipe.shutdown(self._finalize)
+ offset = 1
+
+ for chunk in ssldata:
+ self._transport.write(chunk)
+
+ if offset < len(data):
+ self._write_backlog[0] = (data, offset)
+ # A short write means that a write is blocked on a read
+ # We need to enable reading if it is paused!
+ assert self._sslpipe.need_ssldata
+ if self._transport._paused:
+ self._transport.resume_reading()
+ break
+
+ # An entire chunk from the backlog was processed. We can
+ # delete it and reduce the outstanding buffer size.
+ del self._write_backlog[0]
+ self._write_buffer_size -= len(data)
+ except BaseException as exc:
+ if self._in_handshake:
+ self._on_handshake_complete(exc)
+ else:
+ self._fatal_error(exc, 'Fatal error on SSL transport')
+
+ def _fatal_error(self, exc, message='Fatal error on transport'):
+ # Should be called from exception handler only.
+ if isinstance(exc, (BrokenPipeError, ConnectionResetError)):
+ if self._loop.get_debug():
+ logger.debug("%r: %s", self, message, exc_info=True)
+ else:
+ self._loop.call_exception_handler({
+ 'message': message,
+ 'exception': exc,
+ 'transport': self._transport,
+ 'protocol': self,
+ })
+ if self._transport:
+ self._transport._force_close(exc)
+
+ def _finalize(self):
+ if self._transport is not None:
+ self._transport.close()
+
+ def _abort(self):
+ if self._transport is not None:
+ try:
+ self._transport.abort()
+ finally:
+ self._finalize()
diff --git a/trollius/streams.py b/trollius/streams.py
new file mode 100644
index 0000000..176c65e
--- /dev/null
+++ b/trollius/streams.py
@@ -0,0 +1,501 @@
+"""Stream-related things."""
+
+__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
+ 'open_connection', 'start_server',
+ 'IncompleteReadError',
+ ]
+
+import socket
+import sys
+
+if hasattr(socket, 'AF_UNIX'):
+ __all__.extend(['open_unix_connection', 'start_unix_server'])
+
+from . import coroutines
+from . import events
+from . import futures
+from . import protocols
+from .coroutines import coroutine
+from .log import logger
+
+
+_DEFAULT_LIMIT = 2**16
+_PY35 = sys.version_info >= (3, 5)
+
+
+class IncompleteReadError(EOFError):
+ """
+ Incomplete read error. Attributes:
+
+ - partial: read bytes string before the end of stream was reached
+ - expected: total number of expected bytes
+ """
+ def __init__(self, partial, expected):
+ EOFError.__init__(self, "%s bytes read on a total of %s expected bytes"
+ % (len(partial), expected))
+ self.partial = partial
+ self.expected = expected
+
+
+@coroutine
+def open_connection(host=None, port=None, *,
+ loop=None, limit=_DEFAULT_LIMIT, **kwds):
+ """A wrapper for create_connection() returning a (reader, writer) pair.
+
+ The reader returned is a StreamReader instance; the writer is a
+ StreamWriter instance.
+
+ The arguments are all the usual arguments to create_connection()
+ except protocol_factory; most common are positional host and port,
+ with various optional keyword arguments following.
+
+ Additional optional keyword arguments are loop (to set the event loop
+ instance to use) and limit (to set the buffer limit passed to the
+ StreamReader).
+
+ (If you want to customize the StreamReader and/or
+ StreamReaderProtocol classes, just copy the code -- there's
+ really nothing special here except some convenience.)
+ """
+ if loop is None:
+ loop = events.get_event_loop()
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, loop=loop)
+ transport, _ = yield from loop.create_connection(
+ lambda: protocol, host, port, **kwds)
+ writer = StreamWriter(transport, protocol, reader, loop)
+ return reader, writer
+
+
+@coroutine
+def start_server(client_connected_cb, host=None, port=None, *,
+ loop=None, limit=_DEFAULT_LIMIT, **kwds):
+ """Start a socket server, call back for each client connected.
+
+ The first parameter, `client_connected_cb`, takes two parameters:
+ client_reader, client_writer. client_reader is a StreamReader
+ object, while client_writer is a StreamWriter object. This
+ parameter can either be a plain callback function or a coroutine;
+ if it is a coroutine, it will be automatically converted into a
+ Task.
+
+ The rest of the arguments are all the usual arguments to
+ loop.create_server() except protocol_factory; most common are
+ positional host and port, with various optional keyword arguments
+ following. The return value is the same as loop.create_server().
+
+ Additional optional keyword arguments are loop (to set the event loop
+ instance to use) and limit (to set the buffer limit passed to the
+ StreamReader).
+
+ The return value is the same as loop.create_server(), i.e. a
+ Server object which can be used to stop the service.
+ """
+ if loop is None:
+ loop = events.get_event_loop()
+
+ def factory():
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, client_connected_cb,
+ loop=loop)
+ return protocol
+
+ return (yield from loop.create_server(factory, host, port, **kwds))
+
+
+if hasattr(socket, 'AF_UNIX'):
+ # UNIX Domain Sockets are supported on this platform
+
+ @coroutine
+ def open_unix_connection(path=None, *,
+ loop=None, limit=_DEFAULT_LIMIT, **kwds):
+ """Similar to `open_connection` but works with UNIX Domain Sockets."""
+ if loop is None:
+ loop = events.get_event_loop()
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, loop=loop)
+ transport, _ = yield from loop.create_unix_connection(
+ lambda: protocol, path, **kwds)
+ writer = StreamWriter(transport, protocol, reader, loop)
+ return reader, writer
+
+
+ @coroutine
+ def start_unix_server(client_connected_cb, path=None, *,
+ loop=None, limit=_DEFAULT_LIMIT, **kwds):
+ """Similar to `start_server` but works with UNIX Domain Sockets."""
+ if loop is None:
+ loop = events.get_event_loop()
+
+ def factory():
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, client_connected_cb,
+ loop=loop)
+ return protocol
+
+ return (yield from loop.create_unix_server(factory, path, **kwds))
+
+
+class FlowControlMixin(protocols.Protocol):
+ """Reusable flow control logic for StreamWriter.drain().
+
+ This implements the protocol methods pause_writing(),
+ resume_reading() and connection_lost(). If the subclass overrides
+ these it must call the super methods.
+
+ StreamWriter.drain() must wait for _drain_helper() coroutine.
+ """
+
+ def __init__(self, loop=None):
+ if loop is None:
+ self._loop = events.get_event_loop()
+ else:
+ self._loop = loop
+ self._paused = False
+ self._drain_waiter = None
+ self._connection_lost = False
+
+ def pause_writing(self):
+ assert not self._paused
+ self._paused = True
+ if self._loop.get_debug():
+ logger.debug("%r pauses writing", self)
+
+ def resume_writing(self):
+ assert self._paused
+ self._paused = False
+ if self._loop.get_debug():
+ logger.debug("%r resumes writing", self)
+
+ waiter = self._drain_waiter
+ if waiter is not None:
+ self._drain_waiter = None
+ if not waiter.done():
+ waiter.set_result(None)
+
+ def connection_lost(self, exc):
+ self._connection_lost = True
+ # Wake up the writer if currently paused.
+ if not self._paused:
+ return
+ waiter = self._drain_waiter
+ if waiter is None:
+ return
+ self._drain_waiter = None
+ if waiter.done():
+ return
+ if exc is None:
+ waiter.set_result(None)
+ else:
+ waiter.set_exception(exc)
+
+ @coroutine
+ def _drain_helper(self):
+ if self._connection_lost:
+ raise ConnectionResetError('Connection lost')
+ if not self._paused:
+ return
+ waiter = self._drain_waiter
+ assert waiter is None or waiter.cancelled()
+ waiter = futures.Future(loop=self._loop)
+ self._drain_waiter = waiter
+ yield from waiter
+
+
+class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
+ """Helper class to adapt between Protocol and StreamReader.
+
+ (This is a helper class instead of making StreamReader itself a
+ Protocol subclass, because the StreamReader has other potential
+ uses, and to prevent the user of the StreamReader to accidentally
+ call inappropriate methods of the protocol.)
+ """
+
+ def __init__(self, stream_reader, client_connected_cb=None, loop=None):
+ super().__init__(loop=loop)
+ self._stream_reader = stream_reader
+ self._stream_writer = None
+ self._client_connected_cb = client_connected_cb
+
+ def connection_made(self, transport):
+ self._stream_reader.set_transport(transport)
+ if self._client_connected_cb is not None:
+ self._stream_writer = StreamWriter(transport, self,
+ self._stream_reader,
+ self._loop)
+ res = self._client_connected_cb(self._stream_reader,
+ self._stream_writer)
+ if coroutines.iscoroutine(res):
+ self._loop.create_task(res)
+
+ def connection_lost(self, exc):
+ if exc is None:
+ self._stream_reader.feed_eof()
+ else:
+ self._stream_reader.set_exception(exc)
+ super().connection_lost(exc)
+
+ def data_received(self, data):
+ self._stream_reader.feed_data(data)
+
+ def eof_received(self):
+ self._stream_reader.feed_eof()
+
+
+class StreamWriter:
+ """Wraps a Transport.
+
+ This exposes write(), writelines(), [can_]write_eof(),
+ get_extra_info() and close(). It adds drain() which returns an
+ optional Future on which you can wait for flow control. It also
+ adds a transport property which references the Transport
+ directly.
+ """
+
+ def __init__(self, transport, protocol, reader, loop):
+ self._transport = transport
+ self._protocol = protocol
+ # drain() expects that the reader has a exception() method
+ assert reader is None or isinstance(reader, StreamReader)
+ self._reader = reader
+ self._loop = loop
+
+ def __repr__(self):
+ info = [self.__class__.__name__, 'transport=%r' % self._transport]
+ if self._reader is not None:
+ info.append('reader=%r' % self._reader)
+ return '<%s>' % ' '.join(info)
+
+ @property
+ def transport(self):
+ return self._transport
+
+ def write(self, data):
+ self._transport.write(data)
+
+ def writelines(self, data):
+ self._transport.writelines(data)
+
+ def write_eof(self):
+ return self._transport.write_eof()
+
+ def can_write_eof(self):
+ return self._transport.can_write_eof()
+
+ def close(self):
+ return self._transport.close()
+
+ def get_extra_info(self, name, default=None):
+ return self._transport.get_extra_info(name, default)
+
+ @coroutine
+ def drain(self):
+ """Flush the write buffer.
+
+ The intended use is to write
+
+ w.write(data)
+ yield from w.drain()
+ """
+ if self._reader is not None:
+ exc = self._reader.exception()
+ if exc is not None:
+ raise exc
+ yield from self._protocol._drain_helper()
+
+
+class StreamReader:
+
+ def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
+ # The line length limit is a security feature;
+ # it also doubles as half the buffer limit.
+ self._limit = limit
+ if loop is None:
+ self._loop = events.get_event_loop()
+ else:
+ self._loop = loop
+ self._buffer = bytearray()
+ self._eof = False # Whether we're done.
+ self._waiter = None # A future used by _wait_for_data()
+ self._exception = None
+ self._transport = None
+ self._paused = False
+
+ def exception(self):
+ return self._exception
+
+ def set_exception(self, exc):
+ self._exception = exc
+
+ waiter = self._waiter
+ if waiter is not None:
+ self._waiter = None
+ if not waiter.cancelled():
+ waiter.set_exception(exc)
+
+ def _wakeup_waiter(self):
+ """Wakeup read() or readline() function waiting for data or EOF."""
+ waiter = self._waiter
+ if waiter is not None:
+ self._waiter = None
+ if not waiter.cancelled():
+ waiter.set_result(None)
+
+ def set_transport(self, transport):
+ assert self._transport is None, 'Transport already set'
+ self._transport = transport
+
+ def _maybe_resume_transport(self):
+ if self._paused and len(self._buffer) <= self._limit:
+ self._paused = False
+ self._transport.resume_reading()
+
+ def feed_eof(self):
+ self._eof = True
+ self._wakeup_waiter()
+
+ def at_eof(self):
+ """Return True if the buffer is empty and 'feed_eof' was called."""
+ return self._eof and not self._buffer
+
+ def feed_data(self, data):
+ assert not self._eof, 'feed_data after feed_eof'
+
+ if not data:
+ return
+
+ self._buffer.extend(data)
+ self._wakeup_waiter()
+
+ if (self._transport is not None and
+ not self._paused and
+ len(self._buffer) > 2*self._limit):
+ try:
+ self._transport.pause_reading()
+ except NotImplementedError:
+ # The transport can't be paused.
+ # We'll just have to buffer all data.
+ # Forget the transport so we don't keep trying.
+ self._transport = None
+ else:
+ self._paused = True
+
+ @coroutine
+ def _wait_for_data(self, func_name):
+ """Wait until feed_data() or feed_eof() is called."""
+ # StreamReader uses a future to link the protocol feed_data() method
+ # to a read coroutine. Running two read coroutines at the same time
+ # would have an unexpected behaviour. It would not possible to know
+ # which coroutine would get the next data.
+ if self._waiter is not None:
+ raise RuntimeError('%s() called while another coroutine is '
+ 'already waiting for incoming data' % func_name)
+
+ self._waiter = futures.Future(loop=self._loop)
+ try:
+ yield from self._waiter
+ finally:
+ self._waiter = None
+
+ @coroutine
+ def readline(self):
+ if self._exception is not None:
+ raise self._exception
+
+ line = bytearray()
+ not_enough = True
+
+ while not_enough:
+ while self._buffer and not_enough:
+ ichar = self._buffer.find(b'\n')
+ if ichar < 0:
+ line.extend(self._buffer)
+ self._buffer.clear()
+ else:
+ ichar += 1
+ line.extend(self._buffer[:ichar])
+ del self._buffer[:ichar]
+ not_enough = False
+
+ if len(line) > self._limit:
+ self._maybe_resume_transport()
+ raise ValueError('Line is too long')
+
+ if self._eof:
+ break
+
+ if not_enough:
+ yield from self._wait_for_data('readline')
+
+ self._maybe_resume_transport()
+ return bytes(line)
+
+ @coroutine
+ def read(self, n=-1):
+ if self._exception is not None:
+ raise self._exception
+
+ if not n:
+ return b''
+
+ if n < 0:
+ # This used to just loop creating a new waiter hoping to
+ # collect everything in self._buffer, but that would
+ # deadlock if the subprocess sends more than self.limit
+ # bytes. So just call self.read(self._limit) until EOF.
+ blocks = []
+ while True:
+ block = yield from self.read(self._limit)
+ if not block:
+ break
+ blocks.append(block)
+ return b''.join(blocks)
+ else:
+ if not self._buffer and not self._eof:
+ yield from self._wait_for_data('read')
+
+ if n < 0 or len(self._buffer) <= n:
+ data = bytes(self._buffer)
+ self._buffer.clear()
+ else:
+ # n > 0 and len(self._buffer) > n
+ data = bytes(self._buffer[:n])
+ del self._buffer[:n]
+
+ self._maybe_resume_transport()
+ return data
+
+ @coroutine
+ def readexactly(self, n):
+ if self._exception is not None:
+ raise self._exception
+
+ # There used to be "optimized" code here. It created its own
+ # Future and waited until self._buffer had at least the n
+ # bytes, then called read(n). Unfortunately, this could pause
+ # the transport if the argument was larger than the pause
+ # limit (which is twice self._limit). So now we just read()
+ # into a local buffer.
+
+ blocks = []
+ while n > 0:
+ block = yield from self.read(n)
+ if not block:
+ partial = b''.join(blocks)
+ raise IncompleteReadError(partial, len(partial) + n)
+ blocks.append(block)
+ n -= len(block)
+
+ return b''.join(blocks)
+
+ if _PY35:
+ @coroutine
+ def __aiter__(self):
+ return self
+
+ @coroutine
+ def __anext__(self):
+ val = yield from self.readline()
+ if val == b'':
+ raise StopAsyncIteration
+ return val
diff --git a/trollius/subprocess.py b/trollius/subprocess.py
new file mode 100644
index 0000000..4600a9f
--- /dev/null
+++ b/trollius/subprocess.py
@@ -0,0 +1,215 @@
+__all__ = ['create_subprocess_exec', 'create_subprocess_shell']
+
+import collections
+import subprocess
+
+from . import events
+from . import futures
+from . import protocols
+from . import streams
+from . import tasks
+from .coroutines import coroutine
+from .log import logger
+
+
+PIPE = subprocess.PIPE
+STDOUT = subprocess.STDOUT
+DEVNULL = subprocess.DEVNULL
+
+
+class SubprocessStreamProtocol(streams.FlowControlMixin,
+ protocols.SubprocessProtocol):
+ """Like StreamReaderProtocol, but for a subprocess."""
+
+ def __init__(self, limit, loop):
+ super().__init__(loop=loop)
+ self._limit = limit
+ self.stdin = self.stdout = self.stderr = None
+ self._transport = None
+
+ def __repr__(self):
+ info = [self.__class__.__name__]
+ if self.stdin is not None:
+ info.append('stdin=%r' % self.stdin)
+ if self.stdout is not None:
+ info.append('stdout=%r' % self.stdout)
+ if self.stderr is not None:
+ info.append('stderr=%r' % self.stderr)
+ return '<%s>' % ' '.join(info)
+
+ def connection_made(self, transport):
+ self._transport = transport
+
+ stdout_transport = transport.get_pipe_transport(1)
+ if stdout_transport is not None:
+ self.stdout = streams.StreamReader(limit=self._limit,
+ loop=self._loop)
+ self.stdout.set_transport(stdout_transport)
+
+ stderr_transport = transport.get_pipe_transport(2)
+ if stderr_transport is not None:
+ self.stderr = streams.StreamReader(limit=self._limit,
+ loop=self._loop)
+ self.stderr.set_transport(stderr_transport)
+
+ stdin_transport = transport.get_pipe_transport(0)
+ if stdin_transport is not None:
+ self.stdin = streams.StreamWriter(stdin_transport,
+ protocol=self,
+ reader=None,
+ loop=self._loop)
+
+ def pipe_data_received(self, fd, data):
+ if fd == 1:
+ reader = self.stdout
+ elif fd == 2:
+ reader = self.stderr
+ else:
+ reader = None
+ if reader is not None:
+ reader.feed_data(data)
+
+ def pipe_connection_lost(self, fd, exc):
+ if fd == 0:
+ pipe = self.stdin
+ if pipe is not None:
+ pipe.close()
+ self.connection_lost(exc)
+ return
+ if fd == 1:
+ reader = self.stdout
+ elif fd == 2:
+ reader = self.stderr
+ else:
+ reader = None
+ if reader != None:
+ if exc is None:
+ reader.feed_eof()
+ else:
+ reader.set_exception(exc)
+
+ def process_exited(self):
+ self._transport.close()
+ self._transport = None
+
+
+class Process:
+ def __init__(self, transport, protocol, loop):
+ self._transport = transport
+ self._protocol = protocol
+ self._loop = loop
+ self.stdin = protocol.stdin
+ self.stdout = protocol.stdout
+ self.stderr = protocol.stderr
+ self.pid = transport.get_pid()
+
+ def __repr__(self):
+ return '<%s %s>' % (self.__class__.__name__, self.pid)
+
+ @property
+ def returncode(self):
+ return self._transport.get_returncode()
+
+ @coroutine
+ def wait(self):
+ """Wait until the process exit and return the process return code.
+
+ This method is a coroutine."""
+ return (yield from self._transport._wait())
+
+ def send_signal(self, signal):
+ self._transport.send_signal(signal)
+
+ def terminate(self):
+ self._transport.terminate()
+
+ def kill(self):
+ self._transport.kill()
+
+ @coroutine
+ def _feed_stdin(self, input):
+ debug = self._loop.get_debug()
+ self.stdin.write(input)
+ if debug:
+ logger.debug('%r communicate: feed stdin (%s bytes)',
+ self, len(input))
+ try:
+ yield from self.stdin.drain()
+ except (BrokenPipeError, ConnectionResetError) as exc:
+ # communicate() ignores BrokenPipeError and ConnectionResetError
+ if debug:
+ logger.debug('%r communicate: stdin got %r', self, exc)
+
+ if debug:
+ logger.debug('%r communicate: close stdin', self)
+ self.stdin.close()
+
+ @coroutine
+ def _noop(self):
+ return None
+
+ @coroutine
+ def _read_stream(self, fd):
+ transport = self._transport.get_pipe_transport(fd)
+ if fd == 2:
+ stream = self.stderr
+ else:
+ assert fd == 1
+ stream = self.stdout
+ if self._loop.get_debug():
+ name = 'stdout' if fd == 1 else 'stderr'
+ logger.debug('%r communicate: read %s', self, name)
+ output = yield from stream.read()
+ if self._loop.get_debug():
+ name = 'stdout' if fd == 1 else 'stderr'
+ logger.debug('%r communicate: close %s', self, name)
+ transport.close()
+ return output
+
+ @coroutine
+ def communicate(self, input=None):
+ if input:
+ stdin = self._feed_stdin(input)
+ else:
+ stdin = self._noop()
+ if self.stdout is not None:
+ stdout = self._read_stream(1)
+ else:
+ stdout = self._noop()
+ if self.stderr is not None:
+ stderr = self._read_stream(2)
+ else:
+ stderr = self._noop()
+ stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr,
+ loop=self._loop)
+ yield from self.wait()
+ return (stdout, stderr)
+
+
+@coroutine
+def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None,
+ loop=None, limit=streams._DEFAULT_LIMIT, **kwds):
+ if loop is None:
+ loop = events.get_event_loop()
+ protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
+ loop=loop)
+ transport, protocol = yield from loop.subprocess_shell(
+ protocol_factory,
+ cmd, stdin=stdin, stdout=stdout,
+ stderr=stderr, **kwds)
+ return Process(transport, protocol, loop)
+
+@coroutine
+def create_subprocess_exec(program, *args, stdin=None, stdout=None,
+ stderr=None, loop=None,
+ limit=streams._DEFAULT_LIMIT, **kwds):
+ if loop is None:
+ loop = events.get_event_loop()
+ protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
+ loop=loop)
+ transport, protocol = yield from loop.subprocess_exec(
+ protocol_factory,
+ program, *args,
+ stdin=stdin, stdout=stdout,
+ stderr=stderr, **kwds)
+ return Process(transport, protocol, loop)
diff --git a/trollius/tasks.py b/trollius/tasks.py
new file mode 100644
index 0000000..d8193ba
--- /dev/null
+++ b/trollius/tasks.py
@@ -0,0 +1,681 @@
+"""Support for tasks, coroutines and the scheduler."""
+
+__all__ = ['Task',
+ 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED',
+ 'wait', 'wait_for', 'as_completed', 'sleep', 'async',
+ 'gather', 'shield', 'ensure_future',
+ ]
+
+import concurrent.futures
+import functools
+import inspect
+import linecache
+import sys
+import types
+import traceback
+import warnings
+import weakref
+
+from . import coroutines
+from . import events
+from . import futures
+from .coroutines import coroutine
+
+_PY34 = (sys.version_info >= (3, 4))
+
+
+class Task(futures.Future):
+ """A coroutine wrapped in a Future."""
+
+ # An important invariant maintained while a Task not done:
+ #
+ # - Either _fut_waiter is None, and _step() is scheduled;
+ # - or _fut_waiter is some Future, and _step() is *not* scheduled.
+ #
+ # The only transition from the latter to the former is through
+ # _wakeup(). When _fut_waiter is not None, one of its callbacks
+ # must be _wakeup().
+
+ # Weak set containing all tasks alive.
+ _all_tasks = weakref.WeakSet()
+
+ # Dictionary containing tasks that are currently active in
+ # all running event loops. {EventLoop: Task}
+ _current_tasks = {}
+
+ # If False, don't log a message if the task is destroyed whereas its
+ # status is still pending
+ _log_destroy_pending = True
+
+ @classmethod
+ def current_task(cls, loop=None):
+ """Return the currently running task in an event loop or None.
+
+ By default the current task for the current event loop is returned.
+
+ None is returned when called not in the context of a Task.
+ """
+ if loop is None:
+ loop = events.get_event_loop()
+ return cls._current_tasks.get(loop)
+
+ @classmethod
+ def all_tasks(cls, loop=None):
+ """Return a set of all tasks for an event loop.
+
+ By default all tasks for the current event loop are returned.
+ """
+ if loop is None:
+ loop = events.get_event_loop()
+ return {t for t in cls._all_tasks if t._loop is loop}
+
+ def __init__(self, coro, *, loop=None):
+ assert coroutines.iscoroutine(coro), repr(coro)
+ super().__init__(loop=loop)
+ if self._source_traceback:
+ del self._source_traceback[-1]
+ self._coro = coro
+ self._fut_waiter = None
+ self._must_cancel = False
+ self._loop.call_soon(self._step)
+ self.__class__._all_tasks.add(self)
+
+ # On Python 3.3 or older, objects with a destructor that are part of a
+ # reference cycle are never destroyed. That's not the case any more on
+ # Python 3.4 thanks to the PEP 442.
+ if _PY34:
+ def __del__(self):
+ if self._state == futures._PENDING and self._log_destroy_pending:
+ context = {
+ 'task': self,
+ 'message': 'Task was destroyed but it is pending!',
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ futures.Future.__del__(self)
+
+ def _repr_info(self):
+ info = super()._repr_info()
+
+ if self._must_cancel:
+ # replace status
+ info[0] = 'cancelling'
+
+ coro = coroutines._format_coroutine(self._coro)
+ info.insert(1, 'coro=<%s>' % coro)
+
+ if self._fut_waiter is not None:
+ info.insert(2, 'wait_for=%r' % self._fut_waiter)
+ return info
+
+ def get_stack(self, *, limit=None):
+ """Return the list of stack frames for this task's coroutine.
+
+ If the coroutine is not done, this returns the stack where it is
+ suspended. If the coroutine has completed successfully or was
+ cancelled, this returns an empty list. If the coroutine was
+ terminated by an exception, this returns the list of traceback
+ frames.
+
+ The frames are always ordered from oldest to newest.
+
+ The optional limit gives the maximum number of frames to
+ return; by default all available frames are returned. Its
+ meaning differs depending on whether a stack or a traceback is
+ returned: the newest frames of a stack are returned, but the
+ oldest frames of a traceback are returned. (This matches the
+ behavior of the traceback module.)
+
+ For reasons beyond our control, only one stack frame is
+ returned for a suspended coroutine.
+ """
+ frames = []
+ f = self._coro.gi_frame
+ if f is not None:
+ while f is not None:
+ if limit is not None:
+ if limit <= 0:
+ break
+ limit -= 1
+ frames.append(f)
+ f = f.f_back
+ frames.reverse()
+ elif self._exception is not None:
+ tb = self._exception.__traceback__
+ while tb is not None:
+ if limit is not None:
+ if limit <= 0:
+ break
+ limit -= 1
+ frames.append(tb.tb_frame)
+ tb = tb.tb_next
+ return frames
+
+ def print_stack(self, *, limit=None, file=None):
+ """Print the stack or traceback for this task's coroutine.
+
+ This produces output similar to that of the traceback module,
+ for the frames retrieved by get_stack(). The limit argument
+ is passed to get_stack(). The file argument is an I/O stream
+ to which the output is written; by default output is written
+ to sys.stderr.
+ """
+ extracted_list = []
+ checked = set()
+ for f in self.get_stack(limit=limit):
+ lineno = f.f_lineno
+ co = f.f_code
+ filename = co.co_filename
+ name = co.co_name
+ if filename not in checked:
+ checked.add(filename)
+ linecache.checkcache(filename)
+ line = linecache.getline(filename, lineno, f.f_globals)
+ extracted_list.append((filename, lineno, name, line))
+ exc = self._exception
+ if not extracted_list:
+ print('No stack for %r' % self, file=file)
+ elif exc is not None:
+ print('Traceback for %r (most recent call last):' % self,
+ file=file)
+ else:
+ print('Stack for %r (most recent call last):' % self,
+ file=file)
+ traceback.print_list(extracted_list, file=file)
+ if exc is not None:
+ for line in traceback.format_exception_only(exc.__class__, exc):
+ print(line, file=file, end='')
+
+ def cancel(self):
+ """Request that this task cancel itself.
+
+ This arranges for a CancelledError to be thrown into the
+ wrapped coroutine on the next cycle through the event loop.
+ The coroutine then has a chance to clean up or even deny
+ the request using try/except/finally.
+
+ Unlike Future.cancel, this does not guarantee that the
+ task will be cancelled: the exception might be caught and
+ acted upon, delaying cancellation of the task or preventing
+ cancellation completely. The task may also return a value or
+ raise a different exception.
+
+ Immediately after this method is called, Task.cancelled() will
+ not return True (unless the task was already cancelled). A
+ task will be marked as cancelled when the wrapped coroutine
+ terminates with a CancelledError exception (even if cancel()
+ was not called).
+ """
+ if self.done():
+ return False
+ if self._fut_waiter is not None:
+ if self._fut_waiter.cancel():
+ # Leave self._fut_waiter; it may be a Task that
+ # catches and ignores the cancellation so we may have
+ # to cancel it again later.
+ return True
+ # It must be the case that self._step is already scheduled.
+ self._must_cancel = True
+ return True
+
+ def _step(self, value=None, exc=None):
+ assert not self.done(), \
+ '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc)
+ if self._must_cancel:
+ if not isinstance(exc, futures.CancelledError):
+ exc = futures.CancelledError()
+ self._must_cancel = False
+ coro = self._coro
+ self._fut_waiter = None
+
+ self.__class__._current_tasks[self._loop] = self
+ # Call either coro.throw(exc) or coro.send(value).
+ try:
+ if exc is not None:
+ result = coro.throw(exc)
+ else:
+ result = coro.send(value)
+ except StopIteration as exc:
+ self.set_result(exc.value)
+ except futures.CancelledError as exc:
+ super().cancel() # I.e., Future.cancel(self).
+ except Exception as exc:
+ self.set_exception(exc)
+ except BaseException as exc:
+ self.set_exception(exc)
+ raise
+ else:
+ if isinstance(result, futures.Future):
+ # Yielded Future must come from Future.__iter__().
+ if result._blocking:
+ result._blocking = False
+ result.add_done_callback(self._wakeup)
+ self._fut_waiter = result
+ if self._must_cancel:
+ if self._fut_waiter.cancel():
+ self._must_cancel = False
+ else:
+ self._loop.call_soon(
+ self._step, None,
+ RuntimeError(
+ 'yield was used instead of yield from '
+ 'in task {!r} with {!r}'.format(self, result)))
+ elif result is None:
+ # Bare yield relinquishes control for one event loop iteration.
+ self._loop.call_soon(self._step)
+ elif inspect.isgenerator(result):
+ # Yielding a generator is just wrong.
+ self._loop.call_soon(
+ self._step, None,
+ RuntimeError(
+ 'yield was used instead of yield from for '
+ 'generator in task {!r} with {}'.format(
+ self, result)))
+ else:
+ # Yielding something else is an error.
+ self._loop.call_soon(
+ self._step, None,
+ RuntimeError(
+ 'Task got bad yield: {!r}'.format(result)))
+ finally:
+ self.__class__._current_tasks.pop(self._loop)
+ self = None # Needed to break cycles when an exception occurs.
+
+ def _wakeup(self, future):
+ try:
+ value = future.result()
+ except Exception as exc:
+ # This may also be a cancellation.
+ self._step(None, exc)
+ else:
+ self._step(value, None)
+ self = None # Needed to break cycles when an exception occurs.
+
+
+# wait() and as_completed() similar to those in PEP 3148.
+
+FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED
+FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION
+ALL_COMPLETED = concurrent.futures.ALL_COMPLETED
+
+
+@coroutine
+def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED):
+ """Wait for the Futures and coroutines given by fs to complete.
+
+ The sequence futures must not be empty.
+
+ Coroutines will be wrapped in Tasks.
+
+ Returns two sets of Future: (done, pending).
+
+ Usage:
+
+ done, pending = yield from asyncio.wait(fs)
+
+ Note: This does not raise TimeoutError! Futures that aren't done
+ when the timeout occurs are returned in the second set.
+ """
+ if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs):
+ raise TypeError("expect a list of futures, not %s" % type(fs).__name__)
+ if not fs:
+ raise ValueError('Set of coroutines/Futures is empty.')
+ if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED):
+ raise ValueError('Invalid return_when value: {}'.format(return_when))
+
+ if loop is None:
+ loop = events.get_event_loop()
+
+ fs = {ensure_future(f, loop=loop) for f in set(fs)}
+
+ return (yield from _wait(fs, timeout, return_when, loop))
+
+
+def _release_waiter(waiter, *args):
+ if not waiter.done():
+ waiter.set_result(None)
+
+
+@coroutine
+def wait_for(fut, timeout, *, loop=None):
+ """Wait for the single Future or coroutine to complete, with timeout.
+
+ Coroutine will be wrapped in Task.
+
+ Returns result of the Future or coroutine. When a timeout occurs,
+ it cancels the task and raises TimeoutError. To avoid the task
+ cancellation, wrap it in shield().
+
+ If the wait is cancelled, the task is also cancelled.
+
+ This function is a coroutine.
+ """
+ if loop is None:
+ loop = events.get_event_loop()
+
+ if timeout is None:
+ return (yield from fut)
+
+ waiter = futures.Future(loop=loop)
+ timeout_handle = loop.call_later(timeout, _release_waiter, waiter)
+ cb = functools.partial(_release_waiter, waiter)
+
+ fut = ensure_future(fut, loop=loop)
+ fut.add_done_callback(cb)
+
+ try:
+ # wait until the future completes or the timeout
+ try:
+ yield from waiter
+ except futures.CancelledError:
+ fut.remove_done_callback(cb)
+ fut.cancel()
+ raise
+
+ if fut.done():
+ return fut.result()
+ else:
+ fut.remove_done_callback(cb)
+ fut.cancel()
+ raise futures.TimeoutError()
+ finally:
+ timeout_handle.cancel()
+
+
+@coroutine
+def _wait(fs, timeout, return_when, loop):
+ """Internal helper for wait() and _wait_for().
+
+ The fs argument must be a collection of Futures.
+ """
+ assert fs, 'Set of Futures is empty.'
+ waiter = futures.Future(loop=loop)
+ timeout_handle = None
+ if timeout is not None:
+ timeout_handle = loop.call_later(timeout, _release_waiter, waiter)
+ counter = len(fs)
+
+ def _on_completion(f):
+ nonlocal counter
+ counter -= 1
+ if (counter <= 0 or
+ return_when == FIRST_COMPLETED or
+ return_when == FIRST_EXCEPTION and (not f.cancelled() and
+ f.exception() is not None)):
+ if timeout_handle is not None:
+ timeout_handle.cancel()
+ if not waiter.done():
+ waiter.set_result(None)
+
+ for f in fs:
+ f.add_done_callback(_on_completion)
+
+ try:
+ yield from waiter
+ finally:
+ if timeout_handle is not None:
+ timeout_handle.cancel()
+
+ done, pending = set(), set()
+ for f in fs:
+ f.remove_done_callback(_on_completion)
+ if f.done():
+ done.add(f)
+ else:
+ pending.add(f)
+ return done, pending
+
+
+# This is *not* a @coroutine! It is just an iterator (yielding Futures).
+def as_completed(fs, *, loop=None, timeout=None):
+ """Return an iterator whose values are coroutines.
+
+ When waiting for the yielded coroutines you'll get the results (or
+ exceptions!) of the original Futures (or coroutines), in the order
+ in which and as soon as they complete.
+
+ This differs from PEP 3148; the proper way to use this is:
+
+ for f in as_completed(fs):
+ result = yield from f # The 'yield from' may raise.
+ # Use result.
+
+ If a timeout is specified, the 'yield from' will raise
+ TimeoutError when the timeout occurs before all Futures are done.
+
+ Note: The futures 'f' are not necessarily members of fs.
+ """
+ if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs):
+ raise TypeError("expect a list of futures, not %s" % type(fs).__name__)
+ loop = loop if loop is not None else events.get_event_loop()
+ todo = {ensure_future(f, loop=loop) for f in set(fs)}
+ from .queues import Queue # Import here to avoid circular import problem.
+ done = Queue(loop=loop)
+ timeout_handle = None
+
+ def _on_timeout():
+ for f in todo:
+ f.remove_done_callback(_on_completion)
+ done.put_nowait(None) # Queue a dummy value for _wait_for_one().
+ todo.clear() # Can't do todo.remove(f) in the loop.
+
+ def _on_completion(f):
+ if not todo:
+ return # _on_timeout() was here first.
+ todo.remove(f)
+ done.put_nowait(f)
+ if not todo and timeout_handle is not None:
+ timeout_handle.cancel()
+
+ @coroutine
+ def _wait_for_one():
+ f = yield from done.get()
+ if f is None:
+ # Dummy value from _on_timeout().
+ raise futures.TimeoutError
+ return f.result() # May raise f.exception().
+
+ for f in todo:
+ f.add_done_callback(_on_completion)
+ if todo and timeout is not None:
+ timeout_handle = loop.call_later(timeout, _on_timeout)
+ for _ in range(len(todo)):
+ yield _wait_for_one()
+
+
+@coroutine
+def sleep(delay, result=None, *, loop=None):
+ """Coroutine that completes after a given time (in seconds)."""
+ future = futures.Future(loop=loop)
+ h = future._loop.call_later(delay,
+ future._set_result_unless_cancelled, result)
+ try:
+ return (yield from future)
+ finally:
+ h.cancel()
+
+
+def async(coro_or_future, *, loop=None):
+ """Wrap a coroutine in a future.
+
+ If the argument is a Future, it is returned directly.
+
+ This function is deprecated in 3.5. Use asyncio.ensure_future() instead.
+ """
+
+ warnings.warn("asyncio.async() function is deprecated, use ensure_future()",
+ DeprecationWarning)
+
+ return ensure_future(coro_or_future, loop=loop)
+
+
+def ensure_future(coro_or_future, *, loop=None):
+ """Wrap a coroutine in a future.
+
+ If the argument is a Future, it is returned directly.
+ """
+ if isinstance(coro_or_future, futures.Future):
+ if loop is not None and loop is not coro_or_future._loop:
+ raise ValueError('loop argument must agree with Future')
+ return coro_or_future
+ elif coroutines.iscoroutine(coro_or_future):
+ if loop is None:
+ loop = events.get_event_loop()
+ task = loop.create_task(coro_or_future)
+ if task._source_traceback:
+ del task._source_traceback[-1]
+ return task
+ else:
+ raise TypeError('A Future or coroutine is required')
+
+
+class _GatheringFuture(futures.Future):
+ """Helper for gather().
+
+ This overrides cancel() to cancel all the children and act more
+ like Task.cancel(), which doesn't immediately mark itself as
+ cancelled.
+ """
+
+ def __init__(self, children, *, loop=None):
+ super().__init__(loop=loop)
+ self._children = children
+
+ def cancel(self):
+ if self.done():
+ return False
+ for child in self._children:
+ child.cancel()
+ return True
+
+
+def gather(*coros_or_futures, loop=None, return_exceptions=False):
+ """Return a future aggregating results from the given coroutines
+ or futures.
+
+ All futures must share the same event loop. If all the tasks are
+ done successfully, the returned future's result is the list of
+ results (in the order of the original sequence, not necessarily
+ the order of results arrival). If *return_exceptions* is True,
+ exceptions in the tasks are treated the same as successful
+ results, and gathered in the result list; otherwise, the first
+ raised exception will be immediately propagated to the returned
+ future.
+
+ Cancellation: if the outer Future is cancelled, all children (that
+ have not completed yet) are also cancelled. If any child is
+ cancelled, this is treated as if it raised CancelledError --
+ the outer Future is *not* cancelled in this case. (This is to
+ prevent the cancellation of one child to cause other children to
+ be cancelled.)
+ """
+ if not coros_or_futures:
+ outer = futures.Future(loop=loop)
+ outer.set_result([])
+ return outer
+
+ arg_to_fut = {}
+ for arg in set(coros_or_futures):
+ if not isinstance(arg, futures.Future):
+ fut = ensure_future(arg, loop=loop)
+ if loop is None:
+ loop = fut._loop
+ # The caller cannot control this future, the "destroy pending task"
+ # warning should not be emitted.
+ fut._log_destroy_pending = False
+ else:
+ fut = arg
+ if loop is None:
+ loop = fut._loop
+ elif fut._loop is not loop:
+ raise ValueError("futures are tied to different event loops")
+ arg_to_fut[arg] = fut
+
+ children = [arg_to_fut[arg] for arg in coros_or_futures]
+ nchildren = len(children)
+ outer = _GatheringFuture(children, loop=loop)
+ nfinished = 0
+ results = [None] * nchildren
+
+ def _done_callback(i, fut):
+ nonlocal nfinished
+ if outer.done():
+ if not fut.cancelled():
+ # Mark exception retrieved.
+ fut.exception()
+ return
+
+ if fut.cancelled():
+ res = futures.CancelledError()
+ if not return_exceptions:
+ outer.set_exception(res)
+ return
+ elif fut._exception is not None:
+ res = fut.exception() # Mark exception retrieved.
+ if not return_exceptions:
+ outer.set_exception(res)
+ return
+ else:
+ res = fut._result
+ results[i] = res
+ nfinished += 1
+ if nfinished == nchildren:
+ outer.set_result(results)
+
+ for i, fut in enumerate(children):
+ fut.add_done_callback(functools.partial(_done_callback, i))
+ return outer
+
+
+def shield(arg, *, loop=None):
+ """Wait for a future, shielding it from cancellation.
+
+ The statement
+
+ res = yield from shield(something())
+
+ is exactly equivalent to the statement
+
+ res = yield from something()
+
+ *except* that if the coroutine containing it is cancelled, the
+ task running in something() is not cancelled. From the POV of
+ something(), the cancellation did not happen. But its caller is
+ still cancelled, so the yield-from expression still raises
+ CancelledError. Note: If something() is cancelled by other means
+ this will still cancel shield().
+
+ If you want to completely ignore cancellation (not recommended)
+ you can combine shield() with a try/except clause, as follows:
+
+ try:
+ res = yield from shield(something())
+ except CancelledError:
+ res = None
+ """
+ inner = ensure_future(arg, loop=loop)
+ if inner.done():
+ # Shortcut.
+ return inner
+ loop = inner._loop
+ outer = futures.Future(loop=loop)
+
+ def _done_callback(inner):
+ if outer.cancelled():
+ if not inner.cancelled():
+ # Mark inner's result as retrieved.
+ inner.exception()
+ return
+
+ if inner.cancelled():
+ outer.cancel()
+ else:
+ exc = inner.exception()
+ if exc is not None:
+ outer.set_exception(exc)
+ else:
+ outer.set_result(inner.result())
+
+ inner.add_done_callback(_done_callback)
+ return outer
diff --git a/trollius/test_support.py b/trollius/test_support.py
new file mode 100644
index 0000000..0fadfad
--- /dev/null
+++ b/trollius/test_support.py
@@ -0,0 +1,308 @@
+# Subset of test.support from CPython 3.5, just what we need to run asyncio
+# test suite. The code is copied from CPython 3.5 to not depend on the test
+# module because it is rarely installed.
+
+# Ignore symbol TEST_HOME_DIR: test_events works without it
+
+import functools
+import gc
+import os
+import platform
+import re
+import socket
+import subprocess
+import sys
+import time
+
+
+# A constant likely larger than the underlying OS pipe buffer size, to
+# make writes blocking.
+# Windows limit seems to be around 512 B, and many Unix kernels have a
+# 64 KiB pipe buffer size or 16 * PAGE_SIZE: take a few megs to be sure.
+# (see issue #17835 for a discussion of this number).
+PIPE_MAX_SIZE = 4 * 1024 * 1024 + 1
+
+def strip_python_stderr(stderr):
+ """Strip the stderr of a Python process from potential debug output
+ emitted by the interpreter.
+
+ This will typically be run on the result of the communicate() method
+ of a subprocess.Popen object.
+ """
+ stderr = re.sub(br"\[\d+ refs, \d+ blocks\]\r?\n?", b"", stderr).strip()
+ return stderr
+
+
+# Executing the interpreter in a subprocess
+def _assert_python(expected_success, *args, **env_vars):
+ if '__isolated' in env_vars:
+ isolated = env_vars.pop('__isolated')
+ else:
+ isolated = not env_vars
+ cmd_line = [sys.executable, '-X', 'faulthandler']
+ if isolated and sys.version_info >= (3, 4):
+ # isolated mode: ignore Python environment variables, ignore user
+ # site-packages, and don't add the current directory to sys.path
+ cmd_line.append('-I')
+ elif not env_vars:
+ # ignore Python environment variables
+ cmd_line.append('-E')
+ # Need to preserve the original environment, for in-place testing of
+ # shared library builds.
+ env = os.environ.copy()
+ # But a special flag that can be set to override -- in this case, the
+ # caller is responsible to pass the full environment.
+ if env_vars.pop('__cleanenv', None):
+ env = {}
+ env.update(env_vars)
+ cmd_line.extend(args)
+ p = subprocess.Popen(cmd_line, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ env=env)
+ try:
+ out, err = p.communicate()
+ finally:
+ subprocess._cleanup()
+ p.stdout.close()
+ p.stderr.close()
+ rc = p.returncode
+ err = strip_python_stderr(err)
+ if (rc and expected_success) or (not rc and not expected_success):
+ raise AssertionError(
+ "Process return code is %d, "
+ "stderr follows:\n%s" % (rc, err.decode('ascii', 'ignore')))
+ return rc, out, err
+
+
+def assert_python_ok(*args, **env_vars):
+ """
+ Assert that running the interpreter with `args` and optional environment
+ variables `env_vars` succeeds (rc == 0) and return a (return code, stdout,
+ stderr) tuple.
+
+ If the __cleanenv keyword is set, env_vars is used a fresh environment.
+
+ Python is started in isolated mode (command line option -I),
+ except if the __isolated keyword is set to False.
+ """
+ return _assert_python(True, *args, **env_vars)
+
+
+is_jython = sys.platform.startswith('java')
+
+def gc_collect():
+ """Force as many objects as possible to be collected.
+
+ In non-CPython implementations of Python, this is needed because timely
+ deallocation is not guaranteed by the garbage collector. (Even in CPython
+ this can be the case in case of reference cycles.) This means that __del__
+ methods may be called later than expected and weakrefs may remain alive for
+ longer than expected. This function tries its best to force all garbage
+ objects to disappear.
+ """
+ gc.collect()
+ if is_jython:
+ time.sleep(0.1)
+ gc.collect()
+ gc.collect()
+
+
+HOST = "127.0.0.1"
+HOSTv6 = "::1"
+
+
+def _is_ipv6_enabled():
+ """Check whether IPv6 is enabled on this host."""
+ if socket.has_ipv6:
+ sock = None
+ try:
+ sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ sock.bind((HOSTv6, 0))
+ return True
+ except OSError:
+ pass
+ finally:
+ if sock:
+ sock.close()
+ return False
+
+IPV6_ENABLED = _is_ipv6_enabled()
+
+
+def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM):
+ """Returns an unused port that should be suitable for binding. This is
+ achieved by creating a temporary socket with the same family and type as
+ the 'sock' parameter (default is AF_INET, SOCK_STREAM), and binding it to
+ the specified host address (defaults to 0.0.0.0) with the port set to 0,
+ eliciting an unused ephemeral port from the OS. The temporary socket is
+ then closed and deleted, and the ephemeral port is returned.
+
+ Either this method or bind_port() should be used for any tests where a
+ server socket needs to be bound to a particular port for the duration of
+ the test. Which one to use depends on whether the calling code is creating
+ a python socket, or if an unused port needs to be provided in a constructor
+ or passed to an external program (i.e. the -accept argument to openssl's
+ s_server mode). Always prefer bind_port() over find_unused_port() where
+ possible. Hard coded ports should *NEVER* be used. As soon as a server
+ socket is bound to a hard coded port, the ability to run multiple instances
+ of the test simultaneously on the same host is compromised, which makes the
+ test a ticking time bomb in a buildbot environment. On Unix buildbots, this
+ may simply manifest as a failed test, which can be recovered from without
+ intervention in most cases, but on Windows, the entire python process can
+ completely and utterly wedge, requiring someone to log in to the buildbot
+ and manually kill the affected process.
+
+ (This is easy to reproduce on Windows, unfortunately, and can be traced to
+ the SO_REUSEADDR socket option having different semantics on Windows versus
+ Unix/Linux. On Unix, you can't have two AF_INET SOCK_STREAM sockets bind,
+ listen and then accept connections on identical host/ports. An EADDRINUSE
+ OSError will be raised at some point (depending on the platform and
+ the order bind and listen were called on each socket).
+
+ However, on Windows, if SO_REUSEADDR is set on the sockets, no EADDRINUSE
+ will ever be raised when attempting to bind two identical host/ports. When
+ accept() is called on each socket, the second caller's process will steal
+ the port from the first caller, leaving them both in an awkwardly wedged
+ state where they'll no longer respond to any signals or graceful kills, and
+ must be forcibly killed via OpenProcess()/TerminateProcess().
+
+ The solution on Windows is to use the SO_EXCLUSIVEADDRUSE socket option
+ instead of SO_REUSEADDR, which effectively affords the same semantics as
+ SO_REUSEADDR on Unix. Given the propensity of Unix developers in the Open
+ Source world compared to Windows ones, this is a common mistake. A quick
+ look over OpenSSL's 0.9.8g source shows that they use SO_REUSEADDR when
+ openssl.exe is called with the 's_server' option, for example. See
+ http://bugs.python.org/issue2550 for more info. The following site also
+ has a very thorough description about the implications of both REUSEADDR
+ and EXCLUSIVEADDRUSE on Windows:
+ http://msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx)
+
+ XXX: although this approach is a vast improvement on previous attempts to
+ elicit unused ports, it rests heavily on the assumption that the ephemeral
+ port returned to us by the OS won't immediately be dished back out to some
+ other process when we close and delete our temporary socket but before our
+ calling code has a chance to bind the returned port. We can deal with this
+ issue if/when we come across it.
+ """
+
+ tempsock = socket.socket(family, socktype)
+ port = bind_port(tempsock)
+ tempsock.close()
+ del tempsock
+ return port
+
+def bind_port(sock, host=HOST):
+ """Bind the socket to a free port and return the port number. Relies on
+ ephemeral ports in order to ensure we are using an unbound port. This is
+ important as many tests may be running simultaneously, especially in a
+ buildbot environment. This method raises an exception if the sock.family
+ is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR
+ or SO_REUSEPORT set on it. Tests should *never* set these socket options
+ for TCP/IP sockets. The only case for setting these options is testing
+ multicasting via multiple UDP sockets.
+
+ Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e.
+ on Windows), it will be set on the socket. This will prevent anyone else
+ from bind()'ing to our host/port for the duration of the test.
+ """
+
+ if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM:
+ if hasattr(socket, 'SO_REUSEADDR'):
+ if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1:
+ raise TestFailed("tests should never set the SO_REUSEADDR "
+ "socket option on TCP/IP sockets!")
+ if hasattr(socket, 'SO_REUSEPORT'):
+ try:
+ reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT)
+ if reuse == 1:
+ raise TestFailed("tests should never set the SO_REUSEPORT "
+ "socket option on TCP/IP sockets!")
+ except OSError:
+ # Python's socket module was compiled using modern headers
+ # thus defining SO_REUSEPORT but this process is running
+ # under an older kernel that does not support SO_REUSEPORT.
+ pass
+ if hasattr(socket, 'SO_EXCLUSIVEADDRUSE'):
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
+
+ sock.bind((host, 0))
+ port = sock.getsockname()[1]
+ return port
+
+def requires_mac_ver(*min_version):
+ """Decorator raising SkipTest if the OS is Mac OS X and the OS X
+ version if less than min_version.
+
+ For example, @requires_mac_ver(10, 5) raises SkipTest if the OS X version
+ is lesser than 10.5.
+ """
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kw):
+ if sys.platform == 'darwin':
+ version_txt = platform.mac_ver()[0]
+ try:
+ version = tuple(map(int, version_txt.split('.')))
+ except ValueError:
+ pass
+ else:
+ if version < min_version:
+ min_version_txt = '.'.join(map(str, min_version))
+ raise unittest.SkipTest(
+ "Mac OS X %s or higher required, not %s"
+ % (min_version_txt, version_txt))
+ return func(*args, **kw)
+ wrapper.min_version = min_version
+ return wrapper
+ return decorator
+
+def _requires_unix_version(sysname, min_version):
+ """Decorator raising SkipTest if the OS is `sysname` and the version is
+ less than `min_version`.
+
+ For example, @_requires_unix_version('FreeBSD', (7, 2)) raises SkipTest if
+ the FreeBSD version is less than 7.2.
+ """
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kw):
+ if platform.system() == sysname:
+ version_txt = platform.release().split('-', 1)[0]
+ try:
+ version = tuple(map(int, version_txt.split('.')))
+ except ValueError:
+ pass
+ else:
+ if version < min_version:
+ min_version_txt = '.'.join(map(str, min_version))
+ raise unittest.SkipTest(
+ "%s version %s or higher required, not %s"
+ % (sysname, min_version_txt, version_txt))
+ return func(*args, **kw)
+ wrapper.min_version = min_version
+ return wrapper
+ return decorator
+
+def requires_freebsd_version(*min_version):
+ """Decorator raising SkipTest if the OS is FreeBSD and the FreeBSD version
+ is less than `min_version`.
+
+ For example, @requires_freebsd_version(7, 2) raises SkipTest if the FreeBSD
+ version is less than 7.2.
+ """
+ return _requires_unix_version('FreeBSD', min_version)
+
+# Use test.support if available
+try:
+ from test.support import *
+except ImportError:
+ pass
+
+# Use test.script_helper if available
+try:
+ from test.support.script_helper import assert_python_ok
+except ImportError:
+ try:
+ from test.script_helper import assert_python_ok
+ except ImportError:
+ pass
diff --git a/trollius/test_utils.py b/trollius/test_utils.py
new file mode 100644
index 0000000..8cee95b
--- /dev/null
+++ b/trollius/test_utils.py
@@ -0,0 +1,446 @@
+"""Utilities shared by tests."""
+
+import collections
+import contextlib
+import io
+import logging
+import os
+import re
+import socket
+import socketserver
+import sys
+import tempfile
+import threading
+import time
+import unittest
+from unittest import mock
+
+from http.server import HTTPServer
+from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
+
+try:
+ import ssl
+except ImportError: # pragma: no cover
+ ssl = None
+
+from . import base_events
+from . import events
+from . import futures
+from . import selectors
+from . import tasks
+from .coroutines import coroutine
+from .log import logger
+
+
+if sys.platform == 'win32': # pragma: no cover
+ from .windows_utils import socketpair
+else:
+ from socket import socketpair # pragma: no cover
+
+
+def dummy_ssl_context():
+ if ssl is None:
+ return None
+ else:
+ return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+
+
+def run_briefly(loop):
+ @coroutine
+ def once():
+ pass
+ gen = once()
+ t = loop.create_task(gen)
+ # Don't log a warning if the task is not done after run_until_complete().
+ # It occurs if the loop is stopped or if a task raises a BaseException.
+ t._log_destroy_pending = False
+ try:
+ loop.run_until_complete(t)
+ finally:
+ gen.close()
+
+
+def run_until(loop, pred, timeout=30):
+ deadline = time.time() + timeout
+ while not pred():
+ if timeout is not None:
+ timeout = deadline - time.time()
+ if timeout <= 0:
+ raise futures.TimeoutError()
+ loop.run_until_complete(tasks.sleep(0.001, loop=loop))
+
+
+def run_once(loop):
+ """loop.stop() schedules _raise_stop_error()
+ and run_forever() runs until _raise_stop_error() callback.
+ this wont work if test waits for some IO events, because
+ _raise_stop_error() runs before any of io events callbacks.
+ """
+ loop.stop()
+ loop.run_forever()
+
+
+class SilentWSGIRequestHandler(WSGIRequestHandler):
+
+ def get_stderr(self):
+ return io.StringIO()
+
+ def log_message(self, format, *args):
+ pass
+
+
+class SilentWSGIServer(WSGIServer):
+
+ request_timeout = 2
+
+ def get_request(self):
+ request, client_addr = super().get_request()
+ request.settimeout(self.request_timeout)
+ return request, client_addr
+
+ def handle_error(self, request, client_address):
+ pass
+
+
+class SSLWSGIServerMixin:
+
+ def finish_request(self, request, client_address):
+ # The relative location of our test directory (which
+ # contains the ssl key and certificate files) differs
+ # between the stdlib and stand-alone asyncio.
+ # Prefer our own if we can find it.
+ here = os.path.join(os.path.dirname(__file__), '..', 'tests')
+ if not os.path.isdir(here):
+ here = os.path.join(os.path.dirname(os.__file__),
+ 'test', 'test_asyncio')
+ keyfile = os.path.join(here, 'ssl_key.pem')
+ certfile = os.path.join(here, 'ssl_cert.pem')
+ ssock = ssl.wrap_socket(request,
+ keyfile=keyfile,
+ certfile=certfile,
+ server_side=True)
+ try:
+ self.RequestHandlerClass(ssock, client_address, self)
+ ssock.close()
+ except OSError:
+ # maybe socket has been closed by peer
+ pass
+
+
+class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
+ pass
+
+
+def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
+
+ def app(environ, start_response):
+ status = '200 OK'
+ headers = [('Content-type', 'text/plain')]
+ start_response(status, headers)
+ return [b'Test message']
+
+ # Run the test WSGI server in a separate thread in order not to
+ # interfere with event handling in the main thread
+ server_class = server_ssl_cls if use_ssl else server_cls
+ httpd = server_class(address, SilentWSGIRequestHandler)
+ httpd.set_app(app)
+ httpd.address = httpd.server_address
+ server_thread = threading.Thread(
+ target=lambda: httpd.serve_forever(poll_interval=0.05))
+ server_thread.start()
+ try:
+ yield httpd
+ finally:
+ httpd.shutdown()
+ httpd.server_close()
+ server_thread.join()
+
+
+if hasattr(socket, 'AF_UNIX'):
+
+ class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
+
+ def server_bind(self):
+ socketserver.UnixStreamServer.server_bind(self)
+ self.server_name = '127.0.0.1'
+ self.server_port = 80
+
+
+ class UnixWSGIServer(UnixHTTPServer, WSGIServer):
+
+ request_timeout = 2
+
+ def server_bind(self):
+ UnixHTTPServer.server_bind(self)
+ self.setup_environ()
+
+ def get_request(self):
+ request, client_addr = super().get_request()
+ request.settimeout(self.request_timeout)
+ # Code in the stdlib expects that get_request
+ # will return a socket and a tuple (host, port).
+ # However, this isn't true for UNIX sockets,
+ # as the second return value will be a path;
+ # hence we return some fake data sufficient
+ # to get the tests going
+ return request, ('127.0.0.1', '')
+
+
+ class SilentUnixWSGIServer(UnixWSGIServer):
+
+ def handle_error(self, request, client_address):
+ pass
+
+
+ class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
+ pass
+
+
+ def gen_unix_socket_path():
+ with tempfile.NamedTemporaryFile() as file:
+ return file.name
+
+
+ @contextlib.contextmanager
+ def unix_socket_path():
+ path = gen_unix_socket_path()
+ try:
+ yield path
+ finally:
+ try:
+ os.unlink(path)
+ except OSError:
+ pass
+
+
+ @contextlib.contextmanager
+ def run_test_unix_server(*, use_ssl=False):
+ with unix_socket_path() as path:
+ yield from _run_test_server(address=path, use_ssl=use_ssl,
+ server_cls=SilentUnixWSGIServer,
+ server_ssl_cls=UnixSSLWSGIServer)
+
+
+@contextlib.contextmanager
+def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+ yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
+ server_cls=SilentWSGIServer,
+ server_ssl_cls=SSLWSGIServer)
+
+
+def make_test_protocol(base):
+ dct = {}
+ for name in dir(base):
+ if name.startswith('__') and name.endswith('__'):
+ # skip magic names
+ continue
+ dct[name] = MockCallback(return_value=None)
+ return type('TestProtocol', (base,) + base.__bases__, dct)()
+
+
+class TestSelector(selectors.BaseSelector):
+
+ def __init__(self):
+ self.keys = {}
+
+ def register(self, fileobj, events, data=None):
+ key = selectors.SelectorKey(fileobj, 0, events, data)
+ self.keys[fileobj] = key
+ return key
+
+ def unregister(self, fileobj):
+ return self.keys.pop(fileobj)
+
+ def select(self, timeout):
+ return []
+
+ def get_map(self):
+ return self.keys
+
+
+class TestLoop(base_events.BaseEventLoop):
+ """Loop for unittests.
+
+ It manages self time directly.
+ If something scheduled to be executed later then
+ on next loop iteration after all ready handlers done
+ generator passed to __init__ is calling.
+
+ Generator should be like this:
+
+ def gen():
+ ...
+ when = yield ...
+ ... = yield time_advance
+
+ Value returned by yield is absolute time of next scheduled handler.
+ Value passed to yield is time advance to move loop's time forward.
+ """
+
+ def __init__(self, gen=None):
+ super().__init__()
+
+ if gen is None:
+ def gen():
+ yield
+ self._check_on_close = False
+ else:
+ self._check_on_close = True
+
+ self._gen = gen()
+ next(self._gen)
+ self._time = 0
+ self._clock_resolution = 1e-9
+ self._timers = []
+ self._selector = TestSelector()
+
+ self.readers = {}
+ self.writers = {}
+ self.reset_counters()
+
+ def time(self):
+ return self._time
+
+ def advance_time(self, advance):
+ """Move test time forward."""
+ if advance:
+ self._time += advance
+
+ def close(self):
+ super().close()
+ if self._check_on_close:
+ try:
+ self._gen.send(0)
+ except StopIteration:
+ pass
+ else: # pragma: no cover
+ raise AssertionError("Time generator is not finished")
+
+ def add_reader(self, fd, callback, *args):
+ self.readers[fd] = events.Handle(callback, args, self)
+
+ def remove_reader(self, fd):
+ self.remove_reader_count[fd] += 1
+ if fd in self.readers:
+ del self.readers[fd]
+ return True
+ else:
+ return False
+
+ def assert_reader(self, fd, callback, *args):
+ assert fd in self.readers, 'fd {} is not registered'.format(fd)
+ handle = self.readers[fd]
+ assert handle._callback == callback, '{!r} != {!r}'.format(
+ handle._callback, callback)
+ assert handle._args == args, '{!r} != {!r}'.format(
+ handle._args, args)
+
+ def add_writer(self, fd, callback, *args):
+ self.writers[fd] = events.Handle(callback, args, self)
+
+ def remove_writer(self, fd):
+ self.remove_writer_count[fd] += 1
+ if fd in self.writers:
+ del self.writers[fd]
+ return True
+ else:
+ return False
+
+ def assert_writer(self, fd, callback, *args):
+ assert fd in self.writers, 'fd {} is not registered'.format(fd)
+ handle = self.writers[fd]
+ assert handle._callback == callback, '{!r} != {!r}'.format(
+ handle._callback, callback)
+ assert handle._args == args, '{!r} != {!r}'.format(
+ handle._args, args)
+
+ def reset_counters(self):
+ self.remove_reader_count = collections.defaultdict(int)
+ self.remove_writer_count = collections.defaultdict(int)
+
+ def _run_once(self):
+ super()._run_once()
+ for when in self._timers:
+ advance = self._gen.send(when)
+ self.advance_time(advance)
+ self._timers = []
+
+ def call_at(self, when, callback, *args):
+ self._timers.append(when)
+ return super().call_at(when, callback, *args)
+
+ def _process_events(self, event_list):
+ return
+
+ def _write_to_self(self):
+ pass
+
+
+def MockCallback(**kwargs):
+ return mock.Mock(spec=['__call__'], **kwargs)
+
+
+class MockPattern(str):
+ """A regex based str with a fuzzy __eq__.
+
+ Use this helper with 'mock.assert_called_with', or anywhere
+ where a regex comparison between strings is needed.
+
+ For instance:
+ mock_call.assert_called_with(MockPattern('spam.*ham'))
+ """
+ def __eq__(self, other):
+ return bool(re.search(str(self), other, re.S))
+
+
+def get_function_source(func):
+ source = events._get_function_source(func)
+ if source is None:
+ raise ValueError("unable to get the source of %r" % (func,))
+ return source
+
+
+class TestCase(unittest.TestCase):
+ def set_event_loop(self, loop, *, cleanup=True):
+ assert loop is not None
+ # ensure that the event loop is passed explicitly in asyncio
+ events.set_event_loop(None)
+ if cleanup:
+ self.addCleanup(loop.close)
+
+ def new_test_loop(self, gen=None):
+ loop = TestLoop(gen)
+ self.set_event_loop(loop)
+ return loop
+
+ def tearDown(self):
+ events.set_event_loop(None)
+
+ # Detect CPython bug #23353: ensure that yield/yield-from is not used
+ # in an except block of a generator
+ self.assertEqual(sys.exc_info(), (None, None, None))
+
+
+@contextlib.contextmanager
+def disable_logger():
+ """Context manager to disable asyncio logger.
+
+ For example, it can be used to ignore warnings in debug mode.
+ """
+ old_level = logger.level
+ try:
+ logger.setLevel(logging.CRITICAL+1)
+ yield
+ finally:
+ logger.setLevel(old_level)
+
+def mock_nonblocking_socket():
+ """Create a mock of a non-blocking socket."""
+ sock = mock.Mock(socket.socket)
+ sock.gettimeout.return_value = 0.0
+ return sock
+
+
+def force_legacy_ssl_support():
+ return mock.patch('asyncio.sslproto._is_sslproto_available',
+ return_value=False)
diff --git a/trollius/transports.py b/trollius/transports.py
new file mode 100644
index 0000000..22df3c7
--- /dev/null
+++ b/trollius/transports.py
@@ -0,0 +1,300 @@
+"""Abstract Transport class."""
+
+import sys
+
+_PY34 = sys.version_info >= (3, 4)
+
+__all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport',
+ 'Transport', 'DatagramTransport', 'SubprocessTransport',
+ ]
+
+
+class BaseTransport:
+ """Base class for transports."""
+
+ def __init__(self, extra=None):
+ if extra is None:
+ extra = {}
+ self._extra = extra
+
+ def get_extra_info(self, name, default=None):
+ """Get optional transport information."""
+ return self._extra.get(name, default)
+
+ def close(self):
+ """Close the transport.
+
+ Buffered data will be flushed asynchronously. No more data
+ will be received. After all buffered data is flushed, the
+ protocol's connection_lost() method will (eventually) called
+ with None as its argument.
+ """
+ raise NotImplementedError
+
+
+class ReadTransport(BaseTransport):
+ """Interface for read-only transports."""
+
+ def pause_reading(self):
+ """Pause the receiving end.
+
+ No data will be passed to the protocol's data_received()
+ method until resume_reading() is called.
+ """
+ raise NotImplementedError
+
+ def resume_reading(self):
+ """Resume the receiving end.
+
+ Data received will once again be passed to the protocol's
+ data_received() method.
+ """
+ raise NotImplementedError
+
+
+class WriteTransport(BaseTransport):
+ """Interface for write-only transports."""
+
+ def set_write_buffer_limits(self, high=None, low=None):
+ """Set the high- and low-water limits for write flow control.
+
+ These two values control when to call the protocol's
+ pause_writing() and resume_writing() methods. If specified,
+ the low-water limit must be less than or equal to the
+ high-water limit. Neither value can be negative.
+
+ The defaults are implementation-specific. If only the
+ high-water limit is given, the low-water limit defaults to a
+ implementation-specific value less than or equal to the
+ high-water limit. Setting high to zero forces low to zero as
+ well, and causes pause_writing() to be called whenever the
+ buffer becomes non-empty. Setting low to zero causes
+ resume_writing() to be called only once the buffer is empty.
+ Use of zero for either limit is generally sub-optimal as it
+ reduces opportunities for doing I/O and computation
+ concurrently.
+ """
+ raise NotImplementedError
+
+ def get_write_buffer_size(self):
+ """Return the current size of the write buffer."""
+ raise NotImplementedError
+
+ def write(self, data):
+ """Write some data bytes to the transport.
+
+ This does not block; it buffers the data and arranges for it
+ to be sent out asynchronously.
+ """
+ raise NotImplementedError
+
+ def writelines(self, list_of_data):
+ """Write a list (or any iterable) of data bytes to the transport.
+
+ The default implementation concatenates the arguments and
+ calls write() on the result.
+ """
+ if not _PY34:
+ # In Python 3.3, bytes.join() doesn't handle memoryview.
+ list_of_data = (
+ bytes(data) if isinstance(data, memoryview) else data
+ for data in list_of_data)
+ self.write(b''.join(list_of_data))
+
+ def write_eof(self):
+ """Close the write end after flushing buffered data.
+
+ (This is like typing ^D into a UNIX program reading from stdin.)
+
+ Data may still be received.
+ """
+ raise NotImplementedError
+
+ def can_write_eof(self):
+ """Return True if this transport supports write_eof(), False if not."""
+ raise NotImplementedError
+
+ def abort(self):
+ """Close the transport immediately.
+
+ Buffered data will be lost. No more data will be received.
+ The protocol's connection_lost() method will (eventually) be
+ called with None as its argument.
+ """
+ raise NotImplementedError
+
+
+class Transport(ReadTransport, WriteTransport):
+ """Interface representing a bidirectional transport.
+
+ There may be several implementations, but typically, the user does
+ not implement new transports; rather, the platform provides some
+ useful transports that are implemented using the platform's best
+ practices.
+
+ The user never instantiates a transport directly; they call a
+ utility function, passing it a protocol factory and other
+ information necessary to create the transport and protocol. (E.g.
+ EventLoop.create_connection() or EventLoop.create_server().)
+
+ The utility function will asynchronously create a transport and a
+ protocol and hook them up by calling the protocol's
+ connection_made() method, passing it the transport.
+
+ The implementation here raises NotImplemented for every method
+ except writelines(), which calls write() in a loop.
+ """
+
+
+class DatagramTransport(BaseTransport):
+ """Interface for datagram (UDP) transports."""
+
+ def sendto(self, data, addr=None):
+ """Send data to the transport.
+
+ This does not block; it buffers the data and arranges for it
+ to be sent out asynchronously.
+ addr is target socket address.
+ If addr is None use target address pointed on transport creation.
+ """
+ raise NotImplementedError
+
+ def abort(self):
+ """Close the transport immediately.
+
+ Buffered data will be lost. No more data will be received.
+ The protocol's connection_lost() method will (eventually) be
+ called with None as its argument.
+ """
+ raise NotImplementedError
+
+
+class SubprocessTransport(BaseTransport):
+
+ def get_pid(self):
+ """Get subprocess id."""
+ raise NotImplementedError
+
+ def get_returncode(self):
+ """Get subprocess returncode.
+
+ See also
+ http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode
+ """
+ raise NotImplementedError
+
+ def get_pipe_transport(self, fd):
+ """Get transport for pipe with number fd."""
+ raise NotImplementedError
+
+ def send_signal(self, signal):
+ """Send signal to subprocess.
+
+ See also:
+ docs.python.org/3/library/subprocess#subprocess.Popen.send_signal
+ """
+ raise NotImplementedError
+
+ def terminate(self):
+ """Stop the subprocess.
+
+ Alias for close() method.
+
+ On Posix OSs the method sends SIGTERM to the subprocess.
+ On Windows the Win32 API function TerminateProcess()
+ is called to stop the subprocess.
+
+ See also:
+ http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate
+ """
+ raise NotImplementedError
+
+ def kill(self):
+ """Kill the subprocess.
+
+ On Posix OSs the function sends SIGKILL to the subprocess.
+ On Windows kill() is an alias for terminate().
+
+ See also:
+ http://docs.python.org/3/library/subprocess#subprocess.Popen.kill
+ """
+ raise NotImplementedError
+
+
+class _FlowControlMixin(Transport):
+ """All the logic for (write) flow control in a mix-in base class.
+
+ The subclass must implement get_write_buffer_size(). It must call
+ _maybe_pause_protocol() whenever the write buffer size increases,
+ and _maybe_resume_protocol() whenever it decreases. It may also
+ override set_write_buffer_limits() (e.g. to specify different
+ defaults).
+
+ The subclass constructor must call super().__init__(extra). This
+ will call set_write_buffer_limits().
+
+ The user may call set_write_buffer_limits() and
+ get_write_buffer_size(), and their protocol's pause_writing() and
+ resume_writing() may be called.
+ """
+
+ def __init__(self, extra=None, loop=None):
+ super().__init__(extra)
+ assert loop is not None
+ self._loop = loop
+ self._protocol_paused = False
+ self._set_write_buffer_limits()
+
+ def _maybe_pause_protocol(self):
+ size = self.get_write_buffer_size()
+ if size <= self._high_water:
+ return
+ if not self._protocol_paused:
+ self._protocol_paused = True
+ try:
+ self._protocol.pause_writing()
+ except Exception as exc:
+ self._loop.call_exception_handler({
+ 'message': 'protocol.pause_writing() failed',
+ 'exception': exc,
+ 'transport': self,
+ 'protocol': self._protocol,
+ })
+
+ def _maybe_resume_protocol(self):
+ if (self._protocol_paused and
+ self.get_write_buffer_size() <= self._low_water):
+ self._protocol_paused = False
+ try:
+ self._protocol.resume_writing()
+ except Exception as exc:
+ self._loop.call_exception_handler({
+ 'message': 'protocol.resume_writing() failed',
+ 'exception': exc,
+ 'transport': self,
+ 'protocol': self._protocol,
+ })
+
+ def get_write_buffer_limits(self):
+ return (self._low_water, self._high_water)
+
+ def _set_write_buffer_limits(self, high=None, low=None):
+ if high is None:
+ if low is None:
+ high = 64*1024
+ else:
+ high = 4*low
+ if low is None:
+ low = high // 4
+ if not high >= low >= 0:
+ raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
+ (high, low))
+ self._high_water = high
+ self._low_water = low
+
+ def set_write_buffer_limits(self, high=None, low=None):
+ self._set_write_buffer_limits(high=high, low=low)
+ self._maybe_pause_protocol()
+
+ def get_write_buffer_size(self):
+ raise NotImplementedError
diff --git a/trollius/unix_events.py b/trollius/unix_events.py
new file mode 100644
index 0000000..75e7c9c
--- /dev/null
+++ b/trollius/unix_events.py
@@ -0,0 +1,998 @@
+"""Selector event loop for Unix with signal handling."""
+
+import errno
+import os
+import signal
+import socket
+import stat
+import subprocess
+import sys
+import threading
+import warnings
+
+
+from . import base_events
+from . import base_subprocess
+from . import constants
+from . import coroutines
+from . import events
+from . import futures
+from . import selector_events
+from . import selectors
+from . import transports
+from .coroutines import coroutine
+from .log import logger
+
+
+__all__ = ['SelectorEventLoop',
+ 'AbstractChildWatcher', 'SafeChildWatcher',
+ 'FastChildWatcher', 'DefaultEventLoopPolicy',
+ ]
+
+if sys.platform == 'win32': # pragma: no cover
+ raise ImportError('Signals are not really supported on Windows')
+
+
+def _sighandler_noop(signum, frame):
+ """Dummy signal handler."""
+ pass
+
+
+class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
+ """Unix event loop.
+
+ Adds signal handling and UNIX Domain Socket support to SelectorEventLoop.
+ """
+
+ def __init__(self, selector=None):
+ super().__init__(selector)
+ self._signal_handlers = {}
+
+ def _socketpair(self):
+ return socket.socketpair()
+
+ def close(self):
+ super().close()
+ for sig in list(self._signal_handlers):
+ self.remove_signal_handler(sig)
+
+ def _process_self_data(self, data):
+ for signum in data:
+ if not signum:
+ # ignore null bytes written by _write_to_self()
+ continue
+ self._handle_signal(signum)
+
+ def add_signal_handler(self, sig, callback, *args):
+ """Add a handler for a signal. UNIX only.
+
+ Raise ValueError if the signal number is invalid or uncatchable.
+ Raise RuntimeError if there is a problem setting up the handler.
+ """
+ if (coroutines.iscoroutine(callback)
+ or coroutines.iscoroutinefunction(callback)):
+ raise TypeError("coroutines cannot be used "
+ "with add_signal_handler()")
+ self._check_signal(sig)
+ self._check_closed()
+ try:
+ # set_wakeup_fd() raises ValueError if this is not the
+ # main thread. By calling it early we ensure that an
+ # event loop running in another thread cannot add a signal
+ # handler.
+ signal.set_wakeup_fd(self._csock.fileno())
+ except (ValueError, OSError) as exc:
+ raise RuntimeError(str(exc))
+
+ handle = events.Handle(callback, args, self)
+ self._signal_handlers[sig] = handle
+
+ try:
+ # Register a dummy signal handler to ask Python to write the signal
+ # number in the wakup file descriptor. _process_self_data() will
+ # read signal numbers from this file descriptor to handle signals.
+ signal.signal(sig, _sighandler_noop)
+
+ # Set SA_RESTART to limit EINTR occurrences.
+ signal.siginterrupt(sig, False)
+ except OSError as exc:
+ del self._signal_handlers[sig]
+ if not self._signal_handlers:
+ try:
+ signal.set_wakeup_fd(-1)
+ except (ValueError, OSError) as nexc:
+ logger.info('set_wakeup_fd(-1) failed: %s', nexc)
+
+ if exc.errno == errno.EINVAL:
+ raise RuntimeError('sig {} cannot be caught'.format(sig))
+ else:
+ raise
+
+ def _handle_signal(self, sig):
+ """Internal helper that is the actual signal handler."""
+ handle = self._signal_handlers.get(sig)
+ if handle is None:
+ return # Assume it's some race condition.
+ if handle._cancelled:
+ self.remove_signal_handler(sig) # Remove it properly.
+ else:
+ self._add_callback_signalsafe(handle)
+
+ def remove_signal_handler(self, sig):
+ """Remove a handler for a signal. UNIX only.
+
+ Return True if a signal handler was removed, False if not.
+ """
+ self._check_signal(sig)
+ try:
+ del self._signal_handlers[sig]
+ except KeyError:
+ return False
+
+ if sig == signal.SIGINT:
+ handler = signal.default_int_handler
+ else:
+ handler = signal.SIG_DFL
+
+ try:
+ signal.signal(sig, handler)
+ except OSError as exc:
+ if exc.errno == errno.EINVAL:
+ raise RuntimeError('sig {} cannot be caught'.format(sig))
+ else:
+ raise
+
+ if not self._signal_handlers:
+ try:
+ signal.set_wakeup_fd(-1)
+ except (ValueError, OSError) as exc:
+ logger.info('set_wakeup_fd(-1) failed: %s', exc)
+
+ return True
+
+ def _check_signal(self, sig):
+ """Internal helper to validate a signal.
+
+ Raise ValueError if the signal number is invalid or uncatchable.
+ Raise RuntimeError if there is a problem setting up the handler.
+ """
+ if not isinstance(sig, int):
+ raise TypeError('sig must be an int, not {!r}'.format(sig))
+
+ if not (1 <= sig < signal.NSIG):
+ raise ValueError(
+ 'sig {} out of range(1, {})'.format(sig, signal.NSIG))
+
+ def _make_read_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra)
+
+ def _make_write_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra)
+
+ @coroutine
+ def _make_subprocess_transport(self, protocol, args, shell,
+ stdin, stdout, stderr, bufsize,
+ extra=None, **kwargs):
+ with events.get_child_watcher() as watcher:
+ waiter = futures.Future(loop=self)
+ transp = _UnixSubprocessTransport(self, protocol, args, shell,
+ stdin, stdout, stderr, bufsize,
+ waiter=waiter, extra=extra,
+ **kwargs)
+
+ watcher.add_child_handler(transp.get_pid(),
+ self._child_watcher_callback, transp)
+ try:
+ yield from waiter
+ except Exception as exc:
+ # Workaround CPython bug #23353: using yield/yield-from in an
+ # except block of a generator doesn't clear properly
+ # sys.exc_info()
+ err = exc
+ else:
+ err = None
+
+ if err is not None:
+ transp.close()
+ yield from transp._wait()
+ raise err
+
+ return transp
+
+ def _child_watcher_callback(self, pid, returncode, transp):
+ self.call_soon_threadsafe(transp._process_exited, returncode)
+
+ @coroutine
+ def create_unix_connection(self, protocol_factory, path, *,
+ ssl=None, sock=None,
+ server_hostname=None):
+ assert server_hostname is None or isinstance(server_hostname, str)
+ if ssl:
+ if server_hostname is None:
+ raise ValueError(
+ 'you have to pass server_hostname when using ssl')
+ else:
+ if server_hostname is not None:
+ raise ValueError('server_hostname is only meaningful with ssl')
+
+ if path is not None:
+ if sock is not None:
+ raise ValueError(
+ 'path and sock can not be specified at the same time')
+
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
+ try:
+ sock.setblocking(False)
+ yield from self.sock_connect(sock, path)
+ except:
+ sock.close()
+ raise
+
+ else:
+ if sock is None:
+ raise ValueError('no path and sock were specified')
+ sock.setblocking(False)
+
+ transport, protocol = yield from self._create_connection_transport(
+ sock, protocol_factory, ssl, server_hostname)
+ return transport, protocol
+
+ @coroutine
+ def create_unix_server(self, protocol_factory, path=None, *,
+ sock=None, backlog=100, ssl=None):
+ if isinstance(ssl, bool):
+ raise TypeError('ssl argument must be an SSLContext or None')
+
+ if path is not None:
+ if sock is not None:
+ raise ValueError(
+ 'path and sock can not be specified at the same time')
+
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+
+ try:
+ sock.bind(path)
+ except OSError as exc:
+ sock.close()
+ if exc.errno == errno.EADDRINUSE:
+ # Let's improve the error message by adding
+ # with what exact address it occurs.
+ msg = 'Address {!r} is already in use'.format(path)
+ raise OSError(errno.EADDRINUSE, msg) from None
+ else:
+ raise
+ except:
+ sock.close()
+ raise
+ else:
+ if sock is None:
+ raise ValueError(
+ 'path was not specified, and no sock specified')
+
+ if sock.family != socket.AF_UNIX:
+ raise ValueError(
+ 'A UNIX Domain Socket was expected, got {!r}'.format(sock))
+
+ server = base_events.Server(self, [sock])
+ sock.listen(backlog)
+ sock.setblocking(False)
+ self._start_serving(protocol_factory, sock, ssl, server)
+ return server
+
+
+if hasattr(os, 'set_blocking'):
+ def _set_nonblocking(fd):
+ os.set_blocking(fd, False)
+else:
+ import fcntl
+
+ def _set_nonblocking(fd):
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL)
+ flags = flags | os.O_NONBLOCK
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags)
+
+
+class _UnixReadPipeTransport(transports.ReadTransport):
+
+ max_size = 256 * 1024 # max bytes we read in one event loop iteration
+
+ def __init__(self, loop, pipe, protocol, waiter=None, extra=None):
+ super().__init__(extra)
+ self._extra['pipe'] = pipe
+ self._loop = loop
+ self._pipe = pipe
+ self._fileno = pipe.fileno()
+ mode = os.fstat(self._fileno).st_mode
+ if not (stat.S_ISFIFO(mode) or
+ stat.S_ISSOCK(mode) or
+ stat.S_ISCHR(mode)):
+ raise ValueError("Pipe transport is for pipes/sockets only.")
+ _set_nonblocking(self._fileno)
+ self._protocol = protocol
+ self._closing = False
+ self._loop.call_soon(self._protocol.connection_made, self)
+ # only start reading when connection_made() has been called
+ self._loop.call_soon(self._loop.add_reader,
+ self._fileno, self._read_ready)
+ if waiter is not None:
+ # only wake up the waiter when connection_made() has been called
+ self._loop.call_soon(waiter._set_result_unless_cancelled, None)
+
+ def __repr__(self):
+ info = [self.__class__.__name__]
+ if self._pipe is None:
+ info.append('closed')
+ elif self._closing:
+ info.append('closing')
+ info.append('fd=%s' % self._fileno)
+ if self._pipe is not None:
+ polling = selector_events._test_selector_event(
+ self._loop._selector,
+ self._fileno, selectors.EVENT_READ)
+ if polling:
+ info.append('polling')
+ else:
+ info.append('idle')
+ else:
+ info.append('closed')
+ return '<%s>' % ' '.join(info)
+
+ def _read_ready(self):
+ try:
+ data = os.read(self._fileno, self.max_size)
+ except (BlockingIOError, InterruptedError):
+ pass
+ except OSError as exc:
+ self._fatal_error(exc, 'Fatal read error on pipe transport')
+ else:
+ if data:
+ self._protocol.data_received(data)
+ else:
+ if self._loop.get_debug():
+ logger.info("%r was closed by peer", self)
+ self._closing = True
+ self._loop.remove_reader(self._fileno)
+ self._loop.call_soon(self._protocol.eof_received)
+ self._loop.call_soon(self._call_connection_lost, None)
+
+ def pause_reading(self):
+ self._loop.remove_reader(self._fileno)
+
+ def resume_reading(self):
+ self._loop.add_reader(self._fileno, self._read_ready)
+
+ def close(self):
+ if not self._closing:
+ self._close(None)
+
+ # On Python 3.3 and older, objects with a destructor part of a reference
+ # cycle are never destroyed. It's not more the case on Python 3.4 thanks
+ # to the PEP 442.
+ if sys.version_info >= (3, 4):
+ def __del__(self):
+ if self._pipe is not None:
+ warnings.warn("unclosed transport %r" % self, ResourceWarning)
+ self._pipe.close()
+
+ def _fatal_error(self, exc, message='Fatal error on pipe transport'):
+ # should be called by exception handler only
+ if (isinstance(exc, OSError) and exc.errno == errno.EIO):
+ if self._loop.get_debug():
+ logger.debug("%r: %s", self, message, exc_info=True)
+ else:
+ self._loop.call_exception_handler({
+ 'message': message,
+ 'exception': exc,
+ 'transport': self,
+ 'protocol': self._protocol,
+ })
+ self._close(exc)
+
+ def _close(self, exc):
+ self._closing = True
+ self._loop.remove_reader(self._fileno)
+ self._loop.call_soon(self._call_connection_lost, exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._pipe.close()
+ self._pipe = None
+ self._protocol = None
+ self._loop = None
+
+
+class _UnixWritePipeTransport(transports._FlowControlMixin,
+ transports.WriteTransport):
+
+ def __init__(self, loop, pipe, protocol, waiter=None, extra=None):
+ super().__init__(extra, loop)
+ self._extra['pipe'] = pipe
+ self._pipe = pipe
+ self._fileno = pipe.fileno()
+ mode = os.fstat(self._fileno).st_mode
+ is_socket = stat.S_ISSOCK(mode)
+ if not (is_socket or
+ stat.S_ISFIFO(mode) or
+ stat.S_ISCHR(mode)):
+ raise ValueError("Pipe transport is only for "
+ "pipes, sockets and character devices")
+ _set_nonblocking(self._fileno)
+ self._protocol = protocol
+ self._buffer = []
+ self._conn_lost = 0
+ self._closing = False # Set when close() or write_eof() called.
+
+ self._loop.call_soon(self._protocol.connection_made, self)
+
+ # On AIX, the reader trick (to be notified when the read end of the
+ # socket is closed) only works for sockets. On other platforms it
+ # works for pipes and sockets. (Exception: OS X 10.4? Issue #19294.)
+ if is_socket or not sys.platform.startswith("aix"):
+ # only start reading when connection_made() has been called
+ self._loop.call_soon(self._loop.add_reader,
+ self._fileno, self._read_ready)
+
+ if waiter is not None:
+ # only wake up the waiter when connection_made() has been called
+ self._loop.call_soon(waiter._set_result_unless_cancelled, None)
+
+ def __repr__(self):
+ info = [self.__class__.__name__]
+ if self._pipe is None:
+ info.append('closed')
+ elif self._closing:
+ info.append('closing')
+ info.append('fd=%s' % self._fileno)
+ if self._pipe is not None:
+ polling = selector_events._test_selector_event(
+ self._loop._selector,
+ self._fileno, selectors.EVENT_WRITE)
+ if polling:
+ info.append('polling')
+ else:
+ info.append('idle')
+
+ bufsize = self.get_write_buffer_size()
+ info.append('bufsize=%s' % bufsize)
+ else:
+ info.append('closed')
+ return '<%s>' % ' '.join(info)
+
+ def get_write_buffer_size(self):
+ return sum(len(data) for data in self._buffer)
+
+ def _read_ready(self):
+ # Pipe was closed by peer.
+ if self._loop.get_debug():
+ logger.info("%r was closed by peer", self)
+ if self._buffer:
+ self._close(BrokenPipeError())
+ else:
+ self._close()
+
+ def write(self, data):
+ assert isinstance(data, (bytes, bytearray, memoryview)), repr(data)
+ if isinstance(data, bytearray):
+ data = memoryview(data)
+ if not data:
+ return
+
+ if self._conn_lost or self._closing:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ logger.warning('pipe closed by peer or '
+ 'os.write(pipe, data) raised exception.')
+ self._conn_lost += 1
+ return
+
+ if not self._buffer:
+ # Attempt to send it right away first.
+ try:
+ n = os.write(self._fileno, data)
+ except (BlockingIOError, InterruptedError):
+ n = 0
+ except Exception as exc:
+ self._conn_lost += 1
+ self._fatal_error(exc, 'Fatal write error on pipe transport')
+ return
+ if n == len(data):
+ return
+ elif n > 0:
+ data = data[n:]
+ self._loop.add_writer(self._fileno, self._write_ready)
+
+ self._buffer.append(data)
+ self._maybe_pause_protocol()
+
+ def _write_ready(self):
+ data = b''.join(self._buffer)
+ assert data, 'Data should not be empty'
+
+ self._buffer.clear()
+ try:
+ n = os.write(self._fileno, data)
+ except (BlockingIOError, InterruptedError):
+ self._buffer.append(data)
+ except Exception as exc:
+ self._conn_lost += 1
+ # Remove writer here, _fatal_error() doesn't it
+ # because _buffer is empty.
+ self._loop.remove_writer(self._fileno)
+ self._fatal_error(exc, 'Fatal write error on pipe transport')
+ else:
+ if n == len(data):
+ self._loop.remove_writer(self._fileno)
+ self._maybe_resume_protocol() # May append to buffer.
+ if not self._buffer and self._closing:
+ self._loop.remove_reader(self._fileno)
+ self._call_connection_lost(None)
+ return
+ elif n > 0:
+ data = data[n:]
+
+ self._buffer.append(data) # Try again later.
+
+ def can_write_eof(self):
+ return True
+
+ def write_eof(self):
+ if self._closing:
+ return
+ assert self._pipe
+ self._closing = True
+ if not self._buffer:
+ self._loop.remove_reader(self._fileno)
+ self._loop.call_soon(self._call_connection_lost, None)
+
+ def close(self):
+ if self._pipe is not None and not self._closing:
+ # write_eof is all what we needed to close the write pipe
+ self.write_eof()
+
+ # On Python 3.3 and older, objects with a destructor part of a reference
+ # cycle are never destroyed. It's not more the case on Python 3.4 thanks
+ # to the PEP 442.
+ if sys.version_info >= (3, 4):
+ def __del__(self):
+ if self._pipe is not None:
+ warnings.warn("unclosed transport %r" % self, ResourceWarning)
+ self._pipe.close()
+
+ def abort(self):
+ self._close(None)
+
+ def _fatal_error(self, exc, message='Fatal error on pipe transport'):
+ # should be called by exception handler only
+ if isinstance(exc, (BrokenPipeError, ConnectionResetError)):
+ if self._loop.get_debug():
+ logger.debug("%r: %s", self, message, exc_info=True)
+ else:
+ self._loop.call_exception_handler({
+ 'message': message,
+ 'exception': exc,
+ 'transport': self,
+ 'protocol': self._protocol,
+ })
+ self._close(exc)
+
+ def _close(self, exc=None):
+ self._closing = True
+ if self._buffer:
+ self._loop.remove_writer(self._fileno)
+ self._buffer.clear()
+ self._loop.remove_reader(self._fileno)
+ self._loop.call_soon(self._call_connection_lost, exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._pipe.close()
+ self._pipe = None
+ self._protocol = None
+ self._loop = None
+
+
+if hasattr(os, 'set_inheritable'):
+ # Python 3.4 and newer
+ _set_inheritable = os.set_inheritable
+else:
+ import fcntl
+
+ def _set_inheritable(fd, inheritable):
+ cloexec_flag = getattr(fcntl, 'FD_CLOEXEC', 1)
+
+ old = fcntl.fcntl(fd, fcntl.F_GETFD)
+ if not inheritable:
+ fcntl.fcntl(fd, fcntl.F_SETFD, old | cloexec_flag)
+ else:
+ fcntl.fcntl(fd, fcntl.F_SETFD, old & ~cloexec_flag)
+
+
+class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport):
+
+ def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs):
+ stdin_w = None
+ if stdin == subprocess.PIPE:
+ # Use a socket pair for stdin, since not all platforms
+ # support selecting read events on the write end of a
+ # socket (which we use in order to detect closing of the
+ # other end). Notably this is needed on AIX, and works
+ # just fine on other platforms.
+ stdin, stdin_w = self._loop._socketpair()
+
+ # Mark the write end of the stdin pipe as non-inheritable,
+ # needed by close_fds=False on Python 3.3 and older
+ # (Python 3.4 implements the PEP 446, socketpair returns
+ # non-inheritable sockets)
+ _set_inheritable(stdin_w.fileno(), False)
+ self._proc = subprocess.Popen(
+ args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr,
+ universal_newlines=False, bufsize=bufsize, **kwargs)
+ if stdin_w is not None:
+ stdin.close()
+ self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize)
+
+
+class AbstractChildWatcher:
+ """Abstract base class for monitoring child processes.
+
+ Objects derived from this class monitor a collection of subprocesses and
+ report their termination or interruption by a signal.
+
+ New callbacks are registered with .add_child_handler(). Starting a new
+ process must be done within a 'with' block to allow the watcher to suspend
+ its activity until the new process if fully registered (this is needed to
+ prevent a race condition in some implementations).
+
+ Example:
+ with watcher:
+ proc = subprocess.Popen("sleep 1")
+ watcher.add_child_handler(proc.pid, callback)
+
+ Notes:
+ Implementations of this class must be thread-safe.
+
+ Since child watcher objects may catch the SIGCHLD signal and call
+ waitpid(-1), there should be only one active object per process.
+ """
+
+ def add_child_handler(self, pid, callback, *args):
+ """Register a new child handler.
+
+ Arrange for callback(pid, returncode, *args) to be called when
+ process 'pid' terminates. Specifying another callback for the same
+ process replaces the previous handler.
+
+ Note: callback() must be thread-safe.
+ """
+ raise NotImplementedError()
+
+ def remove_child_handler(self, pid):
+ """Removes the handler for process 'pid'.
+
+ The function returns True if the handler was successfully removed,
+ False if there was nothing to remove."""
+
+ raise NotImplementedError()
+
+ def attach_loop(self, loop):
+ """Attach the watcher to an event loop.
+
+ If the watcher was previously attached to an event loop, then it is
+ first detached before attaching to the new loop.
+
+ Note: loop may be None.
+ """
+ raise NotImplementedError()
+
+ def close(self):
+ """Close the watcher.
+
+ This must be called to make sure that any underlying resource is freed.
+ """
+ raise NotImplementedError()
+
+ def __enter__(self):
+ """Enter the watcher's context and allow starting new processes
+
+ This function must return self"""
+ raise NotImplementedError()
+
+ def __exit__(self, a, b, c):
+ """Exit the watcher's context"""
+ raise NotImplementedError()
+
+
+class BaseChildWatcher(AbstractChildWatcher):
+
+ def __init__(self):
+ self._loop = None
+
+ def close(self):
+ self.attach_loop(None)
+
+ def _do_waitpid(self, expected_pid):
+ raise NotImplementedError()
+
+ def _do_waitpid_all(self):
+ raise NotImplementedError()
+
+ def attach_loop(self, loop):
+ assert loop is None or isinstance(loop, events.AbstractEventLoop)
+
+ if self._loop is not None:
+ self._loop.remove_signal_handler(signal.SIGCHLD)
+
+ self._loop = loop
+ if loop is not None:
+ loop.add_signal_handler(signal.SIGCHLD, self._sig_chld)
+
+ # Prevent a race condition in case a child terminated
+ # during the switch.
+ self._do_waitpid_all()
+
+ def _sig_chld(self):
+ try:
+ self._do_waitpid_all()
+ except Exception as exc:
+ # self._loop should always be available here
+ # as '_sig_chld' is added as a signal handler
+ # in 'attach_loop'
+ self._loop.call_exception_handler({
+ 'message': 'Unknown exception in SIGCHLD handler',
+ 'exception': exc,
+ })
+
+ def _compute_returncode(self, status):
+ if os.WIFSIGNALED(status):
+ # The child process died because of a signal.
+ return -os.WTERMSIG(status)
+ elif os.WIFEXITED(status):
+ # The child process exited (e.g sys.exit()).
+ return os.WEXITSTATUS(status)
+ else:
+ # The child exited, but we don't understand its status.
+ # This shouldn't happen, but if it does, let's just
+ # return that status; perhaps that helps debug it.
+ return status
+
+
+class SafeChildWatcher(BaseChildWatcher):
+ """'Safe' child watcher implementation.
+
+ This implementation avoids disrupting other code spawning processes by
+ polling explicitly each process in the SIGCHLD handler instead of calling
+ os.waitpid(-1).
+
+ This is a safe solution but it has a significant overhead when handling a
+ big number of children (O(n) each time SIGCHLD is raised)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self._callbacks = {}
+
+ def close(self):
+ self._callbacks.clear()
+ super().close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, a, b, c):
+ pass
+
+ def add_child_handler(self, pid, callback, *args):
+ self._callbacks[pid] = (callback, args)
+
+ # Prevent a race condition in case the child is already terminated.
+ self._do_waitpid(pid)
+
+ def remove_child_handler(self, pid):
+ try:
+ del self._callbacks[pid]
+ return True
+ except KeyError:
+ return False
+
+ def _do_waitpid_all(self):
+
+ for pid in list(self._callbacks):
+ self._do_waitpid(pid)
+
+ def _do_waitpid(self, expected_pid):
+ assert expected_pid > 0
+
+ try:
+ pid, status = os.waitpid(expected_pid, os.WNOHANG)
+ except ChildProcessError:
+ # The child process is already reaped
+ # (may happen if waitpid() is called elsewhere).
+ pid = expected_pid
+ returncode = 255
+ logger.warning(
+ "Unknown child process pid %d, will report returncode 255",
+ pid)
+ else:
+ if pid == 0:
+ # The child process is still alive.
+ return
+
+ returncode = self._compute_returncode(status)
+ if self._loop.get_debug():
+ logger.debug('process %s exited with returncode %s',
+ expected_pid, returncode)
+
+ try:
+ callback, args = self._callbacks.pop(pid)
+ except KeyError: # pragma: no cover
+ # May happen if .remove_child_handler() is called
+ # after os.waitpid() returns.
+ if self._loop.get_debug():
+ logger.warning("Child watcher got an unexpected pid: %r",
+ pid, exc_info=True)
+ else:
+ callback(pid, returncode, *args)
+
+
+class FastChildWatcher(BaseChildWatcher):
+ """'Fast' child watcher implementation.
+
+ This implementation reaps every terminated processes by calling
+ os.waitpid(-1) directly, possibly breaking other code spawning processes
+ and waiting for their termination.
+
+ There is no noticeable overhead when handling a big number of children
+ (O(1) each time a child terminates).
+ """
+ def __init__(self):
+ super().__init__()
+ self._callbacks = {}
+ self._lock = threading.Lock()
+ self._zombies = {}
+ self._forks = 0
+
+ def close(self):
+ self._callbacks.clear()
+ self._zombies.clear()
+ super().close()
+
+ def __enter__(self):
+ with self._lock:
+ self._forks += 1
+
+ return self
+
+ def __exit__(self, a, b, c):
+ with self._lock:
+ self._forks -= 1
+
+ if self._forks or not self._zombies:
+ return
+
+ collateral_victims = str(self._zombies)
+ self._zombies.clear()
+
+ logger.warning(
+ "Caught subprocesses termination from unknown pids: %s",
+ collateral_victims)
+
+ def add_child_handler(self, pid, callback, *args):
+ assert self._forks, "Must use the context manager"
+ with self._lock:
+ try:
+ returncode = self._zombies.pop(pid)
+ except KeyError:
+ # The child is running.
+ self._callbacks[pid] = callback, args
+ return
+
+ # The child is dead already. We can fire the callback.
+ callback(pid, returncode, *args)
+
+ def remove_child_handler(self, pid):
+ try:
+ del self._callbacks[pid]
+ return True
+ except KeyError:
+ return False
+
+ def _do_waitpid_all(self):
+ # Because of signal coalescing, we must keep calling waitpid() as
+ # long as we're able to reap a child.
+ while True:
+ try:
+ pid, status = os.waitpid(-1, os.WNOHANG)
+ except ChildProcessError:
+ # No more child processes exist.
+ return
+ else:
+ if pid == 0:
+ # A child process is still alive.
+ return
+
+ returncode = self._compute_returncode(status)
+
+ with self._lock:
+ try:
+ callback, args = self._callbacks.pop(pid)
+ except KeyError:
+ # unknown child
+ if self._forks:
+ # It may not be registered yet.
+ self._zombies[pid] = returncode
+ if self._loop.get_debug():
+ logger.debug('unknown process %s exited '
+ 'with returncode %s',
+ pid, returncode)
+ continue
+ callback = None
+ else:
+ if self._loop.get_debug():
+ logger.debug('process %s exited with returncode %s',
+ pid, returncode)
+
+ if callback is None:
+ logger.warning(
+ "Caught subprocess termination from unknown pid: "
+ "%d -> %d", pid, returncode)
+ else:
+ callback(pid, returncode, *args)
+
+
+class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy):
+ """UNIX event loop policy with a watcher for child processes."""
+ _loop_factory = _UnixSelectorEventLoop
+
+ def __init__(self):
+ super().__init__()
+ self._watcher = None
+
+ def _init_watcher(self):
+ with events._lock:
+ if self._watcher is None: # pragma: no branch
+ self._watcher = SafeChildWatcher()
+ if isinstance(threading.current_thread(),
+ threading._MainThread):
+ self._watcher.attach_loop(self._local._loop)
+
+ def set_event_loop(self, loop):
+ """Set the event loop.
+
+ As a side effect, if a child watcher was set before, then calling
+ .set_event_loop() from the main thread will call .attach_loop(loop) on
+ the child watcher.
+ """
+
+ super().set_event_loop(loop)
+
+ if self._watcher is not None and \
+ isinstance(threading.current_thread(), threading._MainThread):
+ self._watcher.attach_loop(loop)
+
+ def get_child_watcher(self):
+ """Get the watcher for child processes.
+
+ If not yet set, a SafeChildWatcher object is automatically created.
+ """
+ if self._watcher is None:
+ self._init_watcher()
+
+ return self._watcher
+
+ def set_child_watcher(self, watcher):
+ """Set the watcher for child processes."""
+
+ assert watcher is None or isinstance(watcher, AbstractChildWatcher)
+
+ if self._watcher is not None:
+ self._watcher.close()
+
+ self._watcher = watcher
+
+SelectorEventLoop = _UnixSelectorEventLoop
+DefaultEventLoopPolicy = _UnixDefaultEventLoopPolicy
diff --git a/trollius/windows_events.py b/trollius/windows_events.py
new file mode 100644
index 0000000..922594f
--- /dev/null
+++ b/trollius/windows_events.py
@@ -0,0 +1,774 @@
+"""Selector and proactor event loops for Windows."""
+
+import _winapi
+import errno
+import math
+import socket
+import struct
+import weakref
+
+from . import events
+from . import base_subprocess
+from . import futures
+from . import proactor_events
+from . import selector_events
+from . import tasks
+from . import windows_utils
+from . import _overlapped
+from .coroutines import coroutine
+from .log import logger
+
+
+__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor',
+ 'DefaultEventLoopPolicy',
+ ]
+
+
+NULL = 0
+INFINITE = 0xffffffff
+ERROR_CONNECTION_REFUSED = 1225
+ERROR_CONNECTION_ABORTED = 1236
+
+# Initial delay in seconds for connect_pipe() before retrying to connect
+CONNECT_PIPE_INIT_DELAY = 0.001
+
+# Maximum delay in seconds for connect_pipe() before retrying to connect
+CONNECT_PIPE_MAX_DELAY = 0.100
+
+
+class _OverlappedFuture(futures.Future):
+ """Subclass of Future which represents an overlapped operation.
+
+ Cancelling it will immediately cancel the overlapped operation.
+ """
+
+ def __init__(self, ov, *, loop=None):
+ super().__init__(loop=loop)
+ if self._source_traceback:
+ del self._source_traceback[-1]
+ self._ov = ov
+
+ def _repr_info(self):
+ info = super()._repr_info()
+ if self._ov is not None:
+ state = 'pending' if self._ov.pending else 'completed'
+ info.insert(1, 'overlapped=<%s, %#x>' % (state, self._ov.address))
+ return info
+
+ def _cancel_overlapped(self):
+ if self._ov is None:
+ return
+ try:
+ self._ov.cancel()
+ except OSError as exc:
+ context = {
+ 'message': 'Cancelling an overlapped future failed',
+ 'exception': exc,
+ 'future': self,
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ self._ov = None
+
+ def cancel(self):
+ self._cancel_overlapped()
+ return super().cancel()
+
+ def set_exception(self, exception):
+ super().set_exception(exception)
+ self._cancel_overlapped()
+
+ def set_result(self, result):
+ super().set_result(result)
+ self._ov = None
+
+
+class _BaseWaitHandleFuture(futures.Future):
+ """Subclass of Future which represents a wait handle."""
+
+ def __init__(self, ov, handle, wait_handle, *, loop=None):
+ super().__init__(loop=loop)
+ if self._source_traceback:
+ del self._source_traceback[-1]
+ # Keep a reference to the Overlapped object to keep it alive until the
+ # wait is unregistered
+ self._ov = ov
+ self._handle = handle
+ self._wait_handle = wait_handle
+
+ # Should we call UnregisterWaitEx() if the wait completes
+ # or is cancelled?
+ self._registered = True
+
+ def _poll(self):
+ # non-blocking wait: use a timeout of 0 millisecond
+ return (_winapi.WaitForSingleObject(self._handle, 0) ==
+ _winapi.WAIT_OBJECT_0)
+
+ def _repr_info(self):
+ info = super()._repr_info()
+ info.append('handle=%#x' % self._handle)
+ if self._handle is not None:
+ state = 'signaled' if self._poll() else 'waiting'
+ info.append(state)
+ if self._wait_handle is not None:
+ info.append('wait_handle=%#x' % self._wait_handle)
+ return info
+
+ def _unregister_wait_cb(self, fut):
+ # The wait was unregistered: it's not safe to destroy the Overlapped
+ # object
+ self._ov = None
+
+ def _unregister_wait(self):
+ if not self._registered:
+ return
+ self._registered = False
+
+ wait_handle = self._wait_handle
+ self._wait_handle = None
+ try:
+ _overlapped.UnregisterWait(wait_handle)
+ except OSError as exc:
+ if exc.winerror != _overlapped.ERROR_IO_PENDING:
+ context = {
+ 'message': 'Failed to unregister the wait handle',
+ 'exception': exc,
+ 'future': self,
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ return
+ # ERROR_IO_PENDING means that the unregister is pending
+
+ self._unregister_wait_cb(None)
+
+ def cancel(self):
+ self._unregister_wait()
+ return super().cancel()
+
+ def set_exception(self, exception):
+ self._unregister_wait()
+ super().set_exception(exception)
+
+ def set_result(self, result):
+ self._unregister_wait()
+ super().set_result(result)
+
+
+class _WaitCancelFuture(_BaseWaitHandleFuture):
+ """Subclass of Future which represents a wait for the cancellation of a
+ _WaitHandleFuture using an event.
+ """
+
+ def __init__(self, ov, event, wait_handle, *, loop=None):
+ super().__init__(ov, event, wait_handle, loop=loop)
+
+ self._done_callback = None
+
+ def cancel(self):
+ raise RuntimeError("_WaitCancelFuture must not be cancelled")
+
+ def _schedule_callbacks(self):
+ super(_WaitCancelFuture, self)._schedule_callbacks()
+ if self._done_callback is not None:
+ self._done_callback(self)
+
+
+class _WaitHandleFuture(_BaseWaitHandleFuture):
+ def __init__(self, ov, handle, wait_handle, proactor, *, loop=None):
+ super().__init__(ov, handle, wait_handle, loop=loop)
+ self._proactor = proactor
+ self._unregister_proactor = True
+ self._event = _overlapped.CreateEvent(None, True, False, None)
+ self._event_fut = None
+
+ def _unregister_wait_cb(self, fut):
+ if self._event is not None:
+ _winapi.CloseHandle(self._event)
+ self._event = None
+ self._event_fut = None
+
+ # If the wait was cancelled, the wait may never be signalled, so
+ # it's required to unregister it. Otherwise, IocpProactor.close() will
+ # wait forever for an event which will never come.
+ #
+ # If the IocpProactor already received the event, it's safe to call
+ # _unregister() because we kept a reference to the Overlapped object
+ # which is used as an unique key.
+ self._proactor._unregister(self._ov)
+ self._proactor = None
+
+ super()._unregister_wait_cb(fut)
+
+ def _unregister_wait(self):
+ if not self._registered:
+ return
+ self._registered = False
+
+ wait_handle = self._wait_handle
+ self._wait_handle = None
+ try:
+ _overlapped.UnregisterWaitEx(wait_handle, self._event)
+ except OSError as exc:
+ if exc.winerror != _overlapped.ERROR_IO_PENDING:
+ context = {
+ 'message': 'Failed to unregister the wait handle',
+ 'exception': exc,
+ 'future': self,
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ return
+ # ERROR_IO_PENDING is not an error, the wait was unregistered
+
+ self._event_fut = self._proactor._wait_cancel(self._event,
+ self._unregister_wait_cb)
+
+
+class PipeServer(object):
+ """Class representing a pipe server.
+
+ This is much like a bound, listening socket.
+ """
+ def __init__(self, address):
+ self._address = address
+ self._free_instances = weakref.WeakSet()
+ # initialize the pipe attribute before calling _server_pipe_handle()
+ # because this function can raise an exception and the destructor calls
+ # the close() method
+ self._pipe = None
+ self._accept_pipe_future = None
+ self._pipe = self._server_pipe_handle(True)
+
+ def _get_unconnected_pipe(self):
+ # Create new instance and return previous one. This ensures
+ # that (until the server is closed) there is always at least
+ # one pipe handle for address. Therefore if a client attempt
+ # to connect it will not fail with FileNotFoundError.
+ tmp, self._pipe = self._pipe, self._server_pipe_handle(False)
+ return tmp
+
+ def _server_pipe_handle(self, first):
+ # Return a wrapper for a new pipe handle.
+ if self.closed():
+ return None
+ flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED
+ if first:
+ flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
+ h = _winapi.CreateNamedPipe(
+ self._address, flags,
+ _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE |
+ _winapi.PIPE_WAIT,
+ _winapi.PIPE_UNLIMITED_INSTANCES,
+ windows_utils.BUFSIZE, windows_utils.BUFSIZE,
+ _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL)
+ pipe = windows_utils.PipeHandle(h)
+ self._free_instances.add(pipe)
+ return pipe
+
+ def closed(self):
+ return (self._address is None)
+
+ def close(self):
+ if self._accept_pipe_future is not None:
+ self._accept_pipe_future.cancel()
+ self._accept_pipe_future = None
+ # Close all instances which have not been connected to by a client.
+ if self._address is not None:
+ for pipe in self._free_instances:
+ pipe.close()
+ self._pipe = None
+ self._address = None
+ self._free_instances.clear()
+
+ __del__ = close
+
+
+class _WindowsSelectorEventLoop(selector_events.BaseSelectorEventLoop):
+ """Windows version of selector event loop."""
+
+ def _socketpair(self):
+ return windows_utils.socketpair()
+
+
+class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
+ """Windows version of proactor event loop using IOCP."""
+
+ def __init__(self, proactor=None):
+ if proactor is None:
+ proactor = IocpProactor()
+ super().__init__(proactor)
+
+ def _socketpair(self):
+ return windows_utils.socketpair()
+
+ @coroutine
+ def create_pipe_connection(self, protocol_factory, address):
+ f = self._proactor.connect_pipe(address)
+ pipe = yield from f
+ protocol = protocol_factory()
+ trans = self._make_duplex_pipe_transport(pipe, protocol,
+ extra={'addr': address})
+ return trans, protocol
+
+ @coroutine
+ def start_serving_pipe(self, protocol_factory, address):
+ server = PipeServer(address)
+
+ def loop_accept_pipe(f=None):
+ pipe = None
+ try:
+ if f:
+ pipe = f.result()
+ server._free_instances.discard(pipe)
+
+ if server.closed():
+ # A client connected before the server was closed:
+ # drop the client (close the pipe) and exit
+ pipe.close()
+ return
+
+ protocol = protocol_factory()
+ self._make_duplex_pipe_transport(
+ pipe, protocol, extra={'addr': address})
+
+ pipe = server._get_unconnected_pipe()
+ if pipe is None:
+ return
+
+ f = self._proactor.accept_pipe(pipe)
+ except OSError as exc:
+ if pipe and pipe.fileno() != -1:
+ self.call_exception_handler({
+ 'message': 'Pipe accept failed',
+ 'exception': exc,
+ 'pipe': pipe,
+ })
+ pipe.close()
+ elif self._debug:
+ logger.warning("Accept pipe failed on pipe %r",
+ pipe, exc_info=True)
+ except futures.CancelledError:
+ if pipe:
+ pipe.close()
+ else:
+ server._accept_pipe_future = f
+ f.add_done_callback(loop_accept_pipe)
+
+ self.call_soon(loop_accept_pipe)
+ return [server]
+
+ @coroutine
+ def _make_subprocess_transport(self, protocol, args, shell,
+ stdin, stdout, stderr, bufsize,
+ extra=None, **kwargs):
+ waiter = futures.Future(loop=self)
+ transp = _WindowsSubprocessTransport(self, protocol, args, shell,
+ stdin, stdout, stderr, bufsize,
+ waiter=waiter, extra=extra,
+ **kwargs)
+ try:
+ yield from waiter
+ except Exception as exc:
+ # Workaround CPython bug #23353: using yield/yield-from in an
+ # except block of a generator doesn't clear properly sys.exc_info()
+ err = exc
+ else:
+ err = None
+
+ if err is not None:
+ transp.close()
+ yield from transp._wait()
+ raise err
+
+ return transp
+
+
+class IocpProactor:
+ """Proactor implementation using IOCP."""
+
+ def __init__(self, concurrency=0xffffffff):
+ self._loop = None
+ self._results = []
+ self._iocp = _overlapped.CreateIoCompletionPort(
+ _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency)
+ self._cache = {}
+ self._registered = weakref.WeakSet()
+ self._unregistered = []
+ self._stopped_serving = weakref.WeakSet()
+
+ def __repr__(self):
+ return ('<%s overlapped#=%s result#=%s>'
+ % (self.__class__.__name__, len(self._cache),
+ len(self._results)))
+
+ def set_loop(self, loop):
+ self._loop = loop
+
+ def select(self, timeout=None):
+ if not self._results:
+ self._poll(timeout)
+ tmp = self._results
+ self._results = []
+ return tmp
+
+ def _result(self, value):
+ fut = futures.Future(loop=self._loop)
+ fut.set_result(value)
+ return fut
+
+ def recv(self, conn, nbytes, flags=0):
+ self._register_with_iocp(conn)
+ ov = _overlapped.Overlapped(NULL)
+ try:
+ if isinstance(conn, socket.socket):
+ ov.WSARecv(conn.fileno(), nbytes, flags)
+ else:
+ ov.ReadFile(conn.fileno(), nbytes)
+ except BrokenPipeError:
+ return self._result(b'')
+
+ def finish_recv(trans, key, ov):
+ try:
+ return ov.getresult()
+ except OSError as exc:
+ if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
+ raise ConnectionResetError(*exc.args)
+ else:
+ raise
+
+ return self._register(ov, conn, finish_recv)
+
+ def send(self, conn, buf, flags=0):
+ self._register_with_iocp(conn)
+ ov = _overlapped.Overlapped(NULL)
+ if isinstance(conn, socket.socket):
+ ov.WSASend(conn.fileno(), buf, flags)
+ else:
+ ov.WriteFile(conn.fileno(), buf)
+
+ def finish_send(trans, key, ov):
+ try:
+ return ov.getresult()
+ except OSError as exc:
+ if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
+ raise ConnectionResetError(*exc.args)
+ else:
+ raise
+
+ return self._register(ov, conn, finish_send)
+
+ def accept(self, listener):
+ self._register_with_iocp(listener)
+ conn = self._get_accept_socket(listener.family)
+ ov = _overlapped.Overlapped(NULL)
+ ov.AcceptEx(listener.fileno(), conn.fileno())
+
+ def finish_accept(trans, key, ov):
+ ov.getresult()
+ # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work.
+ buf = struct.pack('@P', listener.fileno())
+ conn.setsockopt(socket.SOL_SOCKET,
+ _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf)
+ conn.settimeout(listener.gettimeout())
+ return conn, conn.getpeername()
+
+ @coroutine
+ def accept_coro(future, conn):
+ # Coroutine closing the accept socket if the future is cancelled
+ try:
+ yield from future
+ except futures.CancelledError:
+ conn.close()
+ raise
+
+ future = self._register(ov, listener, finish_accept)
+ coro = accept_coro(future, conn)
+ tasks.ensure_future(coro, loop=self._loop)
+ return future
+
+ def connect(self, conn, address):
+ self._register_with_iocp(conn)
+ # The socket needs to be locally bound before we call ConnectEx().
+ try:
+ _overlapped.BindLocal(conn.fileno(), conn.family)
+ except OSError as e:
+ if e.winerror != errno.WSAEINVAL:
+ raise
+ # Probably already locally bound; check using getsockname().
+ if conn.getsockname()[1] == 0:
+ raise
+ ov = _overlapped.Overlapped(NULL)
+ ov.ConnectEx(conn.fileno(), address)
+
+ def finish_connect(trans, key, ov):
+ ov.getresult()
+ # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work.
+ conn.setsockopt(socket.SOL_SOCKET,
+ _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0)
+ return conn
+
+ return self._register(ov, conn, finish_connect)
+
+ def accept_pipe(self, pipe):
+ self._register_with_iocp(pipe)
+ ov = _overlapped.Overlapped(NULL)
+ connected = ov.ConnectNamedPipe(pipe.fileno())
+
+ if connected:
+ # ConnectNamePipe() failed with ERROR_PIPE_CONNECTED which means
+ # that the pipe is connected. There is no need to wait for the
+ # completion of the connection.
+ return self._result(pipe)
+
+ def finish_accept_pipe(trans, key, ov):
+ ov.getresult()
+ return pipe
+
+ return self._register(ov, pipe, finish_accept_pipe)
+
+ @coroutine
+ def connect_pipe(self, address):
+ delay = CONNECT_PIPE_INIT_DELAY
+ while True:
+ # Unfortunately there is no way to do an overlapped connect to a pipe.
+ # Call CreateFile() in a loop until it doesn't fail with
+ # ERROR_PIPE_BUSY
+ try:
+ handle = _overlapped.ConnectPipe(address)
+ break
+ except OSError as exc:
+ if exc.winerror != _overlapped.ERROR_PIPE_BUSY:
+ raise
+
+ # ConnectPipe() failed with ERROR_PIPE_BUSY: retry later
+ delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY)
+ yield from tasks.sleep(delay, loop=self._loop)
+
+ return windows_utils.PipeHandle(handle)
+
+ def wait_for_handle(self, handle, timeout=None):
+ """Wait for a handle.
+
+ Return a Future object. The result of the future is True if the wait
+ completed, or False if the wait did not complete (on timeout).
+ """
+ return self._wait_for_handle(handle, timeout, False)
+
+ def _wait_cancel(self, event, done_callback):
+ fut = self._wait_for_handle(event, None, True)
+ # add_done_callback() cannot be used because the wait may only complete
+ # in IocpProactor.close(), while the event loop is not running.
+ fut._done_callback = done_callback
+ return fut
+
+ def _wait_for_handle(self, handle, timeout, _is_cancel):
+ if timeout is None:
+ ms = _winapi.INFINITE
+ else:
+ # RegisterWaitForSingleObject() has a resolution of 1 millisecond,
+ # round away from zero to wait *at least* timeout seconds.
+ ms = math.ceil(timeout * 1e3)
+
+ # We only create ov so we can use ov.address as a key for the cache.
+ ov = _overlapped.Overlapped(NULL)
+ wait_handle = _overlapped.RegisterWaitWithQueue(
+ handle, self._iocp, ov.address, ms)
+ if _is_cancel:
+ f = _WaitCancelFuture(ov, handle, wait_handle, loop=self._loop)
+ else:
+ f = _WaitHandleFuture(ov, handle, wait_handle, self,
+ loop=self._loop)
+ if f._source_traceback:
+ del f._source_traceback[-1]
+
+ def finish_wait_for_handle(trans, key, ov):
+ # Note that this second wait means that we should only use
+ # this with handles types where a successful wait has no
+ # effect. So events or processes are all right, but locks
+ # or semaphores are not. Also note if the handle is
+ # signalled and then quickly reset, then we may return
+ # False even though we have not timed out.
+ return f._poll()
+
+ self._cache[ov.address] = (f, ov, 0, finish_wait_for_handle)
+ return f
+
+ def _register_with_iocp(self, obj):
+ # To get notifications of finished ops on this objects sent to the
+ # completion port, were must register the handle.
+ if obj not in self._registered:
+ self._registered.add(obj)
+ _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0)
+ # XXX We could also use SetFileCompletionNotificationModes()
+ # to avoid sending notifications to completion port of ops
+ # that succeed immediately.
+
+ def _register(self, ov, obj, callback):
+ # Return a future which will be set with the result of the
+ # operation when it completes. The future's value is actually
+ # the value returned by callback().
+ f = _OverlappedFuture(ov, loop=self._loop)
+ if f._source_traceback:
+ del f._source_traceback[-1]
+ if not ov.pending:
+ # The operation has completed, so no need to postpone the
+ # work. We cannot take this short cut if we need the
+ # NumberOfBytes, CompletionKey values returned by
+ # PostQueuedCompletionStatus().
+ try:
+ value = callback(None, None, ov)
+ except OSError as e:
+ f.set_exception(e)
+ else:
+ f.set_result(value)
+ # Even if GetOverlappedResult() was called, we have to wait for the
+ # notification of the completion in GetQueuedCompletionStatus().
+ # Register the overlapped operation to keep a reference to the
+ # OVERLAPPED object, otherwise the memory is freed and Windows may
+ # read uninitialized memory.
+
+ # Register the overlapped operation for later. Note that
+ # we only store obj to prevent it from being garbage
+ # collected too early.
+ self._cache[ov.address] = (f, ov, obj, callback)
+ return f
+
+ def _unregister(self, ov):
+ """Unregister an overlapped object.
+
+ Call this method when its future has been cancelled. The event can
+ already be signalled (pending in the proactor event queue). It is also
+ safe if the event is never signalled (because it was cancelled).
+ """
+ self._unregistered.append(ov)
+
+ def _get_accept_socket(self, family):
+ s = socket.socket(family)
+ s.settimeout(0)
+ return s
+
+ def _poll(self, timeout=None):
+ if timeout is None:
+ ms = INFINITE
+ elif timeout < 0:
+ raise ValueError("negative timeout")
+ else:
+ # GetQueuedCompletionStatus() has a resolution of 1 millisecond,
+ # round away from zero to wait *at least* timeout seconds.
+ ms = math.ceil(timeout * 1e3)
+ if ms >= INFINITE:
+ raise ValueError("timeout too big")
+
+ while True:
+ status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms)
+ if status is None:
+ break
+ ms = 0
+
+ err, transferred, key, address = status
+ try:
+ f, ov, obj, callback = self._cache.pop(address)
+ except KeyError:
+ if self._loop.get_debug():
+ self._loop.call_exception_handler({
+ 'message': ('GetQueuedCompletionStatus() returned an '
+ 'unexpected event'),
+ 'status': ('err=%s transferred=%s key=%#x address=%#x'
+ % (err, transferred, key, address)),
+ })
+
+ # key is either zero, or it is used to return a pipe
+ # handle which should be closed to avoid a leak.
+ if key not in (0, _overlapped.INVALID_HANDLE_VALUE):
+ _winapi.CloseHandle(key)
+ continue
+
+ if obj in self._stopped_serving:
+ f.cancel()
+ # Don't call the callback if _register() already read the result or
+ # if the overlapped has been cancelled
+ elif not f.done():
+ try:
+ value = callback(transferred, key, ov)
+ except OSError as e:
+ f.set_exception(e)
+ self._results.append(f)
+ else:
+ f.set_result(value)
+ self._results.append(f)
+
+ # Remove unregisted futures
+ for ov in self._unregistered:
+ self._cache.pop(ov.address, None)
+ self._unregistered.clear()
+
+ def _stop_serving(self, obj):
+ # obj is a socket or pipe handle. It will be closed in
+ # BaseProactorEventLoop._stop_serving() which will make any
+ # pending operations fail quickly.
+ self._stopped_serving.add(obj)
+
+ def close(self):
+ # Cancel remaining registered operations.
+ for address, (fut, ov, obj, callback) in list(self._cache.items()):
+ if fut.cancelled():
+ # Nothing to do with cancelled futures
+ pass
+ elif isinstance(fut, _WaitCancelFuture):
+ # _WaitCancelFuture must not be cancelled
+ pass
+ else:
+ try:
+ fut.cancel()
+ except OSError as exc:
+ if self._loop is not None:
+ context = {
+ 'message': 'Cancelling a future failed',
+ 'exception': exc,
+ 'future': fut,
+ }
+ if fut._source_traceback:
+ context['source_traceback'] = fut._source_traceback
+ self._loop.call_exception_handler(context)
+
+ while self._cache:
+ if not self._poll(1):
+ logger.debug('taking long time to close proactor')
+
+ self._results = []
+ if self._iocp is not None:
+ _winapi.CloseHandle(self._iocp)
+ self._iocp = None
+
+ def __del__(self):
+ self.close()
+
+
+class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport):
+
+ def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs):
+ self._proc = windows_utils.Popen(
+ args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr,
+ bufsize=bufsize, **kwargs)
+
+ def callback(f):
+ returncode = self._proc.poll()
+ self._process_exited(returncode)
+
+ f = self._loop._proactor.wait_for_handle(int(self._proc._handle))
+ f.add_done_callback(callback)
+
+
+SelectorEventLoop = _WindowsSelectorEventLoop
+
+
+class _WindowsDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy):
+ _loop_factory = SelectorEventLoop
+
+
+DefaultEventLoopPolicy = _WindowsDefaultEventLoopPolicy
diff --git a/trollius/windows_utils.py b/trollius/windows_utils.py
new file mode 100644
index 0000000..870cd13
--- /dev/null
+++ b/trollius/windows_utils.py
@@ -0,0 +1,223 @@
+"""
+Various Windows specific bits and pieces
+"""
+
+import sys
+
+if sys.platform != 'win32': # pragma: no cover
+ raise ImportError('win32 only')
+
+import _winapi
+import itertools
+import msvcrt
+import os
+import socket
+import subprocess
+import tempfile
+import warnings
+
+
+__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle']
+
+
+# Constants/globals
+
+
+BUFSIZE = 8192
+PIPE = subprocess.PIPE
+STDOUT = subprocess.STDOUT
+_mmap_counter = itertools.count()
+
+
+if hasattr(socket, 'socketpair'):
+ # Since Python 3.5, socket.socketpair() is now also available on Windows
+ socketpair = socket.socketpair
+else:
+ # Replacement for socket.socketpair()
+ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
+ """A socket pair usable as a self-pipe, for Windows.
+
+ Origin: https://gist.github.com/4325783, by Geert Jansen.
+ Public domain.
+ """
+ if family == socket.AF_INET:
+ host = '127.0.0.1'
+ elif family == socket.AF_INET6:
+ host = '::1'
+ else:
+ raise ValueError("Only AF_INET and AF_INET6 socket address "
+ "families are supported")
+ if type != socket.SOCK_STREAM:
+ raise ValueError("Only SOCK_STREAM socket type is supported")
+ if proto != 0:
+ raise ValueError("Only protocol zero is supported")
+
+ # We create a connected TCP socket. Note the trick with setblocking(0)
+ # that prevents us from having to create a thread.
+ lsock = socket.socket(family, type, proto)
+ try:
+ lsock.bind((host, 0))
+ lsock.listen(1)
+ # On IPv6, ignore flow_info and scope_id
+ addr, port = lsock.getsockname()[:2]
+ csock = socket.socket(family, type, proto)
+ try:
+ csock.setblocking(False)
+ try:
+ csock.connect((addr, port))
+ except (BlockingIOError, InterruptedError):
+ pass
+ csock.setblocking(True)
+ ssock, _ = lsock.accept()
+ except:
+ csock.close()
+ raise
+ finally:
+ lsock.close()
+ return (ssock, csock)
+
+
+# Replacement for os.pipe() using handles instead of fds
+
+
+def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
+ """Like os.pipe() but with overlapped support and using handles not fds."""
+ address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' %
+ (os.getpid(), next(_mmap_counter)))
+
+ if duplex:
+ openmode = _winapi.PIPE_ACCESS_DUPLEX
+ access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE
+ obsize, ibsize = bufsize, bufsize
+ else:
+ openmode = _winapi.PIPE_ACCESS_INBOUND
+ access = _winapi.GENERIC_WRITE
+ obsize, ibsize = 0, bufsize
+
+ openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
+
+ if overlapped[0]:
+ openmode |= _winapi.FILE_FLAG_OVERLAPPED
+
+ if overlapped[1]:
+ flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED
+ else:
+ flags_and_attribs = 0
+
+ h1 = h2 = None
+ try:
+ h1 = _winapi.CreateNamedPipe(
+ address, openmode, _winapi.PIPE_WAIT,
+ 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL)
+
+ h2 = _winapi.CreateFile(
+ address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING,
+ flags_and_attribs, _winapi.NULL)
+
+ ov = _winapi.ConnectNamedPipe(h1, overlapped=True)
+ ov.GetOverlappedResult(True)
+ return h1, h2
+ except:
+ if h1 is not None:
+ _winapi.CloseHandle(h1)
+ if h2 is not None:
+ _winapi.CloseHandle(h2)
+ raise
+
+
+# Wrapper for a pipe handle
+
+
+class PipeHandle:
+ """Wrapper for an overlapped pipe handle which is vaguely file-object like.
+
+ The IOCP event loop can use these instead of socket objects.
+ """
+ def __init__(self, handle):
+ self._handle = handle
+
+ def __repr__(self):
+ if self._handle is not None:
+ handle = 'handle=%r' % self._handle
+ else:
+ handle = 'closed'
+ return '<%s %s>' % (self.__class__.__name__, handle)
+
+ @property
+ def handle(self):
+ return self._handle
+
+ def fileno(self):
+ if self._handle is None:
+ raise ValueError("I/O operatioon on closed pipe")
+ return self._handle
+
+ def close(self, *, CloseHandle=_winapi.CloseHandle):
+ if self._handle is not None:
+ CloseHandle(self._handle)
+ self._handle = None
+
+ def __del__(self):
+ if self._handle is not None:
+ warnings.warn("unclosed %r" % self, ResourceWarning)
+ self.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, t, v, tb):
+ self.close()
+
+
+# Replacement for subprocess.Popen using overlapped pipe handles
+
+
+class Popen(subprocess.Popen):
+ """Replacement for subprocess.Popen using overlapped pipe handles.
+
+ The stdin, stdout, stderr are None or instances of PipeHandle.
+ """
+ def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds):
+ assert not kwds.get('universal_newlines')
+ assert kwds.get('bufsize', 0) == 0
+ stdin_rfd = stdout_wfd = stderr_wfd = None
+ stdin_wh = stdout_rh = stderr_rh = None
+ if stdin == PIPE:
+ stdin_rh, stdin_wh = pipe(overlapped=(False, True), duplex=True)
+ stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY)
+ else:
+ stdin_rfd = stdin
+ if stdout == PIPE:
+ stdout_rh, stdout_wh = pipe(overlapped=(True, False))
+ stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0)
+ else:
+ stdout_wfd = stdout
+ if stderr == PIPE:
+ stderr_rh, stderr_wh = pipe(overlapped=(True, False))
+ stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0)
+ elif stderr == STDOUT:
+ stderr_wfd = stdout_wfd
+ else:
+ stderr_wfd = stderr
+ try:
+ super().__init__(args, stdin=stdin_rfd, stdout=stdout_wfd,
+ stderr=stderr_wfd, **kwds)
+ except:
+ for h in (stdin_wh, stdout_rh, stderr_rh):
+ if h is not None:
+ _winapi.CloseHandle(h)
+ raise
+ else:
+ if stdin_wh is not None:
+ self.stdin = PipeHandle(stdin_wh)
+ if stdout_rh is not None:
+ self.stdout = PipeHandle(stdout_rh)
+ if stderr_rh is not None:
+ self.stderr = PipeHandle(stderr_rh)
+ finally:
+ if stdin == PIPE:
+ os.close(stdin_rfd)
+ if stdout == PIPE:
+ os.close(stdout_wfd)
+ if stderr == PIPE:
+ os.close(stderr_wfd)