diff options
author | Victor Stinner <vstinner@redhat.com> | 2015-07-06 22:11:55 +0200 |
---|---|---|
committer | Victor Stinner <vstinner@redhat.com> | 2015-07-06 22:30:54 +0200 |
commit | 1975461f4ccf0d9ffd9e20e4dd4a4650ad6a0c18 (patch) | |
tree | d185d0cf9c4bd65aebf93fcdd24eb15dac4383db /trollius | |
parent | 728a91239123913b9161e8ca174cf86288f8c7d7 (diff) | |
download | trollius-git-1975461f4ccf0d9ffd9e20e4dd4a4650ad6a0c18.tar.gz |
rename asyncio/ directory to trollius/
Diffstat (limited to 'trollius')
-rw-r--r-- | trollius/__init__.py | 50 | ||||
-rw-r--r-- | trollius/base_events.py | 1240 | ||||
-rw-r--r-- | trollius/base_subprocess.py | 275 | ||||
-rw-r--r-- | trollius/constants.py | 7 | ||||
-rw-r--r-- | trollius/coroutines.py | 301 | ||||
-rw-r--r-- | trollius/events.py | 611 | ||||
-rw-r--r-- | trollius/futures.py | 413 | ||||
-rw-r--r-- | trollius/locks.py | 470 | ||||
-rw-r--r-- | trollius/log.py | 7 | ||||
-rw-r--r-- | trollius/proactor_events.py | 547 | ||||
-rw-r--r-- | trollius/protocols.py | 134 | ||||
-rw-r--r-- | trollius/queues.py | 293 | ||||
-rw-r--r-- | trollius/selector_events.py | 1068 | ||||
-rw-r--r-- | trollius/selectors.py | 594 | ||||
-rw-r--r-- | trollius/sslproto.py | 668 | ||||
-rw-r--r-- | trollius/streams.py | 501 | ||||
-rw-r--r-- | trollius/subprocess.py | 215 | ||||
-rw-r--r-- | trollius/tasks.py | 681 | ||||
-rw-r--r-- | trollius/test_support.py | 308 | ||||
-rw-r--r-- | trollius/test_utils.py | 446 | ||||
-rw-r--r-- | trollius/transports.py | 300 | ||||
-rw-r--r-- | trollius/unix_events.py | 998 | ||||
-rw-r--r-- | trollius/windows_events.py | 774 | ||||
-rw-r--r-- | trollius/windows_utils.py | 223 |
24 files changed, 11124 insertions, 0 deletions
diff --git a/trollius/__init__.py b/trollius/__init__.py new file mode 100644 index 0000000..011466b --- /dev/null +++ b/trollius/__init__.py @@ -0,0 +1,50 @@ +"""The asyncio package, tracking PEP 3156.""" + +import sys + +# The selectors module is in the stdlib in Python 3.4 but not in 3.3. +# Do this first, so the other submodules can use "from . import selectors". +# Prefer asyncio/selectors.py over the stdlib one, as ours may be newer. +try: + from . import selectors +except ImportError: + import selectors # Will also be exported. + +if sys.platform == 'win32': + # Similar thing for _overlapped. + try: + from . import _overlapped + except ImportError: + import _overlapped # Will also be exported. + +# This relies on each of the submodules having an __all__ variable. +from .base_events import * +from .coroutines import * +from .events import * +from .futures import * +from .locks import * +from .protocols import * +from .queues import * +from .streams import * +from .subprocess import * +from .tasks import * +from .transports import * + +__all__ = (base_events.__all__ + + coroutines.__all__ + + events.__all__ + + futures.__all__ + + locks.__all__ + + protocols.__all__ + + queues.__all__ + + streams.__all__ + + subprocess.__all__ + + tasks.__all__ + + transports.__all__) + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * + __all__ += windows_events.__all__ +else: + from .unix_events import * # pragma: no cover + __all__ += unix_events.__all__ diff --git a/trollius/base_events.py b/trollius/base_events.py new file mode 100644 index 0000000..5a536a2 --- /dev/null +++ b/trollius/base_events.py @@ -0,0 +1,1240 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of I/O events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import inspect +import logging +import os +import socket +import subprocess +import threading +import time +import traceback +import sys +import warnings + +from . import coroutines +from . import events +from . import futures +from . import tasks +from .coroutines import coroutine +from .log import logger + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + +# Minimum number of _scheduled timer handles before cleanup of +# cancelled handles is performed. +_MIN_SCHEDULED_TIMER_HANDLES = 100 + +# Minimum fraction of _scheduled timer handles that are cancelled +# before cleanup of cancelled handles is performed. +_MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 + +def _format_handle(handle): + cb = handle._callback + if inspect.ismethod(cb) and isinstance(cb.__self__, tasks.Task): + # format the task + return repr(cb.__self__) + else: + return str(handle) + + +def _format_pipe(fd): + if fd == subprocess.PIPE: + return '<pipe>' + elif fd == subprocess.STDOUT: + return '<stdout>' + else: + return repr(fd) + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _check_resolved_address(sock, address): + # Ensure that the address is already resolved to avoid the trap of hanging + # the entire event loop when the address requires doing a DNS lookup. + # + # getaddrinfo() is slow (around 10 us per call): this function should only + # be called in debug mode + family = sock.family + + if family == socket.AF_INET: + host, port = address + elif family == socket.AF_INET6: + host, port = address[:2] + else: + return + + # On Windows, socket.inet_pton() is only available since Python 3.4 + if hasattr(socket, 'inet_pton'): + # getaddrinfo() is slow and has known issue: prefer inet_pton() + # if available + try: + socket.inet_pton(family, host) + except OSError as exc: + raise ValueError("address must be resolved (IP address), " + "got host %r: %s" + % (host, exc)) + else: + # Use getaddrinfo(flags=AI_NUMERICHOST) to ensure that the address is + # already resolved. + type_mask = 0 + if hasattr(socket, 'SOCK_NONBLOCK'): + type_mask |= socket.SOCK_NONBLOCK + if hasattr(socket, 'SOCK_CLOEXEC'): + type_mask |= socket.SOCK_CLOEXEC + try: + socket.getaddrinfo(host, port, + family=family, + type=(sock.type & ~type_mask), + proto=sock.proto, + flags=socket.AI_NUMERICHOST) + except socket.gaierror as err: + raise ValueError("address must be resolved (IP address), " + "got host %r: %s" + % (host, err)) + +def _raise_stop_error(*args): + raise _StopError + + +def _run_until_complete_cb(fut): + exc = fut._exception + if (isinstance(exc, BaseException) + and not isinstance(exc, Exception)): + # Issue #22429: run_forever() already finished, no need to + # stop it. + return + _raise_stop_error() + + +class Server(events.AbstractServer): + + def __init__(self, loop, sockets): + self._loop = loop + self.sockets = sockets + self._active_count = 0 + self._waiters = [] + + def __repr__(self): + return '<%s sockets=%r>' % (self.__class__.__name__, self.sockets) + + def _attach(self): + assert self.sockets is not None + self._active_count += 1 + + def _detach(self): + assert self._active_count > 0 + self._active_count -= 1 + if self._active_count == 0 and self.sockets is None: + self._wakeup() + + def close(self): + sockets = self.sockets + if sockets is None: + return + self.sockets = None + for sock in sockets: + self._loop._stop_serving(sock) + if self._active_count == 0: + self._wakeup() + + def _wakeup(self): + waiters = self._waiters + self._waiters = None + for waiter in waiters: + if not waiter.done(): + waiter.set_result(waiter) + + @coroutine + def wait_closed(self): + if self.sockets is None or self._waiters is None: + return + waiter = futures.Future(loop=self._loop) + self._waiters.append(waiter) + yield from waiter + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._timer_cancelled_count = 0 + self._closed = False + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + # Identifier of the thread running the event loop, or None if the + # event loop is not running + self._thread_id = None + self._clock_resolution = time.get_clock_info('monotonic').resolution + self._exception_handler = None + self.set_debug((not sys.flags.ignore_environment + and bool(os.environ.get('PYTHONASYNCIODEBUG')))) + # In debug mode, if the execution of a callback or a step of a task + # exceed this duration in seconds, the slow callback/task is logged. + self.slow_callback_duration = 0.1 + self._current_handle = None + self._task_factory = None + self._coroutine_wrapper_set = False + + def __repr__(self): + return ('<%s running=%s closed=%s debug=%s>' + % (self.__class__.__name__, self.is_running(), + self.is_closed(), self.get_debug())) + + def create_task(self, coro): + """Schedule a coroutine object. + + Return a task object. + """ + self._check_closed() + if self._task_factory is None: + task = tasks.Task(coro, loop=self) + if task._source_traceback: + del task._source_traceback[-1] + else: + task = self._task_factory(self, coro) + return task + + def set_task_factory(self, factory): + """Set a task factory that will be used by loop.create_task(). + + If factory is None the default task factory will be set. + + If factory is a callable, it should have a signature matching + '(loop, coro)', where 'loop' will be a reference to the active + event loop, 'coro' will be a coroutine object. The callable + must return a Future. + """ + if factory is not None and not callable(factory): + raise TypeError('task factory must be a callable or None') + self._task_factory = factory + + def get_task_factory(self): + """Return a task factory, or None if the default one is in use.""" + return self._task_factory + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, waiter=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + @coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + + def _write_to_self(self): + """Write a byte to self-pipe, to wake up the event loop. + + This may be called from a different thread. + + The subclass is responsible for implementing the self-pipe. + """ + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def _check_closed(self): + if self._closed: + raise RuntimeError('Event loop is closed') + + def run_forever(self): + """Run until stop() is called.""" + self._check_closed() + if self.is_running(): + raise RuntimeError('Event loop is running.') + self._set_coroutine_wrapper(self._debug) + self._thread_id = threading.get_ident() + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._thread_id = None + self._set_coroutine_wrapper(False) + + def run_until_complete(self, future): + """Run until the Future is done. + + If the argument is a coroutine, it is wrapped in a Task. + + WARNING: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. + """ + self._check_closed() + + new_task = not isinstance(future, futures.Future) + future = tasks.ensure_future(future, loop=self) + if new_task: + # An exception is raised if the future didn't complete, so there + # is no need to log the "destroy pending task" message + future._log_destroy_pending = False + + future.add_done_callback(_run_until_complete_cb) + try: + self.run_forever() + except: + if new_task and future.done() and not future.cancelled(): + # The coroutine raised a BaseException. Consume the exception + # to not log a warning, the caller doesn't have access to the + # local task. + future.exception() + raise + future.remove_done_callback(_run_until_complete_cb) + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. Callbacks + scheduled after stop() is called will not run. However, those callbacks + will run if run_forever is called again later. + """ + self.call_soon(_raise_stop_error) + + def close(self): + """Close the event loop. + + This clears the queues and shuts down the executor, + but does not wait for the executor to finish. + + The event loop must not be running. + """ + if self.is_running(): + raise RuntimeError("Cannot close a running event loop") + if self._closed: + return + if self._debug: + logger.debug("Close %r", self) + self._closed = True + self._ready.clear() + self._scheduled.clear() + executor = self._default_executor + if executor is not None: + self._default_executor = None + executor.shutdown(wait=False) + + def is_closed(self): + """Returns True if the event loop was closed.""" + return self._closed + + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if not self.is_closed(): + warnings.warn("unclosed event loop %r" % self, ResourceWarning) + if not self.is_running(): + self.close() + + def is_running(self): + """Returns True if the event loop is running.""" + return (self._thread_id is not None) + + def time(self): + """Return the time according to the event loop's clock. + + This is a float expressed in seconds since an epoch, but the + epoch, precision, accuracy and drift are unspecified and may + differ per event loop. + """ + return time.monotonic() + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always relative to the current time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + timer = self.call_at(self.time() + delay, callback, *args) + if timer._source_traceback: + del timer._source_traceback[-1] + return timer + + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time. + + Absolute time corresponds to the event loop's time() method. + """ + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with call_at()") + self._check_closed() + if self._debug: + self._check_thread() + timer = events.TimerHandle(when, callback, args, self) + if timer._source_traceback: + del timer._source_traceback[-1] + heapq.heappush(self._scheduled, timer) + timer._scheduled = True + return timer + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue: callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if self._debug: + self._check_thread() + handle = self._call_soon(callback, args) + if handle._source_traceback: + del handle._source_traceback[-1] + return handle + + def _call_soon(self, callback, args): + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with call_soon()") + self._check_closed() + handle = events.Handle(callback, args, self) + if handle._source_traceback: + del handle._source_traceback[-1] + self._ready.append(handle) + return handle + + def _check_thread(self): + """Check that the current thread is the thread running the event loop. + + Non-thread-safe methods of this class make this assumption and will + likely behave incorrectly when the assumption is violated. + + Should only be called when (self._debug == True). The caller is + responsible for checking this condition for performance reasons. + """ + if self._thread_id is None: + return + thread_id = threading.get_ident() + if thread_id != self._thread_id: + raise RuntimeError( + "Non-thread-safe operation invoked on an event loop other " + "than the current one") + + def call_soon_threadsafe(self, callback, *args): + """Like call_soon(), but thread-safe.""" + handle = self._call_soon(callback, args) + if handle._source_traceback: + del handle._source_traceback[-1] + self._write_to_self() + return handle + + def run_in_executor(self, executor, func, *args): + if (coroutines.iscoroutine(func) + or coroutines.iscoroutinefunction(func)): + raise TypeError("coroutines cannot be used with run_in_executor()") + self._check_closed() + if isinstance(func, events.Handle): + assert not args + assert not isinstance(func, events.TimerHandle) + if func._cancelled: + f = futures.Future(loop=self) + f.set_result(None) + return f + func, args = func._callback, func._args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return futures.wrap_future(executor.submit(func, *args), loop=self) + + def set_default_executor(self, executor): + self._default_executor = executor + + def _getaddrinfo_debug(self, host, port, family, type, proto, flags): + msg = ["%s:%r" % (host, port)] + if family: + msg.append('family=%r' % family) + if type: + msg.append('type=%r' % type) + if proto: + msg.append('proto=%r' % proto) + if flags: + msg.append('flags=%r' % flags) + msg = ', '.join(msg) + logger.debug('Get address info %s', msg) + + t0 = self.time() + addrinfo = socket.getaddrinfo(host, port, family, type, proto, flags) + dt = self.time() - t0 + + msg = ('Getting address info %s took %.3f ms: %r' + % (msg, dt * 1e3, addrinfo)) + if dt >= self.slow_callback_duration: + logger.info(msg) + else: + logger.debug(msg) + return addrinfo + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + if self._debug: + return self.run_in_executor(None, self._getaddrinfo_debug, + host, port, family, type, proto, flags) + else: + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None, server_hostname=None): + """Connect to a TCP server. + + Create a streaming transport connection to a given Internet host and + port: socket family AF_INET or socket.AF_INET6 depending on host (or + family if specified), socket type SOCK_STREAM. protocol_factory must be + a callable returning a protocol instance. + + This method is a coroutine which will try to establish the connection + in the background. When successful, the coroutine returns a + (transport, protocol) pair. + """ + if server_hostname is not None and not ssl: + raise ValueError('server_hostname is only meaningful with ssl') + + if server_hostname is None and ssl: + # Use host as default for server_hostname. It is an error + # if host is empty or not set, e.g. when an + # already-connected socket was passed or when only a port + # is given. To avoid this error, you can pass + # server_hostname='' -- this will bypass the hostname + # check. (This also means that if host is a numeric + # IP/IPv6 address, we will attempt to verify that exact + # address; this will probably fail, but it is possible to + # create a certificate for a specific IP address, so we + # don't judge it here.) + if not host: + raise ValueError('You must set server_hostname ' + 'when using ssl without a host') + server_hostname = host + + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo( + *local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs, loop=self) + + infos = f1.result() + if not infos: + raise OSError('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise OSError('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + exc = OSError( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + sock.close() + sock = None + continue + if self._debug: + logger.debug("connect %r to %r", sock, address) + yield from self.sock_connect(sock, address) + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + except: + if sock is not None: + sock.close() + raise + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + if self._debug: + # Get the socket from the transport because SSL transport closes + # the old socket and creates a new SSL socket + sock = transport.get_extra_info('socket') + logger.debug("%r connected to %s:%r: (%r, %r)", + sock, host, port, transport, protocol) + return transport, protocol + + @coroutine + def _create_connection_transport(self, sock, protocol_factory, ssl, + server_hostname): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, + server_side=False, server_hostname=server_hostname) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + try: + yield from waiter + except: + transport.close() + raise + + return transport, protocol + + @coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join address by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + except: + if sock is not None: + sock.close() + raise + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_datagram_transport(sock, protocol, r_addr, + waiter) + if self._debug: + if local_addr: + logger.info("Datagram endpoint local_addr=%r remote_addr=%r " + "created: (%r, %r)", + local_addr, remote_addr, transport, protocol) + else: + logger.debug("Datagram endpoint remote_addr=%r created: " + "(%r, %r)", + remote_addr, transport, protocol) + + try: + yield from waiter + except: + transport.close() + raise + + return transport, protocol + + @coroutine + def create_server(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + """Create a TCP server bound to host and port. + + Return a Server object which can be used to stop the service. + + This method is a coroutine. + """ + if isinstance(ssl, bool): + raise TypeError('ssl argument must be an SSLContext or None') + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + try: + sock = socket.socket(af, socktype, proto) + except socket.error: + # Assume it's a bad family/type/protocol combination. + if self._debug: + logger.warning('create_server() failed to create ' + 'socket.socket(%r, %r, %r)', + af, socktype, proto, exc_info=True) + continue + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except OSError as err: + raise OSError(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError('Neither host/port nor sock were specified') + sockets = [sock] + + server = Server(self, sockets) + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + if self._debug: + logger.info("%r is serving", server) + return server + + @coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_read_pipe_transport(pipe, protocol, waiter) + + try: + yield from waiter + except: + transport.close() + raise + + if self._debug: + logger.debug('Read pipe %r connected: (%r, %r)', + pipe.fileno(), transport, protocol) + return transport, protocol + + @coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_write_pipe_transport(pipe, protocol, waiter) + + try: + yield from waiter + except: + transport.close() + raise + + if self._debug: + logger.debug('Write pipe %r connected: (%r, %r)', + pipe.fileno(), transport, protocol) + return transport, protocol + + def _log_subprocess(self, msg, stdin, stdout, stderr): + info = [msg] + if stdin is not None: + info.append('stdin=%s' % _format_pipe(stdin)) + if stdout is not None and stderr == subprocess.STDOUT: + info.append('stdout=stderr=%s' % _format_pipe(stdout)) + else: + if stdout is not None: + info.append('stdout=%s' % _format_pipe(stdout)) + if stderr is not None: + info.append('stderr=%s' % _format_pipe(stderr)) + logger.debug(' '.join(info)) + + @coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + if not isinstance(cmd, (bytes, str)): + raise ValueError("cmd must be a string") + if universal_newlines: + raise ValueError("universal_newlines must be False") + if not shell: + raise ValueError("shell must be True") + if bufsize != 0: + raise ValueError("bufsize must be 0") + protocol = protocol_factory() + if self._debug: + # don't log parameters: they may contain sensitive information + # (password) and may be too long + debug_log = 'run shell command %r' % cmd + self._log_subprocess(debug_log, stdin, stdout, stderr) + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) + if self._debug: + logger.info('%s: %r' % (debug_log, transport)) + return transport, protocol + + @coroutine + def subprocess_exec(self, protocol_factory, program, *args, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=False, + shell=False, bufsize=0, **kwargs): + if universal_newlines: + raise ValueError("universal_newlines must be False") + if shell: + raise ValueError("shell must be False") + if bufsize != 0: + raise ValueError("bufsize must be 0") + popen_args = (program,) + args + for arg in popen_args: + if not isinstance(arg, (str, bytes)): + raise TypeError("program arguments must be " + "a bytes or text string, not %s" + % type(arg).__name__) + protocol = protocol_factory() + if self._debug: + # don't log parameters: they may contain sensitive information + # (password) and may be too long + debug_log = 'execute program %r' % program + self._log_subprocess(debug_log, stdin, stdout, stderr) + transport = yield from self._make_subprocess_transport( + protocol, popen_args, False, stdin, stdout, stderr, + bufsize, **kwargs) + if self._debug: + logger.info('%s: %r' % (debug_log, transport)) + return transport, protocol + + def set_exception_handler(self, handler): + """Set handler as the new event loop exception handler. + + If handler is None, the default exception handler will + be set. + + If handler is a callable object, it should have a + signature matching '(loop, context)', where 'loop' + will be a reference to the active event loop, 'context' + will be a dict object (see `call_exception_handler()` + documentation for details about context). + """ + if handler is not None and not callable(handler): + raise TypeError('A callable object or None is expected, ' + 'got {!r}'.format(handler)) + self._exception_handler = handler + + def default_exception_handler(self, context): + """Default exception handler. + + This is called when an exception occurs and no exception + handler is set, and can be called by a custom exception + handler that wants to defer to the default behavior. + + The context parameter has the same meaning as in + `call_exception_handler()`. + """ + message = context.get('message') + if not message: + message = 'Unhandled exception in event loop' + + exception = context.get('exception') + if exception is not None: + exc_info = (type(exception), exception, exception.__traceback__) + else: + exc_info = False + + if ('source_traceback' not in context + and self._current_handle is not None + and self._current_handle._source_traceback): + context['handle_traceback'] = self._current_handle._source_traceback + + log_lines = [message] + for key in sorted(context): + if key in {'message', 'exception'}: + continue + value = context[key] + if key == 'source_traceback': + tb = ''.join(traceback.format_list(value)) + value = 'Object created at (most recent call last):\n' + value += tb.rstrip() + elif key == 'handle_traceback': + tb = ''.join(traceback.format_list(value)) + value = 'Handle created at (most recent call last):\n' + value += tb.rstrip() + else: + value = repr(value) + log_lines.append('{}: {}'.format(key, value)) + + logger.error('\n'.join(log_lines), exc_info=exc_info) + + def call_exception_handler(self, context): + """Call the current event loop's exception handler. + + The context argument is a dict containing the following keys: + + - 'message': Error message; + - 'exception' (optional): Exception object; + - 'future' (optional): Future instance; + - 'handle' (optional): Handle instance; + - 'protocol' (optional): Protocol instance; + - 'transport' (optional): Transport instance; + - 'socket' (optional): Socket instance. + + New keys maybe introduced in the future. + + Note: do not overload this method in an event loop subclass. + For custom exception handling, use the + `set_exception_handler()` method. + """ + if self._exception_handler is None: + try: + self.default_exception_handler(context) + except Exception: + # Second protection layer for unexpected errors + # in the default implementation, as well as for subclassed + # event loops with overloaded "default_exception_handler". + logger.error('Exception in default exception handler', + exc_info=True) + else: + try: + self._exception_handler(self, context) + except Exception as exc: + # Exception in the user set custom exception handler. + try: + # Let's try default handler. + self.default_exception_handler({ + 'message': 'Unhandled error in exception handler', + 'exception': exc, + 'context': context, + }) + except Exception: + # Guard 'default_exception_handler' in case it is + # overloaded. + logger.error('Exception in default exception handler ' + 'while handling an unexpected error ' + 'in custom exception handler', + exc_info=True) + + def _add_callback(self, handle): + """Add a Handle to _scheduled (TimerHandle) or _ready.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' + if handle._cancelled: + return + assert not isinstance(handle, events.TimerHandle) + self._ready.append(handle) + + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + + def _timer_handle_cancelled(self, handle): + """Notification that a TimerHandle has been cancelled.""" + if handle._scheduled: + self._timer_cancelled_count += 1 + + def _run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + + sched_count = len(self._scheduled) + if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and + self._timer_cancelled_count / sched_count > + _MIN_CANCELLED_TIMER_HANDLES_FRACTION): + # Remove delayed calls that were cancelled if their number + # is too high + new_scheduled = [] + for handle in self._scheduled: + if handle._cancelled: + handle._scheduled = False + else: + new_scheduled.append(handle) + + heapq.heapify(new_scheduled) + self._scheduled = new_scheduled + self._timer_cancelled_count = 0 + else: + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + self._timer_cancelled_count -= 1 + handle = heapq.heappop(self._scheduled) + handle._scheduled = False + + timeout = None + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0]._when + timeout = max(0, when - self.time()) + + if self._debug and timeout != 0: + t0 = self.time() + event_list = self._selector.select(timeout) + dt = self.time() - t0 + if dt >= 1.0: + level = logging.INFO + else: + level = logging.DEBUG + nevent = len(event_list) + if timeout is None: + logger.log(level, 'poll took %.3f ms: %s events', + dt * 1e3, nevent) + elif nevent: + logger.log(level, + 'poll %.3f ms took %.3f ms: %s events', + timeout * 1e3, dt * 1e3, nevent) + elif dt >= 1.0: + logger.log(level, + 'poll %.3f ms took %.3f ms: timeout', + timeout * 1e3, dt * 1e3) + else: + event_list = self._selector.select(timeout) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + end_time = self.time() + self._clock_resolution + while self._scheduled: + handle = self._scheduled[0] + if handle._when >= end_time: + break + handle = heapq.heappop(self._scheduled) + handle._scheduled = False + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is thread-safe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if handle._cancelled: + continue + if self._debug: + try: + self._current_handle = handle + t0 = self.time() + handle._run() + dt = self.time() - t0 + if dt >= self.slow_callback_duration: + logger.warning('Executing %s took %.3f seconds', + _format_handle(handle), dt) + finally: + self._current_handle = None + else: + handle._run() + handle = None # Needed to break cycles when an exception occurs. + + def _set_coroutine_wrapper(self, enabled): + try: + set_wrapper = sys.set_coroutine_wrapper + get_wrapper = sys.get_coroutine_wrapper + except AttributeError: + return + + enabled = bool(enabled) + if self._coroutine_wrapper_set is enabled: + return + + wrapper = coroutines.debug_wrapper + current_wrapper = get_wrapper() + + if enabled: + if current_wrapper not in (None, wrapper): + warnings.warn( + "loop.set_debug(True): cannot set debug coroutine " + "wrapper; another wrapper is already set %r" % + current_wrapper, RuntimeWarning) + else: + set_wrapper(wrapper) + self._coroutine_wrapper_set = True + else: + if current_wrapper not in (None, wrapper): + warnings.warn( + "loop.set_debug(False): cannot unset debug coroutine " + "wrapper; another wrapper was set %r" % + current_wrapper, RuntimeWarning) + else: + set_wrapper(None) + self._coroutine_wrapper_set = False + + def get_debug(self): + return self._debug + + def set_debug(self, enabled): + self._debug = enabled + + if self.is_running(): + self._set_coroutine_wrapper(enabled) diff --git a/trollius/base_subprocess.py b/trollius/base_subprocess.py new file mode 100644 index 0000000..c1477b8 --- /dev/null +++ b/trollius/base_subprocess.py @@ -0,0 +1,275 @@ +import collections +import subprocess +import sys +import warnings + +from . import futures +from . import protocols +from . import transports +from .coroutines import coroutine +from .log import logger + + +class BaseSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + waiter=None, extra=None, **kwargs): + super().__init__(extra) + self._closed = False + self._protocol = protocol + self._loop = loop + self._proc = None + self._pid = None + self._returncode = None + self._exit_waiters = [] + self._pending_calls = collections.deque() + self._pipes = {} + self._finished = False + + if stdin == subprocess.PIPE: + self._pipes[0] = None + if stdout == subprocess.PIPE: + self._pipes[1] = None + if stderr == subprocess.PIPE: + self._pipes[2] = None + + # Create the child process: set the _proc attribute + self._start(args=args, shell=shell, stdin=stdin, stdout=stdout, + stderr=stderr, bufsize=bufsize, **kwargs) + self._pid = self._proc.pid + self._extra['subprocess'] = self._proc + + if self._loop.get_debug(): + if isinstance(args, (bytes, str)): + program = args + else: + program = args[0] + logger.debug('process %r created: pid %s', + program, self._pid) + + self._loop.create_task(self._connect_pipes(waiter)) + + def __repr__(self): + info = [self.__class__.__name__] + if self._closed: + info.append('closed') + if self._pid is not None: + info.append('pid=%s' % self._pid) + if self._returncode is not None: + info.append('returncode=%s' % self._returncode) + elif self._pid is not None: + info.append('running') + else: + info.append('not started') + + stdin = self._pipes.get(0) + if stdin is not None: + info.append('stdin=%s' % stdin.pipe) + + stdout = self._pipes.get(1) + stderr = self._pipes.get(2) + if stdout is not None and stderr is stdout: + info.append('stdout=stderr=%s' % stdout.pipe) + else: + if stdout is not None: + info.append('stdout=%s' % stdout.pipe) + if stderr is not None: + info.append('stderr=%s' % stderr.pipe) + + return '<%s>' % ' '.join(info) + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + raise NotImplementedError + + def close(self): + if self._closed: + return + self._closed = True + + for proto in self._pipes.values(): + if proto is None: + continue + proto.pipe.close() + + if (self._proc is not None + # the child process finished? + and self._returncode is None + # the child process finished but the transport was not notified yet? + and self._proc.poll() is None + ): + if self._loop.get_debug(): + logger.warning('Close running child process: kill %r', self) + + try: + self._proc.kill() + except ProcessLookupError: + pass + + # Don't clear the _proc reference yet: _post_init() may still run + + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if not self._closed: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self.close() + + def get_pid(self): + return self._pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def _check_proc(self): + if self._proc is None: + raise ProcessLookupError() + + def send_signal(self, signal): + self._check_proc() + self._proc.send_signal(signal) + + def terminate(self): + self._check_proc() + self._proc.terminate() + + def kill(self): + self._check_proc() + self._proc.kill() + + @coroutine + def _connect_pipes(self, waiter): + try: + proc = self._proc + loop = self._loop + + if proc.stdin is not None: + _, pipe = yield from loop.connect_write_pipe( + lambda: WriteSubprocessPipeProto(self, 0), + proc.stdin) + self._pipes[0] = pipe + + if proc.stdout is not None: + _, pipe = yield from loop.connect_read_pipe( + lambda: ReadSubprocessPipeProto(self, 1), + proc.stdout) + self._pipes[1] = pipe + + if proc.stderr is not None: + _, pipe = yield from loop.connect_read_pipe( + lambda: ReadSubprocessPipeProto(self, 2), + proc.stderr) + self._pipes[2] = pipe + + assert self._pending_calls is not None + + loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + loop.call_soon(callback, *data) + self._pending_calls = None + except Exception as exc: + if waiter is not None and not waiter.cancelled(): + waiter.set_exception(exc) + else: + if waiter is not None and not waiter.cancelled(): + waiter.set_result(None) + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + if self._loop.get_debug(): + logger.info('%r exited with return code %r', + self, returncode) + self._returncode = returncode + self._call(self._protocol.process_exited) + self._try_finish() + + # wake up futures waiting for wait() + for waiter in self._exit_waiters: + if not waiter.cancelled(): + waiter.set_result(returncode) + self._exit_waiters = None + + @coroutine + def _wait(self): + """Wait until the process exit and return the process return code. + + This method is a coroutine.""" + if self._returncode is not None: + return self._returncode + + waiter = futures.Future(loop=self._loop) + self._exit_waiters.append(waiter) + return (yield from waiter) + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._call(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._loop = None + self._proc = None + self._protocol = None + + +class WriteSubprocessPipeProto(protocols.BaseProtocol): + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.pipe = None + self.disconnected = False + + def connection_made(self, transport): + self.pipe = transport + + def __repr__(self): + return ('<%s fd=%s pipe=%r>' + % (self.__class__.__name__, self.fd, self.pipe)) + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + self.proc = None + + def pause_writing(self): + self.proc._protocol.pause_writing() + + def resume_writing(self): + self.proc._protocol.resume_writing() + + +class ReadSubprocessPipeProto(WriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) diff --git a/trollius/constants.py b/trollius/constants.py new file mode 100644 index 0000000..f9e1232 --- /dev/null +++ b/trollius/constants.py @@ -0,0 +1,7 @@ +"""Constants.""" + +# After the connection is lost, log warnings after this many write()s. +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 + +# Seconds to wait before retrying accept(). +ACCEPT_RETRY_DELAY = 1 diff --git a/trollius/coroutines.py b/trollius/coroutines.py new file mode 100644 index 0000000..15475f2 --- /dev/null +++ b/trollius/coroutines.py @@ -0,0 +1,301 @@ +__all__ = ['coroutine', + 'iscoroutinefunction', 'iscoroutine'] + +import functools +import inspect +import opcode +import os +import sys +import traceback +import types + +from . import events +from . import futures +from .log import logger + + +_PY35 = sys.version_info >= (3, 5) + + +# Opcode of "yield from" instruction +_YIELD_FROM = opcode.opmap['YIELD_FROM'] + +# If you set _DEBUG to true, @coroutine will wrap the resulting +# generator objects in a CoroWrapper instance (defined below). That +# instance will log a message when the generator is never iterated +# over, which may happen when you forget to use "yield from" with a +# coroutine call. Note that the value of the _DEBUG flag is taken +# when the decorator is used, so to be of any use it must be set +# before you define your coroutines. A downside of using this feature +# is that tracebacks show entries for the CoroWrapper.__next__ method +# when _DEBUG is true. +_DEBUG = (not sys.flags.ignore_environment + and bool(os.environ.get('PYTHONASYNCIODEBUG'))) + + +try: + _types_coroutine = types.coroutine +except AttributeError: + _types_coroutine = None + +try: + _inspect_iscoroutinefunction = inspect.iscoroutinefunction +except AttributeError: + _inspect_iscoroutinefunction = lambda func: False + +try: + from collections.abc import Coroutine as _CoroutineABC, \ + Awaitable as _AwaitableABC +except ImportError: + _CoroutineABC = _AwaitableABC = None + + +# Check for CPython issue #21209 +def has_yield_from_bug(): + class MyGen: + def __init__(self): + self.send_args = None + def __iter__(self): + return self + def __next__(self): + return 42 + def send(self, *what): + self.send_args = what + return None + def yield_from_gen(gen): + yield from gen + value = (1, 2, 3) + gen = MyGen() + coro = yield_from_gen(gen) + next(coro) + coro.send(value) + return gen.send_args != (value,) +_YIELD_FROM_BUG = has_yield_from_bug() +del has_yield_from_bug + + +def debug_wrapper(gen): + # This function is called from 'sys.set_coroutine_wrapper'. + # We only wrap here coroutines defined via 'async def' syntax. + # Generator-based coroutines are wrapped in @coroutine + # decorator. + return CoroWrapper(gen, None) + + +class CoroWrapper: + # Wrapper for coroutine object in _DEBUG mode. + + def __init__(self, gen, func=None): + assert inspect.isgenerator(gen) or inspect.iscoroutine(gen), gen + self.gen = gen + self.func = func # Used to unwrap @coroutine decorator + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + self.__name__ = getattr(gen, '__name__', None) + self.__qualname__ = getattr(gen, '__qualname__', None) + + def __repr__(self): + coro_repr = _format_coroutine(self) + if self._source_traceback: + frame = self._source_traceback[-1] + coro_repr += ', created at %s:%s' % (frame[0], frame[1]) + return '<%s %s>' % (self.__class__.__name__, coro_repr) + + def __iter__(self): + return self + + def __next__(self): + return self.gen.send(None) + + if _YIELD_FROM_BUG: + # For for CPython issue #21209: using "yield from" and a custom + # generator, generator.send(tuple) unpacks the tuple instead of passing + # the tuple unchanged. Check if the caller is a generator using "yield + # from" to decide if the parameter should be unpacked or not. + def send(self, *value): + frame = sys._getframe() + caller = frame.f_back + assert caller.f_lasti >= 0 + if caller.f_code.co_code[caller.f_lasti] != _YIELD_FROM: + value = value[0] + return self.gen.send(value) + else: + def send(self, value): + return self.gen.send(value) + + def throw(self, exc): + return self.gen.throw(exc) + + def close(self): + return self.gen.close() + + @property + def gi_frame(self): + return self.gen.gi_frame + + @property + def gi_running(self): + return self.gen.gi_running + + @property + def gi_code(self): + return self.gen.gi_code + + if _PY35: + + __await__ = __iter__ # make compatible with 'await' expression + + @property + def gi_yieldfrom(self): + return self.gen.gi_yieldfrom + + @property + def cr_await(self): + return self.gen.cr_await + + @property + def cr_running(self): + return self.gen.cr_running + + @property + def cr_code(self): + return self.gen.cr_code + + @property + def cr_frame(self): + return self.gen.cr_frame + + def __del__(self): + # Be careful accessing self.gen.frame -- self.gen might not exist. + gen = getattr(self, 'gen', None) + frame = getattr(gen, 'gi_frame', None) + if frame is None: + frame = getattr(gen, 'cr_frame', None) + if frame is not None and frame.f_lasti == -1: + msg = '%r was never yielded from' % self + tb = getattr(self, '_source_traceback', ()) + if tb: + tb = ''.join(traceback.format_list(tb)) + msg += ('\nCoroutine object created at ' + '(most recent call last):\n') + msg += tb.rstrip() + logger.error(msg) + + +def coroutine(func): + """Decorator to mark coroutines. + + If the coroutine is not yielded from before it is destroyed, + an error message is logged. + """ + if _inspect_iscoroutinefunction(func): + # In Python 3.5 that's all we need to do for coroutines + # defiend with "async def". + # Wrapping in CoroWrapper will happen via + # 'sys.set_coroutine_wrapper' function. + return func + + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + elif _AwaitableABC is not None: + # If 'func' returns an Awaitable (new in 3.5) we + # want to run it. + try: + await_meth = res.__await__ + except AttributeError: + pass + else: + if isinstance(res, _AwaitableABC): + res = yield from await_meth() + return res + + if not _DEBUG: + if _types_coroutine is None: + wrapper = coro + else: + wrapper = _types_coroutine(coro) + else: + @functools.wraps(func) + def wrapper(*args, **kwds): + w = CoroWrapper(coro(*args, **kwds), func=func) + if w._source_traceback: + del w._source_traceback[-1] + # Python < 3.5 does not implement __qualname__ + # on generator objects, so we set it manually. + # We use getattr as some callables (such as + # functools.partial may lack __qualname__). + w.__name__ = getattr(func, '__name__', None) + w.__qualname__ = getattr(func, '__qualname__', None) + return w + + wrapper._is_coroutine = True # For iscoroutinefunction(). + return wrapper + + +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (getattr(func, '_is_coroutine', False) or + _inspect_iscoroutinefunction(func)) + + +_COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) +if _CoroutineABC is not None: + _COROUTINE_TYPES += (_CoroutineABC,) + + +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return isinstance(obj, _COROUTINE_TYPES) + + +def _format_coroutine(coro): + assert iscoroutine(coro) + + coro_name = None + if isinstance(coro, CoroWrapper): + func = coro.func + coro_name = coro.__qualname__ + if coro_name is not None: + coro_name = '{}()'.format(coro_name) + else: + func = coro + + if coro_name is None: + coro_name = events._format_callback(func, ()) + + try: + coro_code = coro.gi_code + except AttributeError: + coro_code = coro.cr_code + + try: + coro_frame = coro.gi_frame + except AttributeError: + coro_frame = coro.cr_frame + + filename = coro_code.co_filename + if (isinstance(coro, CoroWrapper) + and not inspect.isgeneratorfunction(coro.func) + and coro.func is not None): + filename, lineno = events._get_function_source(coro.func) + if coro_frame is None: + coro_repr = ('%s done, defined at %s:%s' + % (coro_name, filename, lineno)) + else: + coro_repr = ('%s running, defined at %s:%s' + % (coro_name, filename, lineno)) + elif coro_frame is not None: + lineno = coro_frame.f_lineno + coro_repr = ('%s running at %s:%s' + % (coro_name, filename, lineno)) + else: + lineno = coro_code.co_firstlineno + coro_repr = ('%s done, defined at %s:%s' + % (coro_name, filename, lineno)) + + return coro_repr diff --git a/trollius/events.py b/trollius/events.py new file mode 100644 index 0000000..496075b --- /dev/null +++ b/trollius/events.py @@ -0,0 +1,611 @@ +"""Event loop and event loop policy.""" + +__all__ = ['AbstractEventLoopPolicy', + 'AbstractEventLoop', 'AbstractServer', + 'Handle', 'TimerHandle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + 'get_child_watcher', 'set_child_watcher', + ] + +import functools +import inspect +import reprlib +import socket +import subprocess +import sys +import threading +import traceback + + +_PY34 = sys.version_info >= (3, 4) + + +def _get_function_source(func): + if _PY34: + func = inspect.unwrap(func) + elif hasattr(func, '__wrapped__'): + func = func.__wrapped__ + if inspect.isfunction(func): + code = func.__code__ + return (code.co_filename, code.co_firstlineno) + if isinstance(func, functools.partial): + return _get_function_source(func.func) + if _PY34 and isinstance(func, functools.partialmethod): + return _get_function_source(func.func) + return None + + +def _format_args(args): + """Format function arguments. + + Special case for a single parameter: ('hello',) is formatted as ('hello'). + """ + # use reprlib to limit the length of the output + args_repr = reprlib.repr(args) + if len(args) == 1 and args_repr.endswith(',)'): + args_repr = args_repr[:-2] + ')' + return args_repr + + +def _format_callback(func, args, suffix=''): + if isinstance(func, functools.partial): + if args is not None: + suffix = _format_args(args) + suffix + return _format_callback(func.func, func.args, suffix) + + if hasattr(func, '__qualname__'): + func_repr = getattr(func, '__qualname__') + elif hasattr(func, '__name__'): + func_repr = getattr(func, '__name__') + else: + func_repr = repr(func) + + if args is not None: + func_repr += _format_args(args) + if suffix: + func_repr += suffix + return func_repr + +def _format_callback_source(func, args): + func_repr = _format_callback(func, args) + source = _get_function_source(func) + if source: + func_repr += ' at %s:%s' % source + return func_repr + + +class Handle: + """Object returned by callback registration methods.""" + + __slots__ = ('_callback', '_args', '_cancelled', '_loop', + '_source_traceback', '_repr', '__weakref__') + + def __init__(self, callback, args, loop): + assert not isinstance(callback, Handle), 'A Handle is not a callback' + self._loop = loop + self._callback = callback + self._args = args + self._cancelled = False + self._repr = None + if self._loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + else: + self._source_traceback = None + + def _repr_info(self): + info = [self.__class__.__name__] + if self._cancelled: + info.append('cancelled') + if self._callback is not None: + info.append(_format_callback_source(self._callback, self._args)) + if self._source_traceback: + frame = self._source_traceback[-1] + info.append('created at %s:%s' % (frame[0], frame[1])) + return info + + def __repr__(self): + if self._repr is not None: + return self._repr + info = self._repr_info() + return '<%s>' % ' '.join(info) + + def cancel(self): + if not self._cancelled: + self._cancelled = True + if self._loop.get_debug(): + # Keep a representation in debug mode to keep callback and + # parameters. For example, to log the warning + # "Executing <Handle...> took 2.5 second" + self._repr = repr(self) + self._callback = None + self._args = None + + def _run(self): + try: + self._callback(*self._args) + except Exception as exc: + cb = _format_callback_source(self._callback, self._args) + msg = 'Exception in callback {}'.format(cb) + context = { + 'message': msg, + 'exception': exc, + 'handle': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + self = None # Needed to break cycles when an exception occurs. + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + __slots__ = ['_scheduled', '_when'] + + def __init__(self, when, callback, args, loop): + assert when is not None + super().__init__(callback, args, loop) + if self._source_traceback: + del self._source_traceback[-1] + self._when = when + self._scheduled = False + + def _repr_info(self): + info = super()._repr_info() + pos = 2 if self._cancelled else 1 + info.insert(pos, 'when=%s' % self._when) + return info + + def __hash__(self): + return hash(self._when) + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + def cancel(self): + if not self._cancelled: + self._loop._timer_handle_cancelled(self) + super().cancel() + + +class AbstractServer: + """Abstract server returned by create_server().""" + + def close(self): + """Stop serving. This leaves existing connections open.""" + return NotImplemented + + def wait_closed(self): + """Coroutine to wait until service is closed.""" + return NotImplemented + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_until_complete(self, future): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + def is_closed(self): + """Returns True if the event loop was closed.""" + raise NotImplementedError + + def close(self): + """Close the loop. + + The loop should not be running. + + This is idempotent and irreversible. + + No other methods should be called after this one. + """ + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def _timer_handle_cancelled(self, handle): + """Notification that a TimerHandle has been cancelled.""" + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): + raise NotImplementedError + + # Method scheduling a coroutine object: create a task. + + def create_task(self, coro): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def run_in_executor(self, executor, func, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None, server_hostname=None): + raise NotImplementedError + + def create_server(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """A coroutine which creates a TCP server bound to host and port. + + The return value is a Server object which can be used to stop + the service. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + raise NotImplementedError + + def create_unix_server(self, protocol_factory, path, *, + sock=None, backlog=100, ssl=None): + """A coroutine which creates a UNIX Domain Socket server. + + The return value is a Server object, which can be used to stop + the service. + + path is a str, representing a file systsem path to bind the + server socket to. + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + """ + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + # Pipes and subprocesses. + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in event loop. Set the pipe to non-blocking mode. + + protocol_factory should instantiate object with Protocol interface. + pipe is a file-like object. + Return pair (transport, protocol), where transport supports the + ReadTransport interface.""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in event loop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport interface.""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + # Task factory. + + def set_task_factory(self, factory): + raise NotImplementedError + + def get_task_factory(self): + raise NotImplementedError + + # Error handlers. + + def set_exception_handler(self, handler): + raise NotImplementedError + + def default_exception_handler(self, context): + raise NotImplementedError + + def call_exception_handler(self, context): + raise NotImplementedError + + # Debug flag management. + + def get_debug(self): + raise NotImplementedError + + def set_debug(self, enabled): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """Get the event loop for the current context. + + Returns an event loop object implementing the BaseEventLoop interface, + or raises an exception in case no event loop has been set for the + current context and the current policy does not specify to create one. + + It should never return None.""" + raise NotImplementedError + + def set_event_loop(self, loop): + """Set the event loop for the current context to loop.""" + raise NotImplementedError + + def new_event_loop(self): + """Create and return a new event loop object according to this + policy's rules. If there's need to set this loop as the event loop for + the current context, set_event_loop must be called explicitly.""" + raise NotImplementedError + + # Child processes handling (Unix only). + + def get_child_watcher(self): + "Get the watcher for child processes." + raise NotImplementedError + + def set_child_watcher(self, watcher): + """Set the watcher for child processes.""" + raise NotImplementedError + + +class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _loop_factory = None + + class _Local(threading.local): + _loop = None + _set_called = False + + def __init__(self): + self._local = self._Local() + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._local._loop is None and + not self._local._set_called and + isinstance(threading.current_thread(), threading._MainThread)): + self.set_event_loop(self.new_event_loop()) + if self._local._loop is None: + raise RuntimeError('There is no current event loop in thread %r.' + % threading.current_thread().name) + return self._local._loop + + def set_event_loop(self, loop): + """Set the event loop.""" + self._local._set_called = True + assert loop is None or isinstance(loop, AbstractEventLoop) + self._local._loop = loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + return self._loop_factory() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + +# Lock for protecting the on-the-fly creation of the event loop policy. +_lock = threading.Lock() + + +def _init_event_loop_policy(): + global _event_loop_policy + with _lock: + if _event_loop_policy is None: # pragma: no branch + from . import DefaultEventLoopPolicy + _event_loop_policy = DefaultEventLoopPolicy() + + +def get_event_loop_policy(): + """Get the current event loop policy.""" + if _event_loop_policy is None: + _init_event_loop_policy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """Set the current event loop policy. + + If policy is None, the default policy is restored.""" + global _event_loop_policy + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """Equivalent to calling get_event_loop_policy().get_event_loop().""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(loop): + """Equivalent to calling get_event_loop_policy().set_event_loop(loop).""" + get_event_loop_policy().set_event_loop(loop) + + +def new_event_loop(): + """Equivalent to calling get_event_loop_policy().new_event_loop().""" + return get_event_loop_policy().new_event_loop() + + +def get_child_watcher(): + """Equivalent to calling get_event_loop_policy().get_child_watcher().""" + return get_event_loop_policy().get_child_watcher() + + +def set_child_watcher(watcher): + """Equivalent to calling + get_event_loop_policy().set_child_watcher(watcher).""" + return get_event_loop_policy().set_child_watcher(watcher) diff --git a/trollius/futures.py b/trollius/futures.py new file mode 100644 index 0000000..d06828a --- /dev/null +++ b/trollius/futures.py @@ -0,0 +1,413 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', + 'Future', 'wrap_future', + ] + +import concurrent.futures._base +import logging +import reprlib +import sys +import traceback + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +_PY34 = sys.version_info >= (3, 4) +_PY35 = sys.version_info >= (3, 5) + +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + + +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ('loop', 'source_traceback', 'exc', 'tb') + + def __init__(self, future, exc): + self.loop = future._loop + self.source_traceback = future._source_traceback + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.exc = None + self.tb = None + + def __del__(self): + if self.tb: + msg = 'Future/Task exception was never retrieved\n' + if self.source_traceback: + src = ''.join(traceback.format_list(self.source_traceback)) + msg += 'Future/Task created at (most recent call last):\n' + msg += '%s\n' % src.rstrip() + msg += ''.join(self.tb).rstrip() + self.loop.call_exception_handler({'message': msg}) + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _loop = None + _source_traceback = None + + _blocking = False # proper use of future (yield vs yield from) + + _log_traceback = False # Used for Python 3.4 and later + _tb_logger = None # Used for Python 3.3 only + + def __init__(self, *, loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._callbacks = [] + if self._loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + + def _format_callbacks(self): + cb = self._callbacks + size = len(cb) + if not size: + cb = '' + + def format_cb(callback): + return events._format_callback_source(callback, ()) + + if size == 1: + cb = format_cb(cb[0]) + elif size == 2: + cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1])) + elif size > 2: + cb = '{}, <{} more>, {}'.format(format_cb(cb[0]), + size-2, + format_cb(cb[-1])) + return 'cb=[%s]' % cb + + def _repr_info(self): + info = [self._state.lower()] + if self._state == _FINISHED: + if self._exception is not None: + info.append('exception={!r}'.format(self._exception)) + else: + # use reprlib to limit the length of the output, especially + # for very long strings + result = reprlib.repr(self._result) + info.append('result={}'.format(result)) + if self._callbacks: + info.append(self._format_callbacks()) + if self._source_traceback: + frame = self._source_traceback[-1] + info.append('created at %s:%s' % (frame[0], frame[1])) + return info + + def __repr__(self): + info = self._repr_info() + return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if _PY34: + def __del__(self): + if not self._log_traceback: + # set_exception() was not called, or result() or exception() + # has consumed the exception + return + exc = self._exception + context = { + 'message': ('%s exception was never retrieved' + % self.__class__.__name__), + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + # Don't implement running(); see http://bugs.python.org/issue18699 + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Result is not ready.') + self._log_traceback = False + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + if self._exception is not None: + raise self._exception + return self._result + + def exception(self): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Exception is not set.') + self._log_traceback = False + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def _set_result_unless_cancelled(self, result): + """Helper setting the result only if the future was not cancelled.""" + if self.cancelled(): + return + self.set_result(result) + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + if isinstance(exception, type): + exception = exception() + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + if _PY34: + self._log_traceback = True + else: + self._tb_logger = _TracebackLogger(self, exception) + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + if self.cancelled(): + return + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + if _PY35: + __await__ = __iter__ # make compatible with 'await' expression + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + if loop is None: + loop = events.get_event_loop() + new_future = Future(loop=loop) + + def _check_cancel_other(f): + if f.cancelled(): + fut.cancel() + + new_future.add_done_callback(_check_cancel_other) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, future)) + return new_future diff --git a/trollius/locks.py b/trollius/locks.py new file mode 100644 index 0000000..b2e516b --- /dev/null +++ b/trollius/locks.py @@ -0,0 +1,470 @@ +"""Synchronization primitives.""" + +__all__ = ['Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore'] + +import collections +import sys + +from . import events +from . import futures +from .coroutines import coroutine + + +_PY35 = sys.version_info >= (3, 5) + + +class _ContextManager: + """Context manager. + + This enables the following idiom for acquiring and releasing a + lock around a block: + + with (yield from lock): + <block> + + while failing loudly when accidentally using: + + with lock: + <block> + """ + + def __init__(self, lock): + self._lock = lock + + def __enter__(self): + # We have no use for the "as ..." clause in the with + # statement for locks. + return None + + def __exit__(self, *args): + try: + self._lock.release() + finally: + self._lock = None # Crudely prevent reuse. + + +class _ContextManagerMixin: + def __enter__(self): + raise RuntimeError( + '"yield from" should be used as context manager expression') + + def __exit__(self, *args): + # This must exist because __enter__ exists, even though that + # always raises; that's how the with-statement works. + pass + + @coroutine + def __iter__(self): + # This is not a coroutine. It is meant to enable the idiom: + # + # with (yield from lock): + # <block> + # + # as an alternative to: + # + # yield from lock.acquire() + # try: + # <block> + # finally: + # lock.release() + yield from self.acquire() + return _ContextManager(self) + + if _PY35: + + def __await__(self): + # To make "with await lock" work. + yield from self.acquire() + return _ContextManager(self) + + @coroutine + def __aenter__(self): + yield from self.acquire() + # We have no use for the "as ..." clause in the with + # statement for locks. + return None + + @coroutine + def __aexit__(self, exc_type, exc, tb): + self.release() + + +class Lock(_ContextManagerMixin): + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context management protocol. '(yield from lock)' + should be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Return True if lock is acquired.""" + return self._locked + + @coroutine + def acquire(self): + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + break + else: + raise RuntimeError('Lock is not acquired.') + + +class Event: + """Asynchronous equivalent to threading.Event. + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._value = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'set' if self._value else 'unset' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def is_set(self): + """Return True if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @coroutine + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. + """ + if self._value: + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + +class Condition(_ContextManagerMixin): + """Asynchronous equivalent to threading.Condition. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + + A new Lock object is created and used as the underlying lock. + """ + + def __init__(self, lock=None, *, loop=None): + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + if lock is None: + lock = Lock(loop=self._loop) + elif lock._loop is not self._loop: + raise ValueError("loop argument must agree with lock") + + self._lock = lock + # Export the lock's locked(), acquire() and release() methods. + self.locked = lock.locked + self.acquire = lock.acquire + self.release = lock.release + + self._waiters = collections.deque() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self.locked() else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + @coroutine + def wait(self): + """Wait until notified. + + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. + """ + if not self.locked(): + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + try: + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + finally: + yield from self.acquire() + + @coroutine + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. + """ + result = predicate() + while not result: + yield from self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self.locked(): + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) + + +class Semaphore(_ContextManagerMixin): + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context management protocol. + + The optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + """ + + def __init__(self, value=1, *, loop=None): + if value < 0: + raise ValueError("Semaphore initial value must be >= 0") + self._value = value + self._waiters = collections.deque() + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self.locked() else 'unlocked,value:{}'.format( + self._value) + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._value == 0 + + @coroutine + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._value -= 1 + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + """ + self._value += 1 + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + +class BoundedSemaphore(Semaphore): + """A bounded semaphore implementation. + + This raises ValueError in release() if it would increase the value + above the initial value. + """ + + def __init__(self, value=1, *, loop=None): + self._bound_value = value + super().__init__(value, loop=loop) + + def release(self): + if self._value >= self._bound_value: + raise ValueError('BoundedSemaphore released too many times') + super().release() diff --git a/trollius/log.py b/trollius/log.py new file mode 100644 index 0000000..23a7074 --- /dev/null +++ b/trollius/log.py @@ -0,0 +1,7 @@ +"""Logging configuration.""" + +import logging + + +# Name the logger after the package. +logger = logging.getLogger(__package__) diff --git a/trollius/proactor_events.py b/trollius/proactor_events.py new file mode 100644 index 0000000..9c2b8f1 --- /dev/null +++ b/trollius/proactor_events.py @@ -0,0 +1,547 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +__all__ = ['BaseProactorEventLoop'] + +import socket +import sys +import warnings + +from . import base_events +from . import constants +from . import futures +from . import sslproto +from . import transports +from .log import logger + + +class _ProactorBasePipeTransport(transports._FlowControlMixin, + transports.BaseTransport): + """Base class for pipe and socket transports.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(extra, loop) + self._set_extra(sock) + self._sock = sock + self._protocol = protocol + self._server = server + self._buffer = None # None or bytearray. + self._read_fut = None + self._write_fut = None + self._pending_write = 0 + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._eof_written = False + if self._server is not None: + self._server._attach() + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # only wake up the waiter when connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def __repr__(self): + info = [self.__class__.__name__] + if self._sock is None: + info.append('closed') + elif self._closing: + info.append('closing') + if self._sock is not None: + info.append('fd=%s' % self._sock.fileno()) + if self._read_fut is not None: + info.append('read=%s' % self._read_fut) + if self._write_fut is not None: + info.append("write=%r" % self._write_fut) + if self._buffer: + bufsize = len(self._buffer) + info.append('write_bufsize=%s' % bufsize) + if self._eof_written: + info.append('EOF written') + return '<%s>' % ' '.join(info) + + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + self._read_fut = None + + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if self._sock is not None: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self.close() + + def _fatal_error(self, exc, message='Fatal error on pipe transport'): + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + self._write_fut = None + if self._read_fut: + self._read_fut.cancel() + self._read_fut = None + self._pending_write = 0 + self._buffer = None + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + self._sock = None + server = self._server + if server is not None: + server._detach() + self._server = None + + def get_write_buffer_size(self): + size = self._pending_write + if self._buffer is not None: + size += len(self._buffer) + return size + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, waiter, extra, server) + self._paused = False + self._loop.call_soon(self._loop_reading) + + def pause_reading(self): + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') + self._paused = True + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) + + def resume_reading(self): + if not self._paused: + raise RuntimeError('Not paused') + self._paused = False + if self._closing: + return + self._loop.call_soon(self._loop_reading, self._read_fut) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) + + def _loop_reading(self, fut=None): + if self._paused: + return + data = None + + try: + if fut is not None: + assert self._read_fut is fut or (self._read_fut is None and + self._closing) + self._read_fut = None + data = fut.result() # deliver data later in "finally" clause + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc, 'Fatal read error on pipe transport') + elif self._loop.get_debug(): + logger.debug("Read error on pipe transport while closing", + exc_info=True) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc, 'Fatal read error on pipe transport') + except futures.CancelledError: + if not self._closing: + raise + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + keep_open = self._protocol.eof_received() + if not keep_open: + self.close() + + +class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + + def write(self, data): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if self._eof_written: + raise RuntimeError('write_eof() already called') + + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + # Observable states: + # 1. IDLE: _write_fut and _buffer both None + # 2. WRITING: _write_fut set; _buffer None + # 3. BACKED UP: _write_fut set; _buffer a bytearray + # We always copy the data, so the caller can't modify it + # while we're still waiting for the I/O to happen. + if self._write_fut is None: # IDLE -> WRITING + assert self._buffer is None + # Pass a copy, except if it's already immutable. + self._loop_writing(data=bytes(data)) + elif not self._buffer: # WRITING -> BACKED UP + # Make a mutable copy which we can extend. + self._buffer = bytearray(data) + self._maybe_pause_protocol() + else: # BACKED UP + # Append to buffer (also copies). + self._buffer.extend(data) + self._maybe_pause_protocol() + + def _loop_writing(self, f=None, data=None): + try: + assert f is self._write_fut + self._write_fut = None + self._pending_write = 0 + if f: + f.result() + if data is None: + data = self._buffer + self._buffer = None + if not data: + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + if self._eof_written: + self._sock.shutdown(socket.SHUT_WR) + # Now that we've reduced the buffer size, tell the + # protocol to resume writing if it was paused. Note that + # we do this last since the callback is called immediately + # and it may add more data to the buffer (even causing the + # protocol to be paused again). + self._maybe_resume_protocol() + else: + self._write_fut = self._loop._proactor.send(self._sock, data) + if not self._write_fut.done(): + assert self._pending_write == 0 + self._pending_write = len(data) + self._write_fut.add_done_callback(self._loop_writing) + self._maybe_pause_protocol() + else: + self._write_fut.add_done_callback(self._loop_writing) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc, 'Fatal write error on pipe transport') + + def can_write_eof(self): + return True + + def write_eof(self): + self.close() + + def abort(self): + self._force_close(None) + + +class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._read_fut = self._loop._proactor.recv(self._sock, 16) + self._read_fut.add_done_callback(self._pipe_closed) + + def _pipe_closed(self, fut): + if fut.cancelled(): + # the transport has been closed + return + assert fut.result() == b'' + if self._closing: + assert self._read_fut is None + return + assert fut is self._read_fut, (fut, self._read_fut) + self._read_fut = None + if self._write_fut is not None: + self._force_close(BrokenPipeError()) + else: + self.close() + + +class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, + _ProactorBaseWritePipeTransport, + transports.Transport): + """Transport for duplex pipes.""" + + def can_write_eof(self): + return False + + def write_eof(self): + raise NotImplementedError + + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorBaseWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock + try: + self._extra['sockname'] = sock.getsockname() + except (socket.error, AttributeError): + if self._loop.get_debug(): + logger.warning("getsockname() failed on %r", + sock, exc_info=True) + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except (socket.error, AttributeError): + if self._loop.get_debug(): + logger.warning("getpeername() failed on %r", + sock, exc_info=True) + + def can_write_eof(self): + return True + + def write_eof(self): + if self._closing or self._eof_written: + return + self._eof_written = True + if self._write_fut is None: + self._sock.shutdown(socket.SHUT_WR) + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + logger.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._self_reading_future = None + self._accept_futures = {} # socket file descriptor => Future + proactor.set_loop(self) + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, + extra=None, server=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None): + if not sslproto._is_sslproto_available(): + raise NotImplementedError("Proactor event loop requires Python 3.5" + " or newer (ssl.MemoryBIO) to support " + "SSL") + + ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, + server_side, server_hostname) + _ProactorSocketTransport(self, rawsock, ssl_protocol, + extra=extra, server=server) + return ssl_protocol._app_transport + + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorDuplexPipeTransport(self, + sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + # We want connection_lost() to be called when other end closes + return _ProactorWritePipeTransport(self, + sock, protocol, waiter, extra) + + def close(self): + if self.is_running(): + raise RuntimeError("Cannot close a running event loop") + if self.is_closed(): + return + + # Call these methods before closing the event loop (before calling + # BaseEventLoop.close), because they can schedule callbacks with + # call_soon(), which is forbidden when the event loop is closed. + self._stop_accept_futures() + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + # Close the event loop + super().close() + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + try: + if self._debug: + base_events._check_resolved_address(sock, address) + except ValueError as err: + fut = futures.Future(loop=self) + fut.set_exception(err) + return fut + else: + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + if self._self_reading_future is not None: + self._self_reading_future.cancel() + self._self_reading_future = None + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except futures.CancelledError: + # _close_self_pipe() has been called, stop waiting for data + return + except Exception as exc: + self.call_exception_handler({ + 'message': 'Error on reading from the event loop self pipe', + 'exception': exc, + 'loop': self, + }) + else: + self._self_reading_future = f + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'\0') + + def _start_serving(self, protocol_factory, sock, + sslcontext=None, server=None): + + def loop(f=None): + try: + if f is not None: + conn, addr = f.result() + if self._debug: + logger.debug("%r got a new connection from %r: %r", + server, addr, conn) + protocol = protocol_factory() + if sslcontext is not None: + self._make_ssl_transport( + conn, protocol, sslcontext, server_side=True, + extra={'peername': addr}, server=server) + else: + self._make_socket_transport( + conn, protocol, + extra={'peername': addr}, server=server) + if self.is_closed(): + return + f = self._proactor.accept(sock) + except OSError as exc: + if sock.fileno() != -1: + self.call_exception_handler({ + 'message': 'Accept failed on a socket', + 'exception': exc, + 'socket': sock, + }) + sock.close() + elif self._debug: + logger.debug("Accept failed on socket %r", + sock, exc_info=True) + except futures.CancelledError: + sock.close() + else: + self._accept_futures[sock.fileno()] = f + f.add_done_callback(loop) + + self.call_soon(loop) + + def _process_events(self, event_list): + # Events are processed in the IocpProactor._poll() method + pass + + def _stop_accept_futures(self): + for future in self._accept_futures.values(): + future.cancel() + self._accept_futures.clear() + + def _stop_serving(self, sock): + self._stop_accept_futures() + self._proactor._stop_serving(sock) + sock.close() diff --git a/trollius/protocols.py b/trollius/protocols.py new file mode 100644 index 0000000..80fcac9 --- /dev/null +++ b/trollius/protocols.py @@ -0,0 +1,134 @@ +"""Abstract Protocol class.""" + +__all__ = ['BaseProtocol', 'Protocol', 'DatagramProtocol', + 'SubprocessProtocol'] + + +class BaseProtocol: + """Common base class for protocol interfaces. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + def pause_writing(self): + """Called when the transport's buffer goes over the high-water mark. + + Pause and resume calls are paired -- pause_writing() is called + once when the buffer goes strictly over the high-water mark + (even if subsequent writes increases the buffer size even + more), and eventually resume_writing() is called once when the + buffer size reaches the low-water mark. + + Note that if the buffer size equals the high-water mark, + pause_writing() is not called -- it must go strictly over. + Conversely, resume_writing() is called when the buffer size is + equal or lower than the low-water mark. These end conditions + are important to ensure that things go as expected when either + mark is zero. + + NOTE: This is the only Protocol callback that is not called + through EventLoop.call_soon() -- if it were, it would have no + effect when it's most needed (when the app keeps writing + without yielding until pause_writing() is called). + """ + + def resume_writing(self): + """Called when the transport's buffer drains below the low-water mark. + + See pause_writing() for details. + """ + + +class Protocol(BaseProtocol): + """Interface for stream protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + + * CM: connection_made() + * DR: data_received() + * ER: eof_received() + * CL: connection_lost() + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + + +class DatagramProtocol(BaseProtocol): + """Interface for datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def error_received(self, exc): + """Called when a send or receive operation raises an OSError. + + (Other than BlockingIOError or InterruptedError.) + """ + + +class SubprocessProtocol(BaseProtocol): + """Interface for protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when the subprocess writes data into stdout/stderr pipe. + + fd is int file descriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited.""" diff --git a/trollius/queues.py b/trollius/queues.py new file mode 100644 index 0000000..ed11662 --- /dev/null +++ b/trollius/queues.py @@ -0,0 +1,293 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty', + 'JoinableQueue'] + +import collections +import heapq + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +class QueueEmpty(Exception): + """Exception raised when Queue.get_nowait() is called on a Queue object + which is empty. + """ + pass + + +class QueueFull(Exception): + """Exception raised when the Queue.put_nowait() method is called on a Queue + object which is full. + """ + pass + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded asyncio application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0, *, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._unfinished_tasks = 0 + self._finished = locks.Event(loop=self._loop) + self._finished.set() + self._init(maxsize) + + # These three are overridable in subclasses. + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + # End of the overridable methods. + + def __put_internal(self, item): + self._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _consume_done_getters(self): + # Delete waiters at the head of the get() queue who've timed out. + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() >= self._maxsize + + @coroutine + def put(self, item): + """Put an item into the queue. + + Put an item into the queue. If the queue is full, wait until a free + slot is available before adding item. + + This method is a coroutine. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + self.__put_internal(item) + + # getter cannot be cancelled, we just removed done getters + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize <= self.qsize(): + waiter = futures.Future(loop=self._loop) + + self._putters.append((item, waiter)) + yield from waiter + + else: + self.__put_internal(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise QueueFull. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + self.__put_internal(item) + + # getter cannot be cancelled, we just removed done getters + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize <= self.qsize(): + raise QueueFull + else: + self.__put_internal(item) + + @coroutine + def get(self): + """Remove and return an item from the queue. + + If queue is empty, wait until an item is available. + + This method is a coroutine. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self.__put_internal(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._loop.call_soon(putter._set_result_unless_cancelled, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future(loop=self._loop) + + self._getters.append(waiter) + return (yield from waiter) + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise QueueEmpty. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self.__put_internal(item) + # Wake putter on next tick. + + # getter cannot be cancelled, we just removed done putters + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise QueueEmpty + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer calls task_done() to + indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait() + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +JoinableQueue = Queue +"""Deprecated alias for Queue.""" diff --git a/trollius/selector_events.py b/trollius/selector_events.py new file mode 100644 index 0000000..7c5b9b5 --- /dev/null +++ b/trollius/selector_events.py @@ -0,0 +1,1068 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +__all__ = ['BaseSelectorEventLoop'] + +import collections +import errno +import functools +import socket +import sys +import warnings +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from . import sslproto +from .coroutines import coroutine +from .log import logger + + +def _test_selector_event(selector, fd, event): + # Test if the selector is monitoring 'event' events + # for the file descriptor 'fd'. + try: + key = selector.get_key(fd) + except KeyError: + return False + else: + return bool(key.events & event) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + logger.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None): + if not sslproto._is_sslproto_available(): + return self._make_legacy_ssl_transport( + rawsock, protocol, sslcontext, waiter, + server_side=server_side, server_hostname=server_hostname, + extra=extra, server=server) + + ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, + server_side, server_hostname) + _SelectorSocketTransport(self, rawsock, ssl_protocol, + extra=extra, server=server) + return ssl_protocol._app_transport + + def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext, + waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + # Use the legacy API: SSL_write, SSL_read, etc. The legacy API is used + # on Python 3.4 and older, when ssl.MemoryBIO is not available. + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, + server_side, server_hostname, extra, server) + + def _make_datagram_transport(self, sock, protocol, + address=None, waiter=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, + address, waiter, extra) + + def close(self): + if self.is_running(): + raise RuntimeError("Cannot close a running event loop") + if self.is_closed(): + return + self._close_self_pipe() + super().close() + if self._selector is not None: + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _process_self_data(self, data): + pass + + def _read_from_self(self): + while True: + try: + data = self._ssock.recv(4096) + if not data: + break + self._process_self_data(data) + except InterruptedError: + continue + except BlockingIOError: + break + + def _write_to_self(self): + # This may be called from a different thread, possibly after + # _close_self_pipe() has been called or even while it is + # running. Guard for self._csock being None or closed. When + # a socket is closed, send() raises OSError (with errno set to + # EBADF, but let's not rely on the exact error code). + csock = self._csock + if csock is not None: + try: + csock.send(b'\0') + except OSError: + if self._debug: + logger.debug("Fail to write a null byte into the " + "self-pipe socket", + exc_info=True) + + def _start_serving(self, protocol_factory, sock, + sslcontext=None, server=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, sslcontext, server) + + def _accept_connection(self, protocol_factory, sock, + sslcontext=None, server=None): + try: + conn, addr = sock.accept() + if self._debug: + logger.debug("%r got a new connection from %r: %r", + server, addr, conn) + conn.setblocking(False) + except (BlockingIOError, InterruptedError, ConnectionAbortedError): + pass # False alarm. + except OSError as exc: + # There's nowhere to send the error, so just log it. + if exc.errno in (errno.EMFILE, errno.ENFILE, + errno.ENOBUFS, errno.ENOMEM): + # Some platforms (e.g. Linux keep reporting the FD as + # ready, so we remove the read handler temporarily. + # We'll try again in a while. + self.call_exception_handler({ + 'message': 'socket.accept() out of system resource', + 'exception': exc, + 'socket': sock, + }) + self.remove_reader(sock.fileno()) + self.call_later(constants.ACCEPT_RETRY_DELAY, + self._start_serving, + protocol_factory, sock, sslcontext, server) + else: + raise # The event loop will catch, log and ignore it. + else: + extra = {'peername': addr} + accept = self._accept_connection2(protocol_factory, conn, extra, + sslcontext, server) + self.create_task(accept) + + @coroutine + def _accept_connection2(self, protocol_factory, conn, extra, + sslcontext=None, server=None): + protocol = None + transport = None + try: + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if sslcontext: + transport = self._make_ssl_transport( + conn, protocol, sslcontext, waiter=waiter, + server_side=True, extra=extra, server=server) + else: + transport = self._make_socket_transport( + conn, protocol, waiter=waiter, extra=extra, + server=server) + + try: + yield from waiter + except: + transport.close() + raise + + # It's now up to the protocol to handle the connection. + except Exception as exc: + if self._debug: + context = { + 'message': ('Error on transport creation ' + 'for incoming connection'), + 'exception': exc, + } + if protocol is not None: + context['protocol'] = protocol + if transport is not None: + context['transport'] = transport + self.call_exception_handler(context) + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + self._check_closed() + handle = events.Handle(callback, args, self) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + def remove_reader(self, fd): + """Remove a reader callback.""" + if self.is_closed(): + return False + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + self._check_closed() + handle = events.Handle(callback, args, self) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + def remove_writer(self, fd): + """Remove a writer callback.""" + if self.is_closed(): + return False + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """Receive data from the socket. + + The return value is a bytes object representing the data received. + The maximum amount of data to be received at once is specified by + nbytes. + + This method is a coroutine. + """ + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + # _sock_recv() can add itself as an I/O callback if the operation can't + # be done immediately. Don't use it directly, call sock_recv(). + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, sock, data): + """Send data to the socket. + + The socket must be connected to a remote socket. This method continues + to send data from data until either all data has been sent or an + error occurs. None is returned on success. On error, an exception is + raised, and there is no way to determine how much data, if any, was + successfully processed by the receiving end of the connection. + + This method is a coroutine. + """ + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """Connect to a remote socket at address. + + The address must be already resolved to avoid the trap of hanging the + entire event loop when the address requires doing a DNS lookup. For + example, it must be an IP address, not an hostname, for AF_INET and + AF_INET6 address families. Use getaddrinfo() to resolve the hostname + asynchronously. + + This method is a coroutine. + """ + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + try: + if self._debug: + base_events._check_resolved_address(sock, address) + except ValueError as err: + fut.set_exception(err) + else: + self._sock_connect(fut, sock, address) + return fut + + def _sock_connect(self, fut, sock, address): + fd = sock.fileno() + try: + sock.connect(address) + except (BlockingIOError, InterruptedError): + # Issue #23618: When the C function connect() fails with EINTR, the + # connection runs in background. We have to wait until the socket + # becomes writable to be notified when the connection succeed or + # fails. + fut.add_done_callback(functools.partial(self._sock_connect_done, + fd)) + self.add_writer(fd, self._sock_connect_cb, fut, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def _sock_connect_done(self, fd, fut): + self.remove_writer(fd) + + def _sock_connect_cb(self, fut, sock, address): + if fut.cancelled(): + return + + try: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to any except clause below. + raise OSError(err, 'Connect call failed %s' % (address,)) + except (BlockingIOError, InterruptedError): + # socket is still registered, the callback will be retried later + pass + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def sock_accept(self, sock): + """Accept a connection. + + The socket must be bound to an address and listening for connections. + The return value is a pair (conn, address) where conn is a new socket + object usable to send and receive data on the connection, and address + is the address bound to the socket on the other end of the connection. + + This method is a coroutine. + """ + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result((conn, address)) + + def _process_events(self, event_list): + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def _stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorTransport(transports._FlowControlMixin, + transports.Transport): + + max_size = 256 * 1024 # Buffer size passed to recv(). + + _buffer_factory = bytearray # Constructs initial value for self._buffer. + + # Attribute used in the destructor: it must be set even if the constructor + # is not called (see _SelectorSslTransport which may start by raising an + # exception) + _sock = None + + def __init__(self, loop, sock, protocol, extra=None, server=None): + super().__init__(extra, loop) + self._extra['socket'] = sock + self._extra['sockname'] = sock.getsockname() + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except socket.error: + self._extra['peername'] = None + self._sock = sock + self._sock_fd = sock.fileno() + self._protocol = protocol + self._protocol_connected = True + self._server = server + self._buffer = self._buffer_factory() + self._conn_lost = 0 # Set when call to connection_lost scheduled. + self._closing = False # Set when close() called. + if self._server is not None: + self._server._attach() + + def __repr__(self): + info = [self.__class__.__name__] + if self._sock is None: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % self._sock_fd) + # test if the transport was closed + if self._loop is not None and not self._loop.is_closed(): + polling = _test_selector_event(self._loop._selector, + self._sock_fd, selectors.EVENT_READ) + if polling: + info.append('read=polling') + else: + info.append('read=idle') + + polling = _test_selector_event(self._loop._selector, + self._sock_fd, + selectors.EVENT_WRITE) + if polling: + state = 'polling' + else: + state = 'idle' + + bufsize = self.get_write_buffer_size() + info.append('write=<%s, bufsize=%s>' % (state, bufsize)) + return '<%s>' % ' '.join(info) + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._conn_lost += 1 + self._loop.call_soon(self._call_connection_lost, None) + + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if self._sock is not None: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self._sock.close() + + def _fatal_error(self, exc, message='Fatal error on transport'): + # Should be called from exception handler only. + if isinstance(exc, (BrokenPipeError, + ConnectionResetError, ConnectionAbortedError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._force_close(exc) + + def _force_close(self, exc): + if self._conn_lost: + return + if self._buffer: + self._buffer.clear() + self._loop.remove_writer(self._sock_fd) + if not self._closing: + self._closing = True + self._loop.remove_reader(self._sock_fd) + self._conn_lost += 1 + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + if self._protocol_connected: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + server = self._server + if server is not None: + server._detach() + self._server = None + + def get_write_buffer_size(self): + return len(self._buffer) + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, extra, server) + self._eof = False + self._paused = False + + self._loop.call_soon(self._protocol.connection_made, self) + # only start reading when connection_made() has been called + self._loop.call_soon(self._loop.add_reader, + self._sock_fd, self._read_ready) + if waiter is not None: + # only wake up the waiter when connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def pause_reading(self): + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') + self._paused = True + self._loop.remove_reader(self._sock_fd) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) + + def resume_reading(self): + if not self._paused: + raise RuntimeError('Not paused') + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) + + def _read_ready(self): + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc, 'Fatal read error on socket transport') + else: + if data: + self._protocol.data_received(data) + else: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + keep_open = self._protocol.eof_received() + if keep_open: + # We're keeping the connection open so the + # protocol can write more, but we still can't + # receive more, so remove the reader callback. + self._loop.remove_reader(self._sock_fd) + else: + self.close() + + def write(self, data): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if self._eof: + raise RuntimeError('Cannot call write() after write_eof()') + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Optimization: try to send now. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc, 'Fatal write error on socket transport') + return + else: + data = data[n:] + if not data: + return + # Not all was written; register write handler. + self._loop.add_writer(self._sock_fd, self._write_ready) + + # Add it to the buffer. + self._buffer.extend(data) + self._maybe_pause_protocol() + + def _write_ready(self): + assert self._buffer, 'Data should not be empty' + + try: + n = self._sock.send(self._buffer) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + self._fatal_error(exc, 'Fatal write error on socket transport') + else: + if n: + del self._buffer[:n] + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + elif self._eof: + self._sock.shutdown(socket.SHUT_WR) + + def write_eof(self): + if self._eof: + return + self._eof = True + if not self._buffer: + self._sock.shutdown(socket.SHUT_WR) + + def can_write_eof(self): + return True + + +class _SelectorSslTransport(_SelectorTransport): + + _buffer_factory = bytearray + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, server_hostname=None, + extra=None, server=None): + if ssl is None: + raise RuntimeError('stdlib ssl module not available') + + if not sslcontext: + sslcontext = sslproto._create_transport_context(server_side, server_hostname) + + wrap_kwargs = { + 'server_side': server_side, + 'do_handshake_on_connect': False, + } + if server_hostname and not server_side: + wrap_kwargs['server_hostname'] = server_hostname + sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) + + super().__init__(loop, sslsock, protocol, extra, server) + # the protocol connection is only made after the SSL handshake + self._protocol_connected = False + + self._server_hostname = server_hostname + self._waiter = waiter + self._sslcontext = sslcontext + self._paused = False + + # SSL-specific extra info. (peercert is set later) + self._extra.update(sslcontext=sslcontext) + + if self._loop.get_debug(): + logger.debug("%r starts SSL handshake", self) + start_time = self._loop.time() + else: + start_time = None + self._on_handshake(start_time) + + def _wakeup_waiter(self, exc=None): + if self._waiter is None: + return + if not self._waiter.cancelled(): + if exc is not None: + self._waiter.set_exception(exc) + else: + self._waiter.set_result(None) + self._waiter = None + + def _on_handshake(self, start_time): + try: + self._sock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(self._sock_fd, + self._on_handshake, start_time) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(self._sock_fd, + self._on_handshake, start_time) + return + except BaseException as exc: + if self._loop.get_debug(): + logger.warning("%r: SSL handshake failed", + self, exc_info=True) + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._sock.close() + self._wakeup_waiter(exc) + if isinstance(exc, Exception): + return + else: + raise + + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + + peercert = self._sock.getpeercert() + if not hasattr(self._sslcontext, 'check_hostname'): + # Verify hostname if requested, Python 3.4+ uses check_hostname + # and checks the hostname in do_handshake() + if (self._server_hostname and + self._sslcontext.verify_mode != ssl.CERT_NONE): + try: + ssl.match_hostname(peercert, self._server_hostname) + except Exception as exc: + if self._loop.get_debug(): + logger.warning("%r: SSL handshake failed " + "on matching the hostname", + self, exc_info=True) + self._sock.close() + self._wakeup_waiter(exc) + return + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=peercert, + cipher=self._sock.cipher(), + compression=self._sock.compression(), + ) + + self._read_wants_write = False + self._write_wants_read = False + self._loop.add_reader(self._sock_fd, self._read_ready) + self._protocol_connected = True + self._loop.call_soon(self._protocol.connection_made, self) + # only wake up the waiter when connection_made() has been called + self._loop.call_soon(self._wakeup_waiter) + + if self._loop.get_debug(): + dt = self._loop.time() - start_time + logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) + + def pause_reading(self): + # XXX This is a bit icky, given the comment at the top of + # _read_ready(). Is it possible to evoke a deadlock? I don't + # know, although it doesn't look like it; write() will still + # accept more data for the buffer and eventually the app will + # call resume_reading() again, and things will flow again. + + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') + self._paused = True + self._loop.remove_reader(self._sock_fd) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) + + def resume_reading(self): + if not self._paused: + raise RuntimeError('Not paused') + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) + + def _read_ready(self): + if self._write_wants_read: + self._write_wants_read = False + self._write_ready() + + if self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): + pass + except ssl.SSLWantWriteError: + self._read_wants_write = True + self._loop.remove_reader(self._sock_fd) + self._loop.add_writer(self._sock_fd, self._write_ready) + except Exception as exc: + self._fatal_error(exc, 'Fatal read error on SSL transport') + else: + if data: + self._protocol.data_received(data) + else: + try: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + keep_open = self._protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + finally: + self.close() + + def _write_ready(self): + if self._read_wants_write: + self._read_wants_write = False + self._read_ready() + + if not (self._paused or self._closing): + self._loop.add_reader(self._sock_fd, self._read_ready) + + if self._buffer: + try: + n = self._sock.send(self._buffer) + except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError): + n = 0 + except ssl.SSLWantReadError: + n = 0 + self._loop.remove_writer(self._sock_fd) + self._write_wants_read = True + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + self._fatal_error(exc, 'Fatal write error on SSL transport') + return + + if n: + del self._buffer[:n] + + self._maybe_resume_protocol() # May append to buffer. + + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + + def write(self, data): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + + # Add it to the buffer. + self._buffer.extend(data) + self._maybe_pause_protocol() + + def can_write_eof(self): + return False + + +class _SelectorDatagramTransport(_SelectorTransport): + + _buffer_factory = collections.deque + + def __init__(self, loop, sock, protocol, address=None, + waiter=None, extra=None): + super().__init__(loop, sock, protocol, extra) + self._address = address + self._loop.call_soon(self._protocol.connection_made, self) + # only start reading when connection_made() has been called + self._loop.call_soon(self._loop.add_reader, + self._sock_fd, self._read_ready) + if waiter is not None: + # only wake up the waiter when connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def get_write_buffer_size(self): + return sum(len(data) for data, _ in self._buffer) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._protocol.error_received(exc) + except Exception as exc: + self._fatal_error(exc, 'Fatal read error on datagram transport') + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if not data: + return + + if self._address and addr not in (None, self._address): + raise ValueError('Invalid address: must be None or %s' % + (self._address,)) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) + except OSError as exc: + self._protocol.error_received(exc) + return + except Exception as exc: + self._fatal_error(exc, + 'Fatal write error on datagram transport') + return + + # Ensure that what we buffer is immutable. + self._buffer.append((bytes(data), addr)) + self._maybe_pause_protocol() + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except OSError as exc: + self._protocol.error_received(exc) + return + except Exception as exc: + self._fatal_error(exc, + 'Fatal write error on datagram transport') + return + + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) diff --git a/trollius/selectors.py b/trollius/selectors.py new file mode 100644 index 0000000..6d569c3 --- /dev/null +++ b/trollius/selectors.py @@ -0,0 +1,594 @@ +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. +""" + + +from abc import ABCMeta, abstractmethod +from collections import namedtuple, Mapping +import math +import select +import sys + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + + Raises: + ValueError if the object is invalid + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) from None + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class _SelectorMapping(Mapping): + """Mapping of file objects to selector keys.""" + + def __init__(self, selector): + self._selector = selector + + def __len__(self): + return len(self._selector._fd_to_key) + + def __getitem__(self, fileobj): + try: + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key[fd] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + + def __iter__(self): + return iter(self._selector._fd_to_key) + + +class BaseSelector(metaclass=ABCMeta): + """Selector abstract base class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + efficient implementation on the current platform. + """ + + @abstractmethod + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + ValueError if events is invalid + KeyError if fileobj is already registered + OSError if fileobj is closed or otherwise is unacceptable to + the underlying system call (if a system call is made) + + Note: + OSError may or may not be raised + """ + raise NotImplementedError + + @abstractmethod + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + + Raises: + KeyError if fileobj is not registered + + Note: + If fileobj is registered but has since been closed this does + *not* raise OSError (even if the wrapped syscall does) + """ + raise NotImplementedError + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + Anything that unregister() or register() raises + """ + self.unregister(fileobj) + return self.register(fileobj, events, data) + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + pass + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + mapping = self.get_map() + if mapping is None: + raise RuntimeError('Selector is closed') + try: + return mapping[fileobj] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + + @abstractmethod + def get_map(self): + """Return a mapping of file objects to selector keys.""" + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class _BaseSelectorImpl(BaseSelector): + """Base selector implementation.""" + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # read-only mapping returned by get_map() + self._map = _SelectorMapping(self) + + def _fileobj_lookup(self, fileobj): + """Return a file descriptor from a file object. + + This wraps _fileobj_to_fd() to do an exhaustive search in case + the object is invalid but we still have it in our map. This + is used by unregister() so we can unregister an object that + was previously registered even if it is closed. It is also + used by _SelectorMapping. + """ + try: + return _fileobj_to_fd(fileobj) + except ValueError: + # Do an exhaustive search. + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + return key.fd + # Raise ValueError after all. + raise + + def register(self, fileobj, events, data=None): + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {!r}".format(events)) + + key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already registered" + .format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + try: + key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + return key + + def modify(self, fileobj, events, data=None): + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fd_to_key[self._fileobj_lookup(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + if events != key.events: + self.unregister(fileobj) + key = self.register(fileobj, events, data) + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + return key + + def close(self): + self._fd_to_key.clear() + self._map = None + + def get_map(self): + return self._map + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(_BaseSelectorImpl): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + return ready + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(_BaseSelectorImpl): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # poll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(_BaseSelectorImpl): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + try: + self._epoll.unregister(key.fd) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + return key + + def select(self, timeout=None): + if timeout is None: + timeout = -1 + elif timeout <= 0: + timeout = 0 + else: + # epoll_wait() has a resolution of 1 millisecond, round away + # from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) * 1e-3 + + # epoll_wait() expects `maxevents` to be greater than zero; + # we want to make sure that `select()` can be called when no + # FD is registered. + max_ev = max(len(self._fd_to_key), 1) + + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.EPOLLIN: + events |= EVENT_WRITE + if event & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._epoll.close() + super().close() + + +if hasattr(select, 'devpoll'): + + class DevpollSelector(_BaseSelectorImpl): + """Solaris /dev/poll selector.""" + + def __init__(self): + super().__init__() + self._devpoll = select.devpoll() + + def fileno(self): + return self._devpoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._devpoll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._devpoll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # devpoll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._devpoll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._devpoll.close() + super().close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(_BaseSelectorImpl): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # See comment above. + pass + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._kqueue.close() + super().close() + + +# Choose the best implementation, roughly: +# epoll|kqueue|devpoll > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'DevpollSelector' in globals(): + DefaultSelector = DevpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/trollius/sslproto.py b/trollius/sslproto.py new file mode 100644 index 0000000..235855e --- /dev/null +++ b/trollius/sslproto.py @@ -0,0 +1,668 @@ +import collections +import sys +import warnings +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import protocols +from . import transports +from .log import logger + + +def _create_transport_context(server_side, server_hostname): + if server_side: + raise ValueError('Server side SSL needs a valid SSLContext') + + # Client side may pass ssl=True to use a default + # context; in that case the sslcontext passed is None. + # The default is secure for client connections. + if hasattr(ssl, 'create_default_context'): + # Python 3.4+: use up-to-date strong settings. + sslcontext = ssl.create_default_context() + if not server_hostname: + sslcontext.check_hostname = False + else: + # Fallback for Python 3.3. + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED + return sslcontext + + +def _is_sslproto_available(): + return hasattr(ssl, "MemoryBIO") + + +# States of an _SSLPipe. +_UNWRAPPED = "UNWRAPPED" +_DO_HANDSHAKE = "DO_HANDSHAKE" +_WRAPPED = "WRAPPED" +_SHUTDOWN = "SHUTDOWN" + + +class _SSLPipe(object): + """An SSL "Pipe". + + An SSL pipe allows you to communicate with an SSL/TLS protocol instance + through memory buffers. It can be used to implement a security layer for an + existing connection where you don't have access to the connection's file + descriptor, or for some reason you don't want to use it. + + An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, + data is passed through untransformed. In wrapped mode, application level + data is encrypted to SSL record level data and vice versa. The SSL record + level is the lowest level in the SSL protocol suite and is what travels + as-is over the wire. + + An SslPipe initially is in "unwrapped" mode. To start SSL, call + do_handshake(). To shutdown SSL again, call unwrap(). + """ + + max_size = 256 * 1024 # Buffer size passed to read() + + def __init__(self, context, server_side, server_hostname=None): + """ + The *context* argument specifies the ssl.SSLContext to use. + + The *server_side* argument indicates whether this is a server side or + client side transport. + + The optional *server_hostname* argument can be used to specify the + hostname you are connecting to. You may only specify this parameter if + the _ssl module supports Server Name Indication (SNI). + """ + self._context = context + self._server_side = server_side + self._server_hostname = server_hostname + self._state = _UNWRAPPED + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + self._sslobj = None + self._need_ssldata = False + self._handshake_cb = None + self._shutdown_cb = None + + @property + def context(self): + """The SSL context passed to the constructor.""" + return self._context + + @property + def ssl_object(self): + """The internal ssl.SSLObject instance. + + Return None if the pipe is not wrapped. + """ + return self._sslobj + + @property + def need_ssldata(self): + """Whether more record level data is needed to complete a handshake + that is currently in progress.""" + return self._need_ssldata + + @property + def wrapped(self): + """ + Whether a security layer is currently in effect. + + Return False during handshake. + """ + return self._state == _WRAPPED + + def do_handshake(self, callback=None): + """Start the SSL handshake. + + Return a list of ssldata. A ssldata element is a list of buffers + + The optional *callback* argument can be used to install a callback that + will be called when the handshake is complete. The callback will be + called with None if successful, else an exception instance. + """ + if self._state != _UNWRAPPED: + raise RuntimeError('handshake in progress or completed') + self._sslobj = self._context.wrap_bio( + self._incoming, self._outgoing, + server_side=self._server_side, + server_hostname=self._server_hostname) + self._state = _DO_HANDSHAKE + self._handshake_cb = callback + ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) + assert len(appdata) == 0 + return ssldata + + def shutdown(self, callback=None): + """Start the SSL shutdown sequence. + + Return a list of ssldata. A ssldata element is a list of buffers + + The optional *callback* argument can be used to install a callback that + will be called when the shutdown is complete. The callback will be + called without arguments. + """ + if self._state == _UNWRAPPED: + raise RuntimeError('no security layer present') + if self._state == _SHUTDOWN: + raise RuntimeError('shutdown in progress') + assert self._state in (_WRAPPED, _DO_HANDSHAKE) + self._state = _SHUTDOWN + self._shutdown_cb = callback + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + return ssldata + + def feed_eof(self): + """Send a potentially "ragged" EOF. + + This method will raise an SSL_ERROR_EOF exception if the EOF is + unexpected. + """ + self._incoming.write_eof() + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + + def feed_ssldata(self, data, only_handshake=False): + """Feed SSL record level data into the pipe. + + The data must be a bytes instance. It is OK to send an empty bytes + instance. This can be used to get ssldata for a handshake initiated by + this endpoint. + + Return a (ssldata, appdata) tuple. The ssldata element is a list of + buffers containing SSL data that needs to be sent to the remote SSL. + + The appdata element is a list of buffers containing plaintext data that + needs to be forwarded to the application. The appdata list may contain + an empty buffer indicating an SSL "close_notify" alert. This alert must + be acknowledged by calling shutdown(). + """ + if self._state == _UNWRAPPED: + # If unwrapped, pass plaintext data straight through. + if data: + appdata = [data] + else: + appdata = [] + return ([], appdata) + + self._need_ssldata = False + if data: + self._incoming.write(data) + + ssldata = [] + appdata = [] + try: + if self._state == _DO_HANDSHAKE: + # Call do_handshake() until it doesn't raise anymore. + self._sslobj.do_handshake() + self._state = _WRAPPED + if self._handshake_cb: + self._handshake_cb(None) + if only_handshake: + return (ssldata, appdata) + # Handshake done: execute the wrapped block + + if self._state == _WRAPPED: + # Main state: read data from SSL until close_notify + while True: + chunk = self._sslobj.read(self.max_size) + appdata.append(chunk) + if not chunk: # close_notify + break + + elif self._state == _SHUTDOWN: + # Call shutdown() until it doesn't raise anymore. + self._sslobj.unwrap() + self._sslobj = None + self._state = _UNWRAPPED + if self._shutdown_cb: + self._shutdown_cb() + + elif self._state == _UNWRAPPED: + # Drain possible plaintext data after close_notify. + appdata.append(self._incoming.read()) + except (ssl.SSLError, ssl.CertificateError) as exc: + if getattr(exc, 'errno', None) not in ( + ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + if self._state == _DO_HANDSHAKE and self._handshake_cb: + self._handshake_cb(exc) + raise + self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) + + # Check for record level data that needs to be sent back. + # Happens for the initial handshake and renegotiations. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + return (ssldata, appdata) + + def feed_appdata(self, data, offset=0): + """Feed plaintext data into the pipe. + + Return an (ssldata, offset) tuple. The ssldata element is a list of + buffers containing record level data that needs to be sent to the + remote SSL instance. The offset is the number of plaintext bytes that + were processed, which may be less than the length of data. + + NOTE: In case of short writes, this call MUST be retried with the SAME + buffer passed into the *data* argument (i.e. the id() must be the + same). This is an OpenSSL requirement. A further particularity is that + a short write will always have offset == 0, because the _ssl module + does not enable partial writes. And even though the offset is zero, + there will still be encrypted data in ssldata. + """ + assert 0 <= offset <= len(data) + if self._state == _UNWRAPPED: + # pass through data in unwrapped mode + if offset < len(data): + ssldata = [data[offset:]] + else: + ssldata = [] + return (ssldata, len(data)) + + ssldata = [] + view = memoryview(data) + while True: + self._need_ssldata = False + try: + if offset < len(view): + offset += self._sslobj.write(view[offset:]) + except ssl.SSLError as exc: + # It is not allowed to call write() after unwrap() until the + # close_notify is acknowledged. We return the condition to the + # caller as a short write. + if exc.reason == 'PROTOCOL_IS_SHUTDOWN': + exc.errno = ssl.SSL_ERROR_WANT_READ + if exc.errno not in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) + + # See if there's any record level data back for us. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + if offset == len(view) or self._need_ssldata: + break + return (ssldata, offset) + + +class _SSLProtocolTransport(transports._FlowControlMixin, + transports.Transport): + + def __init__(self, loop, ssl_protocol, app_protocol): + self._loop = loop + self._ssl_protocol = ssl_protocol + self._app_protocol = app_protocol + self._closed = False + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._ssl_protocol._get_extra_info(name, default) + + def close(self): + """Close the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + self._closed = True + self._ssl_protocol._start_shutdown() + + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if not self._closed: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self.close() + + def pause_reading(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume_reading() is called. + """ + self._ssl_protocol._transport.pause_reading() + + def resume_reading(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + self._ssl_protocol._transport.resume_reading() + + def set_write_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for write flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + self._ssl_protocol._transport.set_write_buffer_limits(high, low) + + def get_write_buffer_size(self): + """Return the current size of the write buffer.""" + return self._ssl_protocol._transport.get_write_buffer_size() + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError("data: expecting a bytes-like instance, got {!r}" + .format(type(data).__name__)) + if not data: + return + self._ssl_protocol._write_appdata(data) + + def can_write_eof(self): + """Return True if this transport supports write_eof(), False if not.""" + return False + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + self._ssl_protocol._abort() + + +class SSLProtocol(protocols.Protocol): + """SSL protocol. + + Implementation of SSL on top of a socket using incoming and outgoing + buffers which are ssl.MemoryBIO objects. + """ + + def __init__(self, loop, app_protocol, sslcontext, waiter, + server_side=False, server_hostname=None): + if ssl is None: + raise RuntimeError('stdlib ssl module not available') + + if not sslcontext: + sslcontext = _create_transport_context(server_side, server_hostname) + + self._server_side = server_side + if server_hostname and not server_side: + self._server_hostname = server_hostname + else: + self._server_hostname = None + self._sslcontext = sslcontext + # SSL-specific extra info. More info are set when the handshake + # completes. + self._extra = dict(sslcontext=sslcontext) + + # App data write buffering + self._write_backlog = collections.deque() + self._write_buffer_size = 0 + + self._waiter = waiter + self._loop = loop + self._app_protocol = app_protocol + self._app_transport = _SSLProtocolTransport(self._loop, + self, self._app_protocol) + self._sslpipe = None + self._session_established = False + self._in_handshake = False + self._in_shutdown = False + self._transport = None + + def _wakeup_waiter(self, exc=None): + if self._waiter is None: + return + if not self._waiter.cancelled(): + if exc is not None: + self._waiter.set_exception(exc) + else: + self._waiter.set_result(None) + self._waiter = None + + def connection_made(self, transport): + """Called when the low-level connection is made. + + Start the SSL handshake. + """ + self._transport = transport + self._sslpipe = _SSLPipe(self._sslcontext, + self._server_side, + self._server_hostname) + self._start_handshake() + + def connection_lost(self, exc): + """Called when the low-level connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + if self._session_established: + self._session_established = False + self._loop.call_soon(self._app_protocol.connection_lost, exc) + self._transport = None + self._app_transport = None + + def pause_writing(self): + """Called when the low-level transport's buffer goes over + the high-water mark. + """ + self._app_protocol.pause_writing() + + def resume_writing(self): + """Called when the low-level transport's buffer drains below + the low-water mark. + """ + self._app_protocol.resume_writing() + + def data_received(self, data): + """Called when some SSL data is received. + + The argument is a bytes object. + """ + try: + ssldata, appdata = self._sslpipe.feed_ssldata(data) + except ssl.SSLError as e: + if self._loop.get_debug(): + logger.warning('%r: SSL error %s (reason %s)', + self, e.errno, e.reason) + self._abort() + return + + for chunk in ssldata: + self._transport.write(chunk) + + for chunk in appdata: + if chunk: + self._app_protocol.data_received(chunk) + else: + self._start_shutdown() + break + + def eof_received(self): + """Called when the other end of the low-level stream + is half-closed. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + try: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + + self._wakeup_waiter(ConnectionResetError) + + if not self._in_handshake: + keep_open = self._app_protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + finally: + self._transport.close() + + def _get_extra_info(self, name, default=None): + if name in self._extra: + return self._extra[name] + else: + return self._transport.get_extra_info(name, default) + + def _start_shutdown(self): + if self._in_shutdown: + return + self._in_shutdown = True + self._write_appdata(b'') + + def _write_appdata(self, data): + self._write_backlog.append((data, 0)) + self._write_buffer_size += len(data) + self._process_write_backlog() + + def _start_handshake(self): + if self._loop.get_debug(): + logger.debug("%r starts SSL handshake", self) + self._handshake_start_time = self._loop.time() + else: + self._handshake_start_time = None + self._in_handshake = True + # (b'', 1) is a special value in _process_write_backlog() to do + # the SSL handshake + self._write_backlog.append((b'', 1)) + self._loop.call_soon(self._process_write_backlog) + + def _on_handshake_complete(self, handshake_exc): + self._in_handshake = False + + sslobj = self._sslpipe.ssl_object + try: + if handshake_exc is not None: + raise handshake_exc + + peercert = sslobj.getpeercert() + if not hasattr(self._sslcontext, 'check_hostname'): + # Verify hostname if requested, Python 3.4+ uses check_hostname + # and checks the hostname in do_handshake() + if (self._server_hostname + and self._sslcontext.verify_mode != ssl.CERT_NONE): + ssl.match_hostname(peercert, self._server_hostname) + except BaseException as exc: + if self._loop.get_debug(): + if isinstance(exc, ssl.CertificateError): + logger.warning("%r: SSL handshake failed " + "on verifying the certificate", + self, exc_info=True) + else: + logger.warning("%r: SSL handshake failed", + self, exc_info=True) + self._transport.close() + if isinstance(exc, Exception): + self._wakeup_waiter(exc) + return + else: + raise + + if self._loop.get_debug(): + dt = self._loop.time() - self._handshake_start_time + logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=peercert, + cipher=sslobj.cipher(), + compression=sslobj.compression(), + ) + self._app_protocol.connection_made(self._app_transport) + self._wakeup_waiter() + self._session_established = True + # In case transport.write() was already called. Don't call + # immediatly _process_write_backlog(), but schedule it: + # _on_handshake_complete() can be called indirectly from + # _process_write_backlog(), and _process_write_backlog() is not + # reentrant. + self._loop.call_soon(self._process_write_backlog) + + def _process_write_backlog(self): + # Try to make progress on the write backlog. + if self._transport is None: + return + + try: + for i in range(len(self._write_backlog)): + data, offset = self._write_backlog[0] + if data: + ssldata, offset = self._sslpipe.feed_appdata(data, offset) + elif offset: + ssldata = self._sslpipe.do_handshake(self._on_handshake_complete) + offset = 1 + else: + ssldata = self._sslpipe.shutdown(self._finalize) + offset = 1 + + for chunk in ssldata: + self._transport.write(chunk) + + if offset < len(data): + self._write_backlog[0] = (data, offset) + # A short write means that a write is blocked on a read + # We need to enable reading if it is paused! + assert self._sslpipe.need_ssldata + if self._transport._paused: + self._transport.resume_reading() + break + + # An entire chunk from the backlog was processed. We can + # delete it and reduce the outstanding buffer size. + del self._write_backlog[0] + self._write_buffer_size -= len(data) + except BaseException as exc: + if self._in_handshake: + self._on_handshake_complete(exc) + else: + self._fatal_error(exc, 'Fatal error on SSL transport') + + def _fatal_error(self, exc, message='Fatal error on transport'): + # Should be called from exception handler only. + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self._transport, + 'protocol': self, + }) + if self._transport: + self._transport._force_close(exc) + + def _finalize(self): + if self._transport is not None: + self._transport.close() + + def _abort(self): + if self._transport is not None: + try: + self._transport.abort() + finally: + self._finalize() diff --git a/trollius/streams.py b/trollius/streams.py new file mode 100644 index 0000000..176c65e --- /dev/null +++ b/trollius/streams.py @@ -0,0 +1,501 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', + 'open_connection', 'start_server', + 'IncompleteReadError', + ] + +import socket +import sys + +if hasattr(socket, 'AF_UNIX'): + __all__.extend(['open_unix_connection', 'start_unix_server']) + +from . import coroutines +from . import events +from . import futures +from . import protocols +from .coroutines import coroutine +from .log import logger + + +_DEFAULT_LIMIT = 2**16 +_PY35 = sys.version_info >= (3, 5) + + +class IncompleteReadError(EOFError): + """ + Incomplete read error. Attributes: + + - partial: read bytes string before the end of stream was reached + - expected: total number of expected bytes + """ + def __init__(self, partial, expected): + EOFError.__init__(self, "%s bytes read on a total of %s expected bytes" + % (len(partial), expected)) + self.partial = partial + self.expected = expected + + +@coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + StreamWriter instance. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer + + +@coroutine +def start_server(client_connected_cb, host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Start a socket server, call back for each client connected. + + The first parameter, `client_connected_cb`, takes two parameters: + client_reader, client_writer. client_reader is a StreamReader + object, while client_writer is a StreamWriter object. This + parameter can either be a plain callback function or a coroutine; + if it is a coroutine, it will be automatically converted into a + Task. + + The rest of the arguments are all the usual arguments to + loop.create_server() except protocol_factory; most common are + positional host and port, with various optional keyword arguments + following. The return value is the same as loop.create_server(). + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + The return value is the same as loop.create_server(), i.e. a + Server object which can be used to stop the service. + """ + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_server(factory, host, port, **kwds)) + + +if hasattr(socket, 'AF_UNIX'): + # UNIX Domain Sockets are supported on this platform + + @coroutine + def open_unix_connection(path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `open_connection` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.create_unix_connection( + lambda: protocol, path, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer + + + @coroutine + def start_unix_server(client_connected_cb, path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `start_server` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_unix_server(factory, path, **kwds)) + + +class FlowControlMixin(protocols.Protocol): + """Reusable flow control logic for StreamWriter.drain(). + + This implements the protocol methods pause_writing(), + resume_reading() and connection_lost(). If the subclass overrides + these it must call the super methods. + + StreamWriter.drain() must wait for _drain_helper() coroutine. + """ + + def __init__(self, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._paused = False + self._drain_waiter = None + self._connection_lost = False + + def pause_writing(self): + assert not self._paused + self._paused = True + if self._loop.get_debug(): + logger.debug("%r pauses writing", self) + + def resume_writing(self): + assert self._paused + self._paused = False + if self._loop.get_debug(): + logger.debug("%r resumes writing", self) + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def connection_lost(self, exc): + self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + @coroutine + def _drain_helper(self): + if self._connection_lost: + raise ConnectionResetError('Connection lost') + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = futures.Future(loop=self._loop) + self._drain_waiter = waiter + yield from waiter + + +class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): + """Helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader, client_connected_cb=None, loop=None): + super().__init__(loop=loop) + self._stream_reader = stream_reader + self._stream_writer = None + self._client_connected_cb = client_connected_cb + + def connection_made(self, transport): + self._stream_reader.set_transport(transport) + if self._client_connected_cb is not None: + self._stream_writer = StreamWriter(transport, self, + self._stream_reader, + self._loop) + res = self._client_connected_cb(self._stream_reader, + self._stream_writer) + if coroutines.iscoroutine(res): + self._loop.create_task(res) + + def connection_lost(self, exc): + if exc is None: + self._stream_reader.feed_eof() + else: + self._stream_reader.set_exception(exc) + super().connection_lost(exc) + + def data_received(self, data): + self._stream_reader.feed_data(data) + + def eof_received(self): + self._stream_reader.feed_eof() + + +class StreamWriter: + """Wraps a Transport. + + This exposes write(), writelines(), [can_]write_eof(), + get_extra_info() and close(). It adds drain() which returns an + optional Future on which you can wait for flow control. It also + adds a transport property which references the Transport + directly. + """ + + def __init__(self, transport, protocol, reader, loop): + self._transport = transport + self._protocol = protocol + # drain() expects that the reader has a exception() method + assert reader is None or isinstance(reader, StreamReader) + self._reader = reader + self._loop = loop + + def __repr__(self): + info = [self.__class__.__name__, 'transport=%r' % self._transport] + if self._reader is not None: + info.append('reader=%r' % self._reader) + return '<%s>' % ' '.join(info) + + @property + def transport(self): + return self._transport + + def write(self, data): + self._transport.write(data) + + def writelines(self, data): + self._transport.writelines(data) + + def write_eof(self): + return self._transport.write_eof() + + def can_write_eof(self): + return self._transport.can_write_eof() + + def close(self): + return self._transport.close() + + def get_extra_info(self, name, default=None): + return self._transport.get_extra_info(name, default) + + @coroutine + def drain(self): + """Flush the write buffer. + + The intended use is to write + + w.write(data) + yield from w.drain() + """ + if self._reader is not None: + exc = self._reader.exception() + if exc is not None: + raise exc + yield from self._protocol._drain_helper() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + self._limit = limit + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._buffer = bytearray() + self._eof = False # Whether we're done. + self._waiter = None # A future used by _wait_for_data() + self._exception = None + self._transport = None + self._paused = False + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) + + def _wakeup_waiter(self): + """Wakeup read() or readline() function waiting for data or EOF.""" + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_result(None) + + def set_transport(self, transport): + assert self._transport is None, 'Transport already set' + self._transport = transport + + def _maybe_resume_transport(self): + if self._paused and len(self._buffer) <= self._limit: + self._paused = False + self._transport.resume_reading() + + def feed_eof(self): + self._eof = True + self._wakeup_waiter() + + def at_eof(self): + """Return True if the buffer is empty and 'feed_eof' was called.""" + return self._eof and not self._buffer + + def feed_data(self, data): + assert not self._eof, 'feed_data after feed_eof' + + if not data: + return + + self._buffer.extend(data) + self._wakeup_waiter() + + if (self._transport is not None and + not self._paused and + len(self._buffer) > 2*self._limit): + try: + self._transport.pause_reading() + except NotImplementedError: + # The transport can't be paused. + # We'll just have to buffer all data. + # Forget the transport so we don't keep trying. + self._transport = None + else: + self._paused = True + + @coroutine + def _wait_for_data(self, func_name): + """Wait until feed_data() or feed_eof() is called.""" + # StreamReader uses a future to link the protocol feed_data() method + # to a read coroutine. Running two read coroutines at the same time + # would have an unexpected behaviour. It would not possible to know + # which coroutine would get the next data. + if self._waiter is not None: + raise RuntimeError('%s() called while another coroutine is ' + 'already waiting for incoming data' % func_name) + + self._waiter = futures.Future(loop=self._loop) + try: + yield from self._waiter + finally: + self._waiter = None + + @coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + line = bytearray() + not_enough = True + + while not_enough: + while self._buffer and not_enough: + ichar = self._buffer.find(b'\n') + if ichar < 0: + line.extend(self._buffer) + self._buffer.clear() + else: + ichar += 1 + line.extend(self._buffer[:ichar]) + del self._buffer[:ichar] + not_enough = False + + if len(line) > self._limit: + self._maybe_resume_transport() + raise ValueError('Line is too long') + + if self._eof: + break + + if not_enough: + yield from self._wait_for_data('readline') + + self._maybe_resume_transport() + return bytes(line) + + @coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + # This used to just loop creating a new waiter hoping to + # collect everything in self._buffer, but that would + # deadlock if the subprocess sends more than self.limit + # bytes. So just call self.read(self._limit) until EOF. + blocks = [] + while True: + block = yield from self.read(self._limit) + if not block: + break + blocks.append(block) + return b''.join(blocks) + else: + if not self._buffer and not self._eof: + yield from self._wait_for_data('read') + + if n < 0 or len(self._buffer) <= n: + data = bytes(self._buffer) + self._buffer.clear() + else: + # n > 0 and len(self._buffer) > n + data = bytes(self._buffer[:n]) + del self._buffer[:n] + + self._maybe_resume_transport() + return data + + @coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + # There used to be "optimized" code here. It created its own + # Future and waited until self._buffer had at least the n + # bytes, then called read(n). Unfortunately, this could pause + # the transport if the argument was larger than the pause + # limit (which is twice self._limit). So now we just read() + # into a local buffer. + + blocks = [] + while n > 0: + block = yield from self.read(n) + if not block: + partial = b''.join(blocks) + raise IncompleteReadError(partial, len(partial) + n) + blocks.append(block) + n -= len(block) + + return b''.join(blocks) + + if _PY35: + @coroutine + def __aiter__(self): + return self + + @coroutine + def __anext__(self): + val = yield from self.readline() + if val == b'': + raise StopAsyncIteration + return val diff --git a/trollius/subprocess.py b/trollius/subprocess.py new file mode 100644 index 0000000..4600a9f --- /dev/null +++ b/trollius/subprocess.py @@ -0,0 +1,215 @@ +__all__ = ['create_subprocess_exec', 'create_subprocess_shell'] + +import collections +import subprocess + +from . import events +from . import futures +from . import protocols +from . import streams +from . import tasks +from .coroutines import coroutine +from .log import logger + + +PIPE = subprocess.PIPE +STDOUT = subprocess.STDOUT +DEVNULL = subprocess.DEVNULL + + +class SubprocessStreamProtocol(streams.FlowControlMixin, + protocols.SubprocessProtocol): + """Like StreamReaderProtocol, but for a subprocess.""" + + def __init__(self, limit, loop): + super().__init__(loop=loop) + self._limit = limit + self.stdin = self.stdout = self.stderr = None + self._transport = None + + def __repr__(self): + info = [self.__class__.__name__] + if self.stdin is not None: + info.append('stdin=%r' % self.stdin) + if self.stdout is not None: + info.append('stdout=%r' % self.stdout) + if self.stderr is not None: + info.append('stderr=%r' % self.stderr) + return '<%s>' % ' '.join(info) + + def connection_made(self, transport): + self._transport = transport + + stdout_transport = transport.get_pipe_transport(1) + if stdout_transport is not None: + self.stdout = streams.StreamReader(limit=self._limit, + loop=self._loop) + self.stdout.set_transport(stdout_transport) + + stderr_transport = transport.get_pipe_transport(2) + if stderr_transport is not None: + self.stderr = streams.StreamReader(limit=self._limit, + loop=self._loop) + self.stderr.set_transport(stderr_transport) + + stdin_transport = transport.get_pipe_transport(0) + if stdin_transport is not None: + self.stdin = streams.StreamWriter(stdin_transport, + protocol=self, + reader=None, + loop=self._loop) + + def pipe_data_received(self, fd, data): + if fd == 1: + reader = self.stdout + elif fd == 2: + reader = self.stderr + else: + reader = None + if reader is not None: + reader.feed_data(data) + + def pipe_connection_lost(self, fd, exc): + if fd == 0: + pipe = self.stdin + if pipe is not None: + pipe.close() + self.connection_lost(exc) + return + if fd == 1: + reader = self.stdout + elif fd == 2: + reader = self.stderr + else: + reader = None + if reader != None: + if exc is None: + reader.feed_eof() + else: + reader.set_exception(exc) + + def process_exited(self): + self._transport.close() + self._transport = None + + +class Process: + def __init__(self, transport, protocol, loop): + self._transport = transport + self._protocol = protocol + self._loop = loop + self.stdin = protocol.stdin + self.stdout = protocol.stdout + self.stderr = protocol.stderr + self.pid = transport.get_pid() + + def __repr__(self): + return '<%s %s>' % (self.__class__.__name__, self.pid) + + @property + def returncode(self): + return self._transport.get_returncode() + + @coroutine + def wait(self): + """Wait until the process exit and return the process return code. + + This method is a coroutine.""" + return (yield from self._transport._wait()) + + def send_signal(self, signal): + self._transport.send_signal(signal) + + def terminate(self): + self._transport.terminate() + + def kill(self): + self._transport.kill() + + @coroutine + def _feed_stdin(self, input): + debug = self._loop.get_debug() + self.stdin.write(input) + if debug: + logger.debug('%r communicate: feed stdin (%s bytes)', + self, len(input)) + try: + yield from self.stdin.drain() + except (BrokenPipeError, ConnectionResetError) as exc: + # communicate() ignores BrokenPipeError and ConnectionResetError + if debug: + logger.debug('%r communicate: stdin got %r', self, exc) + + if debug: + logger.debug('%r communicate: close stdin', self) + self.stdin.close() + + @coroutine + def _noop(self): + return None + + @coroutine + def _read_stream(self, fd): + transport = self._transport.get_pipe_transport(fd) + if fd == 2: + stream = self.stderr + else: + assert fd == 1 + stream = self.stdout + if self._loop.get_debug(): + name = 'stdout' if fd == 1 else 'stderr' + logger.debug('%r communicate: read %s', self, name) + output = yield from stream.read() + if self._loop.get_debug(): + name = 'stdout' if fd == 1 else 'stderr' + logger.debug('%r communicate: close %s', self, name) + transport.close() + return output + + @coroutine + def communicate(self, input=None): + if input: + stdin = self._feed_stdin(input) + else: + stdin = self._noop() + if self.stdout is not None: + stdout = self._read_stream(1) + else: + stdout = self._noop() + if self.stderr is not None: + stderr = self._read_stream(2) + else: + stderr = self._noop() + stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr, + loop=self._loop) + yield from self.wait() + return (stdout, stderr) + + +@coroutine +def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, + loop=None, limit=streams._DEFAULT_LIMIT, **kwds): + if loop is None: + loop = events.get_event_loop() + protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, + loop=loop) + transport, protocol = yield from loop.subprocess_shell( + protocol_factory, + cmd, stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) + return Process(transport, protocol, loop) + +@coroutine +def create_subprocess_exec(program, *args, stdin=None, stdout=None, + stderr=None, loop=None, + limit=streams._DEFAULT_LIMIT, **kwds): + if loop is None: + loop = events.get_event_loop() + protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, + loop=loop) + transport, protocol = yield from loop.subprocess_exec( + protocol_factory, + program, *args, + stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) + return Process(transport, protocol, loop) diff --git a/trollius/tasks.py b/trollius/tasks.py new file mode 100644 index 0000000..d8193ba --- /dev/null +++ b/trollius/tasks.py @@ -0,0 +1,681 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + 'gather', 'shield', 'ensure_future', + ] + +import concurrent.futures +import functools +import inspect +import linecache +import sys +import types +import traceback +import warnings +import weakref + +from . import coroutines +from . import events +from . import futures +from .coroutines import coroutine + +_PY34 = (sys.version_info >= (3, 4)) + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + # An important invariant maintained while a Task not done: + # + # - Either _fut_waiter is None, and _step() is scheduled; + # - or _fut_waiter is some Future, and _step() is *not* scheduled. + # + # The only transition from the latter to the former is through + # _wakeup(). When _fut_waiter is not None, one of its callbacks + # must be _wakeup(). + + # Weak set containing all tasks alive. + _all_tasks = weakref.WeakSet() + + # Dictionary containing tasks that are currently active in + # all running event loops. {EventLoop: Task} + _current_tasks = {} + + # If False, don't log a message if the task is destroyed whereas its + # status is still pending + _log_destroy_pending = True + + @classmethod + def current_task(cls, loop=None): + """Return the currently running task in an event loop or None. + + By default the current task for the current event loop is returned. + + None is returned when called not in the context of a Task. + """ + if loop is None: + loop = events.get_event_loop() + return cls._current_tasks.get(loop) + + @classmethod + def all_tasks(cls, loop=None): + """Return a set of all tasks for an event loop. + + By default all tasks for the current event loop are returned. + """ + if loop is None: + loop = events.get_event_loop() + return {t for t in cls._all_tasks if t._loop is loop} + + def __init__(self, coro, *, loop=None): + assert coroutines.iscoroutine(coro), repr(coro) + super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + self._coro = coro + self._fut_waiter = None + self._must_cancel = False + self._loop.call_soon(self._step) + self.__class__._all_tasks.add(self) + + # On Python 3.3 or older, objects with a destructor that are part of a + # reference cycle are never destroyed. That's not the case any more on + # Python 3.4 thanks to the PEP 442. + if _PY34: + def __del__(self): + if self._state == futures._PENDING and self._log_destroy_pending: + context = { + 'task': self, + 'message': 'Task was destroyed but it is pending!', + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + futures.Future.__del__(self) + + def _repr_info(self): + info = super()._repr_info() + + if self._must_cancel: + # replace status + info[0] = 'cancelling' + + coro = coroutines._format_coroutine(self._coro) + info.insert(1, 'coro=<%s>' % coro) + + if self._fut_waiter is not None: + info.insert(2, 'wait_for=%r' % self._fut_waiter) + return info + + def get_stack(self, *, limit=None): + """Return the list of stack frames for this task's coroutine. + + If the coroutine is not done, this returns the stack where it is + suspended. If the coroutine has completed successfully or was + cancelled, this returns an empty list. If the coroutine was + terminated by an exception, this returns the list of traceback + frames. + + The frames are always ordered from oldest to newest. + + The optional limit gives the maximum number of frames to + return; by default all available frames are returned. Its + meaning differs depending on whether a stack or a traceback is + returned: the newest frames of a stack are returned, but the + oldest frames of a traceback are returned. (This matches the + behavior of the traceback module.) + + For reasons beyond our control, only one stack frame is + returned for a suspended coroutine. + """ + frames = [] + f = self._coro.gi_frame + if f is not None: + while f is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(f) + f = f.f_back + frames.reverse() + elif self._exception is not None: + tb = self._exception.__traceback__ + while tb is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(tb.tb_frame) + tb = tb.tb_next + return frames + + def print_stack(self, *, limit=None, file=None): + """Print the stack or traceback for this task's coroutine. + + This produces output similar to that of the traceback module, + for the frames retrieved by get_stack(). The limit argument + is passed to get_stack(). The file argument is an I/O stream + to which the output is written; by default output is written + to sys.stderr. + """ + extracted_list = [] + checked = set() + for f in self.get_stack(limit=limit): + lineno = f.f_lineno + co = f.f_code + filename = co.co_filename + name = co.co_name + if filename not in checked: + checked.add(filename) + linecache.checkcache(filename) + line = linecache.getline(filename, lineno, f.f_globals) + extracted_list.append((filename, lineno, name, line)) + exc = self._exception + if not extracted_list: + print('No stack for %r' % self, file=file) + elif exc is not None: + print('Traceback for %r (most recent call last):' % self, + file=file) + else: + print('Stack for %r (most recent call last):' % self, + file=file) + traceback.print_list(extracted_list, file=file) + if exc is not None: + for line in traceback.format_exception_only(exc.__class__, exc): + print(line, file=file, end='') + + def cancel(self): + """Request that this task cancel itself. + + This arranges for a CancelledError to be thrown into the + wrapped coroutine on the next cycle through the event loop. + The coroutine then has a chance to clean up or even deny + the request using try/except/finally. + + Unlike Future.cancel, this does not guarantee that the + task will be cancelled: the exception might be caught and + acted upon, delaying cancellation of the task or preventing + cancellation completely. The task may also return a value or + raise a different exception. + + Immediately after this method is called, Task.cancelled() will + not return True (unless the task was already cancelled). A + task will be marked as cancelled when the wrapped coroutine + terminates with a CancelledError exception (even if cancel() + was not called). + """ + if self.done(): + return False + if self._fut_waiter is not None: + if self._fut_waiter.cancel(): + # Leave self._fut_waiter; it may be a Task that + # catches and ignores the cancellation so we may have + # to cancel it again later. + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True + + def _step(self, value=None, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + if self._must_cancel: + if not isinstance(exc, futures.CancelledError): + exc = futures.CancelledError() + self._must_cancel = False + coro = self._coro + self._fut_waiter = None + + self.__class__._current_tasks[self._loop] = self + # Call either coro.throw(exc) or coro.send(value). + try: + if exc is not None: + result = coro.throw(exc) + else: + result = coro.send(value) + except StopIteration as exc: + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). + except Exception as exc: + self.set_exception(exc) + except BaseException as exc: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False + else: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + finally: + self.__class__._current_tasks.pop(self._loop) + self = None # Needed to break cycles when an exception occurs. + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + # This may also be a cancellation. + self._step(None, exc) + else: + self._step(value, None) + self = None # Needed to break cycles when an exception occurs. + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +@coroutine +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and coroutines given by fs to complete. + + The sequence futures must not be empty. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from asyncio.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): + raise TypeError("expect a list of futures, not %s" % type(fs).__name__) + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + + if loop is None: + loop = events.get_event_loop() + + fs = {ensure_future(f, loop=loop) for f in set(fs)} + + return (yield from _wait(fs, timeout, return_when, loop)) + + +def _release_waiter(waiter, *args): + if not waiter.done(): + waiter.set_result(None) + + +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. When a timeout occurs, + it cancels the task and raises TimeoutError. To avoid the task + cancellation, wrap it in shield(). + + If the wait is cancelled, the task is also cancelled. + + This function is a coroutine. + """ + if loop is None: + loop = events.get_event_loop() + + if timeout is None: + return (yield from fut) + + waiter = futures.Future(loop=loop) + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) + cb = functools.partial(_release_waiter, waiter) + + fut = ensure_future(fut, loop=loop) + fut.add_done_callback(cb) + + try: + # wait until the future completes or the timeout + try: + yield from waiter + except futures.CancelledError: + fut.remove_done_callback(cb) + fut.cancel() + raise + + if fut.done(): + return fut.result() + else: + fut.remove_done_callback(cb) + fut.cancel() + raise futures.TimeoutError() + finally: + timeout_handle.cancel() + + +@coroutine +def _wait(fs, timeout, return_when, loop): + """Internal helper for wait() and _wait_for(). + + The fs argument must be a collection of Futures. + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) + counter = len(fs) + + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or + return_when == FIRST_COMPLETED or + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() + if not waiter.done(): + waiter.set_result(None) + + for f in fs: + f.add_done_callback(_on_completion) + + try: + yield from waiter + finally: + if timeout_handle is not None: + timeout_handle.cancel() + + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, *, loop=None, timeout=None): + """Return an iterator whose values are coroutines. + + When waiting for the yielded coroutines you'll get the results (or + exceptions!) of the original Futures (or coroutines), in the order + in which and as soon as they complete. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + If a timeout is specified, the 'yield from' will raise + TimeoutError when the timeout occurs before all Futures are done. + + Note: The futures 'f' are not necessarily members of fs. + """ + if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): + raise TypeError("expect a list of futures, not %s" % type(fs).__name__) + loop = loop if loop is not None else events.get_event_loop() + todo = {ensure_future(f, loop=loop) for f in set(fs)} + from .queues import Queue # Import here to avoid circular import problem. + done = Queue(loop=loop) + timeout_handle = None + + def _on_timeout(): + for f in todo: + f.remove_done_callback(_on_completion) + done.put_nowait(None) # Queue a dummy value for _wait_for_one(). + todo.clear() # Can't do todo.remove(f) in the loop. + + def _on_completion(f): + if not todo: + return # _on_timeout() was here first. + todo.remove(f) + done.put_nowait(f) + if not todo and timeout_handle is not None: + timeout_handle.cancel() + + @coroutine + def _wait_for_one(): + f = yield from done.get() + if f is None: + # Dummy value from _on_timeout(). + raise futures.TimeoutError + return f.result() # May raise f.exception(). + + for f in todo: + f.add_done_callback(_on_completion) + if todo and timeout is not None: + timeout_handle = loop.call_later(timeout, _on_timeout) + for _ in range(len(todo)): + yield _wait_for_one() + + +@coroutine +def sleep(delay, result=None, *, loop=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future(loop=loop) + h = future._loop.call_later(delay, + future._set_result_unless_cancelled, result) + try: + return (yield from future) + finally: + h.cancel() + + +def async(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + + This function is deprecated in 3.5. Use asyncio.ensure_future() instead. + """ + + warnings.warn("asyncio.async() function is deprecated, use ensure_future()", + DeprecationWarning) + + return ensure_future(coro_or_future, loop=loop) + + +def ensure_future(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif coroutines.iscoroutine(coro_or_future): + if loop is None: + loop = events.get_event_loop() + task = loop.create_task(coro_or_future) + if task._source_traceback: + del task._source_traceback[-1] + return task + else: + raise TypeError('A Future or coroutine is required') + + +class _GatheringFuture(futures.Future): + """Helper for gather(). + + This overrides cancel() to cancel all the children and act more + like Task.cancel(), which doesn't immediately mark itself as + cancelled. + """ + + def __init__(self, children, *, loop=None): + super().__init__(loop=loop) + self._children = children + + def cancel(self): + if self.done(): + return False + for child in self._children: + child.cancel() + return True + + +def gather(*coros_or_futures, loop=None, return_exceptions=False): + """Return a future aggregating results from the given coroutines + or futures. + + All futures must share the same event loop. If all the tasks are + done successfully, the returned future's result is the list of + results (in the order of the original sequence, not necessarily + the order of results arrival). If *return_exceptions* is True, + exceptions in the tasks are treated the same as successful + results, and gathered in the result list; otherwise, the first + raised exception will be immediately propagated to the returned + future. + + Cancellation: if the outer Future is cancelled, all children (that + have not completed yet) are also cancelled. If any child is + cancelled, this is treated as if it raised CancelledError -- + the outer Future is *not* cancelled in this case. (This is to + prevent the cancellation of one child to cause other children to + be cancelled.) + """ + if not coros_or_futures: + outer = futures.Future(loop=loop) + outer.set_result([]) + return outer + + arg_to_fut = {} + for arg in set(coros_or_futures): + if not isinstance(arg, futures.Future): + fut = ensure_future(arg, loop=loop) + if loop is None: + loop = fut._loop + # The caller cannot control this future, the "destroy pending task" + # warning should not be emitted. + fut._log_destroy_pending = False + else: + fut = arg + if loop is None: + loop = fut._loop + elif fut._loop is not loop: + raise ValueError("futures are tied to different event loops") + arg_to_fut[arg] = fut + + children = [arg_to_fut[arg] for arg in coros_or_futures] + nchildren = len(children) + outer = _GatheringFuture(children, loop=loop) + nfinished = 0 + results = [None] * nchildren + + def _done_callback(i, fut): + nonlocal nfinished + if outer.done(): + if not fut.cancelled(): + # Mark exception retrieved. + fut.exception() + return + + if fut.cancelled(): + res = futures.CancelledError() + if not return_exceptions: + outer.set_exception(res) + return + elif fut._exception is not None: + res = fut.exception() # Mark exception retrieved. + if not return_exceptions: + outer.set_exception(res) + return + else: + res = fut._result + results[i] = res + nfinished += 1 + if nfinished == nchildren: + outer.set_result(results) + + for i, fut in enumerate(children): + fut.add_done_callback(functools.partial(_done_callback, i)) + return outer + + +def shield(arg, *, loop=None): + """Wait for a future, shielding it from cancellation. + + The statement + + res = yield from shield(something()) + + is exactly equivalent to the statement + + res = yield from something() + + *except* that if the coroutine containing it is cancelled, the + task running in something() is not cancelled. From the POV of + something(), the cancellation did not happen. But its caller is + still cancelled, so the yield-from expression still raises + CancelledError. Note: If something() is cancelled by other means + this will still cancel shield(). + + If you want to completely ignore cancellation (not recommended) + you can combine shield() with a try/except clause, as follows: + + try: + res = yield from shield(something()) + except CancelledError: + res = None + """ + inner = ensure_future(arg, loop=loop) + if inner.done(): + # Shortcut. + return inner + loop = inner._loop + outer = futures.Future(loop=loop) + + def _done_callback(inner): + if outer.cancelled(): + if not inner.cancelled(): + # Mark inner's result as retrieved. + inner.exception() + return + + if inner.cancelled(): + outer.cancel() + else: + exc = inner.exception() + if exc is not None: + outer.set_exception(exc) + else: + outer.set_result(inner.result()) + + inner.add_done_callback(_done_callback) + return outer diff --git a/trollius/test_support.py b/trollius/test_support.py new file mode 100644 index 0000000..0fadfad --- /dev/null +++ b/trollius/test_support.py @@ -0,0 +1,308 @@ +# Subset of test.support from CPython 3.5, just what we need to run asyncio +# test suite. The code is copied from CPython 3.5 to not depend on the test +# module because it is rarely installed. + +# Ignore symbol TEST_HOME_DIR: test_events works without it + +import functools +import gc +import os +import platform +import re +import socket +import subprocess +import sys +import time + + +# A constant likely larger than the underlying OS pipe buffer size, to +# make writes blocking. +# Windows limit seems to be around 512 B, and many Unix kernels have a +# 64 KiB pipe buffer size or 16 * PAGE_SIZE: take a few megs to be sure. +# (see issue #17835 for a discussion of this number). +PIPE_MAX_SIZE = 4 * 1024 * 1024 + 1 + +def strip_python_stderr(stderr): + """Strip the stderr of a Python process from potential debug output + emitted by the interpreter. + + This will typically be run on the result of the communicate() method + of a subprocess.Popen object. + """ + stderr = re.sub(br"\[\d+ refs, \d+ blocks\]\r?\n?", b"", stderr).strip() + return stderr + + +# Executing the interpreter in a subprocess +def _assert_python(expected_success, *args, **env_vars): + if '__isolated' in env_vars: + isolated = env_vars.pop('__isolated') + else: + isolated = not env_vars + cmd_line = [sys.executable, '-X', 'faulthandler'] + if isolated and sys.version_info >= (3, 4): + # isolated mode: ignore Python environment variables, ignore user + # site-packages, and don't add the current directory to sys.path + cmd_line.append('-I') + elif not env_vars: + # ignore Python environment variables + cmd_line.append('-E') + # Need to preserve the original environment, for in-place testing of + # shared library builds. + env = os.environ.copy() + # But a special flag that can be set to override -- in this case, the + # caller is responsible to pass the full environment. + if env_vars.pop('__cleanenv', None): + env = {} + env.update(env_vars) + cmd_line.extend(args) + p = subprocess.Popen(cmd_line, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + env=env) + try: + out, err = p.communicate() + finally: + subprocess._cleanup() + p.stdout.close() + p.stderr.close() + rc = p.returncode + err = strip_python_stderr(err) + if (rc and expected_success) or (not rc and not expected_success): + raise AssertionError( + "Process return code is %d, " + "stderr follows:\n%s" % (rc, err.decode('ascii', 'ignore'))) + return rc, out, err + + +def assert_python_ok(*args, **env_vars): + """ + Assert that running the interpreter with `args` and optional environment + variables `env_vars` succeeds (rc == 0) and return a (return code, stdout, + stderr) tuple. + + If the __cleanenv keyword is set, env_vars is used a fresh environment. + + Python is started in isolated mode (command line option -I), + except if the __isolated keyword is set to False. + """ + return _assert_python(True, *args, **env_vars) + + +is_jython = sys.platform.startswith('java') + +def gc_collect(): + """Force as many objects as possible to be collected. + + In non-CPython implementations of Python, this is needed because timely + deallocation is not guaranteed by the garbage collector. (Even in CPython + this can be the case in case of reference cycles.) This means that __del__ + methods may be called later than expected and weakrefs may remain alive for + longer than expected. This function tries its best to force all garbage + objects to disappear. + """ + gc.collect() + if is_jython: + time.sleep(0.1) + gc.collect() + gc.collect() + + +HOST = "127.0.0.1" +HOSTv6 = "::1" + + +def _is_ipv6_enabled(): + """Check whether IPv6 is enabled on this host.""" + if socket.has_ipv6: + sock = None + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind((HOSTv6, 0)) + return True + except OSError: + pass + finally: + if sock: + sock.close() + return False + +IPV6_ENABLED = _is_ipv6_enabled() + + +def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): + """Returns an unused port that should be suitable for binding. This is + achieved by creating a temporary socket with the same family and type as + the 'sock' parameter (default is AF_INET, SOCK_STREAM), and binding it to + the specified host address (defaults to 0.0.0.0) with the port set to 0, + eliciting an unused ephemeral port from the OS. The temporary socket is + then closed and deleted, and the ephemeral port is returned. + + Either this method or bind_port() should be used for any tests where a + server socket needs to be bound to a particular port for the duration of + the test. Which one to use depends on whether the calling code is creating + a python socket, or if an unused port needs to be provided in a constructor + or passed to an external program (i.e. the -accept argument to openssl's + s_server mode). Always prefer bind_port() over find_unused_port() where + possible. Hard coded ports should *NEVER* be used. As soon as a server + socket is bound to a hard coded port, the ability to run multiple instances + of the test simultaneously on the same host is compromised, which makes the + test a ticking time bomb in a buildbot environment. On Unix buildbots, this + may simply manifest as a failed test, which can be recovered from without + intervention in most cases, but on Windows, the entire python process can + completely and utterly wedge, requiring someone to log in to the buildbot + and manually kill the affected process. + + (This is easy to reproduce on Windows, unfortunately, and can be traced to + the SO_REUSEADDR socket option having different semantics on Windows versus + Unix/Linux. On Unix, you can't have two AF_INET SOCK_STREAM sockets bind, + listen and then accept connections on identical host/ports. An EADDRINUSE + OSError will be raised at some point (depending on the platform and + the order bind and listen were called on each socket). + + However, on Windows, if SO_REUSEADDR is set on the sockets, no EADDRINUSE + will ever be raised when attempting to bind two identical host/ports. When + accept() is called on each socket, the second caller's process will steal + the port from the first caller, leaving them both in an awkwardly wedged + state where they'll no longer respond to any signals or graceful kills, and + must be forcibly killed via OpenProcess()/TerminateProcess(). + + The solution on Windows is to use the SO_EXCLUSIVEADDRUSE socket option + instead of SO_REUSEADDR, which effectively affords the same semantics as + SO_REUSEADDR on Unix. Given the propensity of Unix developers in the Open + Source world compared to Windows ones, this is a common mistake. A quick + look over OpenSSL's 0.9.8g source shows that they use SO_REUSEADDR when + openssl.exe is called with the 's_server' option, for example. See + http://bugs.python.org/issue2550 for more info. The following site also + has a very thorough description about the implications of both REUSEADDR + and EXCLUSIVEADDRUSE on Windows: + http://msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx) + + XXX: although this approach is a vast improvement on previous attempts to + elicit unused ports, it rests heavily on the assumption that the ephemeral + port returned to us by the OS won't immediately be dished back out to some + other process when we close and delete our temporary socket but before our + calling code has a chance to bind the returned port. We can deal with this + issue if/when we come across it. + """ + + tempsock = socket.socket(family, socktype) + port = bind_port(tempsock) + tempsock.close() + del tempsock + return port + +def bind_port(sock, host=HOST): + """Bind the socket to a free port and return the port number. Relies on + ephemeral ports in order to ensure we are using an unbound port. This is + important as many tests may be running simultaneously, especially in a + buildbot environment. This method raises an exception if the sock.family + is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR + or SO_REUSEPORT set on it. Tests should *never* set these socket options + for TCP/IP sockets. The only case for setting these options is testing + multicasting via multiple UDP sockets. + + Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e. + on Windows), it will be set on the socket. This will prevent anyone else + from bind()'ing to our host/port for the duration of the test. + """ + + if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: + if hasattr(socket, 'SO_REUSEADDR'): + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: + raise TestFailed("tests should never set the SO_REUSEADDR " + "socket option on TCP/IP sockets!") + if hasattr(socket, 'SO_REUSEPORT'): + try: + reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) + if reuse == 1: + raise TestFailed("tests should never set the SO_REUSEPORT " + "socket option on TCP/IP sockets!") + except OSError: + # Python's socket module was compiled using modern headers + # thus defining SO_REUSEPORT but this process is running + # under an older kernel that does not support SO_REUSEPORT. + pass + if hasattr(socket, 'SO_EXCLUSIVEADDRUSE'): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + + sock.bind((host, 0)) + port = sock.getsockname()[1] + return port + +def requires_mac_ver(*min_version): + """Decorator raising SkipTest if the OS is Mac OS X and the OS X + version if less than min_version. + + For example, @requires_mac_ver(10, 5) raises SkipTest if the OS X version + is lesser than 10.5. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kw): + if sys.platform == 'darwin': + version_txt = platform.mac_ver()[0] + try: + version = tuple(map(int, version_txt.split('.'))) + except ValueError: + pass + else: + if version < min_version: + min_version_txt = '.'.join(map(str, min_version)) + raise unittest.SkipTest( + "Mac OS X %s or higher required, not %s" + % (min_version_txt, version_txt)) + return func(*args, **kw) + wrapper.min_version = min_version + return wrapper + return decorator + +def _requires_unix_version(sysname, min_version): + """Decorator raising SkipTest if the OS is `sysname` and the version is + less than `min_version`. + + For example, @_requires_unix_version('FreeBSD', (7, 2)) raises SkipTest if + the FreeBSD version is less than 7.2. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kw): + if platform.system() == sysname: + version_txt = platform.release().split('-', 1)[0] + try: + version = tuple(map(int, version_txt.split('.'))) + except ValueError: + pass + else: + if version < min_version: + min_version_txt = '.'.join(map(str, min_version)) + raise unittest.SkipTest( + "%s version %s or higher required, not %s" + % (sysname, min_version_txt, version_txt)) + return func(*args, **kw) + wrapper.min_version = min_version + return wrapper + return decorator + +def requires_freebsd_version(*min_version): + """Decorator raising SkipTest if the OS is FreeBSD and the FreeBSD version + is less than `min_version`. + + For example, @requires_freebsd_version(7, 2) raises SkipTest if the FreeBSD + version is less than 7.2. + """ + return _requires_unix_version('FreeBSD', min_version) + +# Use test.support if available +try: + from test.support import * +except ImportError: + pass + +# Use test.script_helper if available +try: + from test.support.script_helper import assert_python_ok +except ImportError: + try: + from test.script_helper import assert_python_ok + except ImportError: + pass diff --git a/trollius/test_utils.py b/trollius/test_utils.py new file mode 100644 index 0000000..8cee95b --- /dev/null +++ b/trollius/test_utils.py @@ -0,0 +1,446 @@ +"""Utilities shared by tests.""" + +import collections +import contextlib +import io +import logging +import os +import re +import socket +import socketserver +import sys +import tempfile +import threading +import time +import unittest +from unittest import mock + +from http.server import HTTPServer +from wsgiref.simple_server import WSGIRequestHandler, WSGIServer + +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import events +from . import futures +from . import selectors +from . import tasks +from .coroutines import coroutine +from .log import logger + + +if sys.platform == 'win32': # pragma: no cover + from .windows_utils import socketpair +else: + from socket import socketpair # pragma: no cover + + +def dummy_ssl_context(): + if ssl is None: + return None + else: + return ssl.SSLContext(ssl.PROTOCOL_SSLv23) + + +def run_briefly(loop): + @coroutine + def once(): + pass + gen = once() + t = loop.create_task(gen) + # Don't log a warning if the task is not done after run_until_complete(). + # It occurs if the loop is stopped or if a task raises a BaseException. + t._log_destroy_pending = False + try: + loop.run_until_complete(t) + finally: + gen.close() + + +def run_until(loop, pred, timeout=30): + deadline = time.time() + timeout + while not pred(): + if timeout is not None: + timeout = deadline - time.time() + if timeout <= 0: + raise futures.TimeoutError() + loop.run_until_complete(tasks.sleep(0.001, loop=loop)) + + +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + +class SilentWSGIRequestHandler(WSGIRequestHandler): + + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + +class SilentWSGIServer(WSGIServer): + + request_timeout = 2 + + def get_request(self): + request, client_addr = super().get_request() + request.settimeout(self.request_timeout) + return request, client_addr + + def handle_error(self, request, client_address): + pass + + +class SSLWSGIServerMixin: + + def finish_request(self, request, client_address): + # The relative location of our test directory (which + # contains the ssl key and certificate files) differs + # between the stdlib and stand-alone asyncio. + # Prefer our own if we can find it. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + if not os.path.isdir(here): + here = os.path.join(os.path.dirname(os.__file__), + 'test', 'test_asyncio') + keyfile = os.path.join(here, 'ssl_key.pem') + certfile = os.path.join(here, 'ssl_cert.pem') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + +class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): + pass + + +def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): + + def app(environ, start_response): + status = '200 OK' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = server_ssl_cls if use_ssl else server_cls + httpd = server_class(address, SilentWSGIRequestHandler) + httpd.set_app(app) + httpd.address = httpd.server_address + server_thread = threading.Thread( + target=lambda: httpd.serve_forever(poll_interval=0.05)) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + httpd.server_close() + server_thread.join() + + +if hasattr(socket, 'AF_UNIX'): + + class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): + + def server_bind(self): + socketserver.UnixStreamServer.server_bind(self) + self.server_name = '127.0.0.1' + self.server_port = 80 + + + class UnixWSGIServer(UnixHTTPServer, WSGIServer): + + request_timeout = 2 + + def server_bind(self): + UnixHTTPServer.server_bind(self) + self.setup_environ() + + def get_request(self): + request, client_addr = super().get_request() + request.settimeout(self.request_timeout) + # Code in the stdlib expects that get_request + # will return a socket and a tuple (host, port). + # However, this isn't true for UNIX sockets, + # as the second return value will be a path; + # hence we return some fake data sufficient + # to get the tests going + return request, ('127.0.0.1', '') + + + class SilentUnixWSGIServer(UnixWSGIServer): + + def handle_error(self, request, client_address): + pass + + + class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): + pass + + + def gen_unix_socket_path(): + with tempfile.NamedTemporaryFile() as file: + return file.name + + + @contextlib.contextmanager + def unix_socket_path(): + path = gen_unix_socket_path() + try: + yield path + finally: + try: + os.unlink(path) + except OSError: + pass + + + @contextlib.contextmanager + def run_test_unix_server(*, use_ssl=False): + with unix_socket_path() as path: + yield from _run_test_server(address=path, use_ssl=use_ssl, + server_cls=SilentUnixWSGIServer, + server_ssl_cls=UnixSSLWSGIServer) + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + yield from _run_test_server(address=(host, port), use_ssl=use_ssl, + server_cls=SilentWSGIServer, + server_ssl_cls=SSLWSGIServer) + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = MockCallback(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors.BaseSelector): + + def __init__(self): + self.keys = {} + + def register(self, fileobj, events, data=None): + key = selectors.SelectorKey(fileobj, 0, events, data) + self.keys[fileobj] = key + return key + + def unregister(self, fileobj): + return self.keys.pop(fileobj) + + def select(self, timeout): + return [] + + def get_map(self): + return self.keys + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value returned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + def gen(): + yield + self._check_on_close = False + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._clock_resolution = 1e-9 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + super().close() + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: # pragma: no cover + raise AssertionError("Time generator is not finished") + + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.Handle(callback, args, self) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.Handle(callback, args, self) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass + + +def MockCallback(**kwargs): + return mock.Mock(spec=['__call__'], **kwargs) + + +class MockPattern(str): + """A regex based str with a fuzzy __eq__. + + Use this helper with 'mock.assert_called_with', or anywhere + where a regex comparison between strings is needed. + + For instance: + mock_call.assert_called_with(MockPattern('spam.*ham')) + """ + def __eq__(self, other): + return bool(re.search(str(self), other, re.S)) + + +def get_function_source(func): + source = events._get_function_source(func) + if source is None: + raise ValueError("unable to get the source of %r" % (func,)) + return source + + +class TestCase(unittest.TestCase): + def set_event_loop(self, loop, *, cleanup=True): + assert loop is not None + # ensure that the event loop is passed explicitly in asyncio + events.set_event_loop(None) + if cleanup: + self.addCleanup(loop.close) + + def new_test_loop(self, gen=None): + loop = TestLoop(gen) + self.set_event_loop(loop) + return loop + + def tearDown(self): + events.set_event_loop(None) + + # Detect CPython bug #23353: ensure that yield/yield-from is not used + # in an except block of a generator + self.assertEqual(sys.exc_info(), (None, None, None)) + + +@contextlib.contextmanager +def disable_logger(): + """Context manager to disable asyncio logger. + + For example, it can be used to ignore warnings in debug mode. + """ + old_level = logger.level + try: + logger.setLevel(logging.CRITICAL+1) + yield + finally: + logger.setLevel(old_level) + +def mock_nonblocking_socket(): + """Create a mock of a non-blocking socket.""" + sock = mock.Mock(socket.socket) + sock.gettimeout.return_value = 0.0 + return sock + + +def force_legacy_ssl_support(): + return mock.patch('asyncio.sslproto._is_sslproto_available', + return_value=False) diff --git a/trollius/transports.py b/trollius/transports.py new file mode 100644 index 0000000..22df3c7 --- /dev/null +++ b/trollius/transports.py @@ -0,0 +1,300 @@ +"""Abstract Transport class.""" + +import sys + +_PY34 = sys.version_info >= (3, 4) + +__all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport', + 'Transport', 'DatagramTransport', 'SubprocessTransport', + ] + + +class BaseTransport: + """Base class for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Close the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """Interface for read-only transports.""" + + def pause_reading(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume_reading() is called. + """ + raise NotImplementedError + + def resume_reading(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """Interface for write-only transports.""" + + def set_write_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for write flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + raise NotImplementedError + + def get_write_buffer_size(self): + """Return the current size of the write buffer.""" + raise NotImplementedError + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation concatenates the arguments and + calls write() on the result. + """ + if not _PY34: + # In Python 3.3, bytes.join() doesn't handle memoryview. + list_of_data = ( + bytes(data) if isinstance(data, memoryview) else data + for data in list_of_data) + self.write(b''.join(list_of_data)) + + def write_eof(self): + """Close the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this transport supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """Interface representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.create_server().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """Interface for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError + + +class _FlowControlMixin(Transport): + """All the logic for (write) flow control in a mix-in base class. + + The subclass must implement get_write_buffer_size(). It must call + _maybe_pause_protocol() whenever the write buffer size increases, + and _maybe_resume_protocol() whenever it decreases. It may also + override set_write_buffer_limits() (e.g. to specify different + defaults). + + The subclass constructor must call super().__init__(extra). This + will call set_write_buffer_limits(). + + The user may call set_write_buffer_limits() and + get_write_buffer_size(), and their protocol's pause_writing() and + resume_writing() may be called. + """ + + def __init__(self, extra=None, loop=None): + super().__init__(extra) + assert loop is not None + self._loop = loop + self._protocol_paused = False + self._set_write_buffer_limits() + + def _maybe_pause_protocol(self): + size = self.get_write_buffer_size() + if size <= self._high_water: + return + if not self._protocol_paused: + self._protocol_paused = True + try: + self._protocol.pause_writing() + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.pause_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + + def _maybe_resume_protocol(self): + if (self._protocol_paused and + self.get_write_buffer_size() <= self._low_water): + self._protocol_paused = False + try: + self._protocol.resume_writing() + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.resume_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + + def get_write_buffer_limits(self): + return (self._low_water, self._high_water) + + def _set_write_buffer_limits(self, high=None, low=None): + if high is None: + if low is None: + high = 64*1024 + else: + high = 4*low + if low is None: + low = high // 4 + if not high >= low >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (high, low)) + self._high_water = high + self._low_water = low + + def set_write_buffer_limits(self, high=None, low=None): + self._set_write_buffer_limits(high=high, low=low) + self._maybe_pause_protocol() + + def get_write_buffer_size(self): + raise NotImplementedError diff --git a/trollius/unix_events.py b/trollius/unix_events.py new file mode 100644 index 0000000..75e7c9c --- /dev/null +++ b/trollius/unix_events.py @@ -0,0 +1,998 @@ +"""Selector event loop for Unix with signal handling.""" + +import errno +import os +import signal +import socket +import stat +import subprocess +import sys +import threading +import warnings + + +from . import base_events +from . import base_subprocess +from . import constants +from . import coroutines +from . import events +from . import futures +from . import selector_events +from . import selectors +from . import transports +from .coroutines import coroutine +from .log import logger + + +__all__ = ['SelectorEventLoop', + 'AbstractChildWatcher', 'SafeChildWatcher', + 'FastChildWatcher', 'DefaultEventLoopPolicy', + ] + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +def _sighandler_noop(signum, frame): + """Dummy signal handler.""" + pass + + +class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop. + + Adds signal handling and UNIX Domain Socket support to SelectorEventLoop. + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + + def _socketpair(self): + return socket.socketpair() + + def close(self): + super().close() + for sig in list(self._signal_handlers): + self.remove_signal_handler(sig) + + def _process_self_data(self, data): + for signum in data: + if not signum: + # ignore null bytes written by _write_to_self() + continue + self._handle_signal(signum) + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used " + "with add_signal_handler()") + self._check_signal(sig) + self._check_closed() + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except (ValueError, OSError) as exc: + raise RuntimeError(str(exc)) + + handle = events.Handle(callback, args, self) + self._signal_handlers[sig] = handle + + try: + # Register a dummy signal handler to ask Python to write the signal + # number in the wakup file descriptor. _process_self_data() will + # read signal numbers from this file descriptor to handle signals. + signal.signal(sig, _sighandler_noop) + + # Set SA_RESTART to limit EINTR occurrences. + signal.siginterrupt(sig, False) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except (ValueError, OSError) as nexc: + logger.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + def _handle_signal(self, sig): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle._cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self._add_callback_signalsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except (ValueError, OSError) as exc: + logger.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + @coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + with events.get_child_watcher() as watcher: + waiter = futures.Future(loop=self) + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + waiter=waiter, extra=extra, + **kwargs) + + watcher.add_child_handler(transp.get_pid(), + self._child_watcher_callback, transp) + try: + yield from waiter + except Exception as exc: + # Workaround CPython bug #23353: using yield/yield-from in an + # except block of a generator doesn't clear properly + # sys.exc_info() + err = exc + else: + err = None + + if err is not None: + transp.close() + yield from transp._wait() + raise err + + return transp + + def _child_watcher_callback(self, pid, returncode, transp): + self.call_soon_threadsafe(transp._process_exited, returncode) + + @coroutine + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + assert server_hostname is None or isinstance(server_hostname, str) + if ssl: + if server_hostname is None: + raise ValueError( + 'you have to pass server_hostname when using ssl') + else: + if server_hostname is not None: + raise ValueError('server_hostname is only meaningful with ssl') + + if path is not None: + if sock is not None: + raise ValueError( + 'path and sock can not be specified at the same time') + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) + try: + sock.setblocking(False) + yield from self.sock_connect(sock, path) + except: + sock.close() + raise + + else: + if sock is None: + raise ValueError('no path and sock were specified') + sock.setblocking(False) + + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + return transport, protocol + + @coroutine + def create_unix_server(self, protocol_factory, path=None, *, + sock=None, backlog=100, ssl=None): + if isinstance(ssl, bool): + raise TypeError('ssl argument must be an SSLContext or None') + + if path is not None: + if sock is not None: + raise ValueError( + 'path and sock can not be specified at the same time') + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + try: + sock.bind(path) + except OSError as exc: + sock.close() + if exc.errno == errno.EADDRINUSE: + # Let's improve the error message by adding + # with what exact address it occurs. + msg = 'Address {!r} is already in use'.format(path) + raise OSError(errno.EADDRINUSE, msg) from None + else: + raise + except: + sock.close() + raise + else: + if sock is None: + raise ValueError( + 'path was not specified, and no sock specified') + + if sock.family != socket.AF_UNIX: + raise ValueError( + 'A UNIX Domain Socket was expected, got {!r}'.format(sock)) + + server = base_events.Server(self, [sock]) + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + return server + + +if hasattr(os, 'set_blocking'): + def _set_nonblocking(fd): + os.set_blocking(fd, False) +else: + import fcntl + + def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one event loop iteration + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + mode = os.fstat(self._fileno).st_mode + if not (stat.S_ISFIFO(mode) or + stat.S_ISSOCK(mode) or + stat.S_ISCHR(mode)): + raise ValueError("Pipe transport is for pipes/sockets only.") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._loop.call_soon(self._protocol.connection_made, self) + # only start reading when connection_made() has been called + self._loop.call_soon(self._loop.add_reader, + self._fileno, self._read_ready) + if waiter is not None: + # only wake up the waiter when connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def __repr__(self): + info = [self.__class__.__name__] + if self._pipe is None: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % self._fileno) + if self._pipe is not None: + polling = selector_events._test_selector_event( + self._loop._selector, + self._fileno, selectors.EVENT_READ) + if polling: + info.append('polling') + else: + info.append('idle') + else: + info.append('closed') + return '<%s>' % ' '.join(info) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc, 'Fatal read error on pipe transport') + else: + if data: + self._protocol.data_received(data) + else: + if self._loop.get_debug(): + logger.info("%r was closed by peer", self) + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def pause_reading(self): + self._loop.remove_reader(self._fileno) + + def resume_reading(self): + self._loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if self._pipe is not None: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self._pipe.close() + + def _fatal_error(self, exc, message='Fatal error on pipe transport'): + # should be called by exception handler only + if (isinstance(exc, OSError) and exc.errno == errno.EIO): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWritePipeTransport(transports._FlowControlMixin, + transports.WriteTransport): + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra, loop) + self._extra['pipe'] = pipe + self._pipe = pipe + self._fileno = pipe.fileno() + mode = os.fstat(self._fileno).st_mode + is_socket = stat.S_ISSOCK(mode) + if not (is_socket or + stat.S_ISFIFO(mode) or + stat.S_ISCHR(mode)): + raise ValueError("Pipe transport is only for " + "pipes, sockets and character devices") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + + self._loop.call_soon(self._protocol.connection_made, self) + + # On AIX, the reader trick (to be notified when the read end of the + # socket is closed) only works for sockets. On other platforms it + # works for pipes and sockets. (Exception: OS X 10.4? Issue #19294.) + if is_socket or not sys.platform.startswith("aix"): + # only start reading when connection_made() has been called + self._loop.call_soon(self._loop.add_reader, + self._fileno, self._read_ready) + + if waiter is not None: + # only wake up the waiter when connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def __repr__(self): + info = [self.__class__.__name__] + if self._pipe is None: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % self._fileno) + if self._pipe is not None: + polling = selector_events._test_selector_event( + self._loop._selector, + self._fileno, selectors.EVENT_WRITE) + if polling: + info.append('polling') + else: + info.append('idle') + + bufsize = self.get_write_buffer_size() + info.append('bufsize=%s' % bufsize) + else: + info.append('closed') + return '<%s>' % ' '.join(info) + + def get_write_buffer_size(self): + return sum(len(data) for data in self._buffer) + + def _read_ready(self): + # Pipe was closed by peer. + if self._loop.get_debug(): + logger.info("%r was closed by peer", self) + if self._buffer: + self._close(BrokenPipeError()) + else: + self._close() + + def write(self, data): + assert isinstance(data, (bytes, bytearray, memoryview)), repr(data) + if isinstance(data, bytearray): + data = memoryview(data) + if not data: + return + + if self._conn_lost or self._closing: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('pipe closed by peer or ' + 'os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc, 'Fatal write error on pipe transport') + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + self._maybe_pause_protocol() + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) + self._fatal_error(exc, 'Fatal write error on pipe transport') + else: + if n == len(data): + self._loop.remove_writer(self._fileno) + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer and self._closing: + self._loop.remove_reader(self._fileno) + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + if self._closing: + return + assert self._pipe + self._closing = True + if not self._buffer: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) + + def close(self): + if self._pipe is not None and not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if self._pipe is not None: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self._pipe.close() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc, message='Fatal error on pipe transport'): + # should be called by exception handler only + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +if hasattr(os, 'set_inheritable'): + # Python 3.4 and newer + _set_inheritable = os.set_inheritable +else: + import fcntl + + def _set_inheritable(fd, inheritable): + cloexec_flag = getattr(fcntl, 'FD_CLOEXEC', 1) + + old = fcntl.fcntl(fd, fcntl.F_GETFD) + if not inheritable: + fcntl.fcntl(fd, fcntl.F_SETFD, old | cloexec_flag) + else: + fcntl.fcntl(fd, fcntl.F_SETFD, old & ~cloexec_flag) + + +class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport): + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + stdin_w = None + if stdin == subprocess.PIPE: + # Use a socket pair for stdin, since not all platforms + # support selecting read events on the write end of a + # socket (which we use in order to detect closing of the + # other end). Notably this is needed on AIX, and works + # just fine on other platforms. + stdin, stdin_w = self._loop._socketpair() + + # Mark the write end of the stdin pipe as non-inheritable, + # needed by close_fds=False on Python 3.3 and older + # (Python 3.4 implements the PEP 446, socketpair returns + # non-inheritable sockets) + _set_inheritable(stdin_w.fileno(), False) + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + if stdin_w is not None: + stdin.close() + self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize) + + +class AbstractChildWatcher: + """Abstract base class for monitoring child processes. + + Objects derived from this class monitor a collection of subprocesses and + report their termination or interruption by a signal. + + New callbacks are registered with .add_child_handler(). Starting a new + process must be done within a 'with' block to allow the watcher to suspend + its activity until the new process if fully registered (this is needed to + prevent a race condition in some implementations). + + Example: + with watcher: + proc = subprocess.Popen("sleep 1") + watcher.add_child_handler(proc.pid, callback) + + Notes: + Implementations of this class must be thread-safe. + + Since child watcher objects may catch the SIGCHLD signal and call + waitpid(-1), there should be only one active object per process. + """ + + def add_child_handler(self, pid, callback, *args): + """Register a new child handler. + + Arrange for callback(pid, returncode, *args) to be called when + process 'pid' terminates. Specifying another callback for the same + process replaces the previous handler. + + Note: callback() must be thread-safe. + """ + raise NotImplementedError() + + def remove_child_handler(self, pid): + """Removes the handler for process 'pid'. + + The function returns True if the handler was successfully removed, + False if there was nothing to remove.""" + + raise NotImplementedError() + + def attach_loop(self, loop): + """Attach the watcher to an event loop. + + If the watcher was previously attached to an event loop, then it is + first detached before attaching to the new loop. + + Note: loop may be None. + """ + raise NotImplementedError() + + def close(self): + """Close the watcher. + + This must be called to make sure that any underlying resource is freed. + """ + raise NotImplementedError() + + def __enter__(self): + """Enter the watcher's context and allow starting new processes + + This function must return self""" + raise NotImplementedError() + + def __exit__(self, a, b, c): + """Exit the watcher's context""" + raise NotImplementedError() + + +class BaseChildWatcher(AbstractChildWatcher): + + def __init__(self): + self._loop = None + + def close(self): + self.attach_loop(None) + + def _do_waitpid(self, expected_pid): + raise NotImplementedError() + + def _do_waitpid_all(self): + raise NotImplementedError() + + def attach_loop(self, loop): + assert loop is None or isinstance(loop, events.AbstractEventLoop) + + if self._loop is not None: + self._loop.remove_signal_handler(signal.SIGCHLD) + + self._loop = loop + if loop is not None: + loop.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + # Prevent a race condition in case a child terminated + # during the switch. + self._do_waitpid_all() + + def _sig_chld(self): + try: + self._do_waitpid_all() + except Exception as exc: + # self._loop should always be available here + # as '_sig_chld' is added as a signal handler + # in 'attach_loop' + self._loop.call_exception_handler({ + 'message': 'Unknown exception in SIGCHLD handler', + 'exception': exc, + }) + + def _compute_returncode(self, status): + if os.WIFSIGNALED(status): + # The child process died because of a signal. + return -os.WTERMSIG(status) + elif os.WIFEXITED(status): + # The child process exited (e.g sys.exit()). + return os.WEXITSTATUS(status) + else: + # The child exited, but we don't understand its status. + # This shouldn't happen, but if it does, let's just + # return that status; perhaps that helps debug it. + return status + + +class SafeChildWatcher(BaseChildWatcher): + """'Safe' child watcher implementation. + + This implementation avoids disrupting other code spawning processes by + polling explicitly each process in the SIGCHLD handler instead of calling + os.waitpid(-1). + + This is a safe solution but it has a significant overhead when handling a + big number of children (O(n) each time SIGCHLD is raised) + """ + + def __init__(self): + super().__init__() + self._callbacks = {} + + def close(self): + self._callbacks.clear() + super().close() + + def __enter__(self): + return self + + def __exit__(self, a, b, c): + pass + + def add_child_handler(self, pid, callback, *args): + self._callbacks[pid] = (callback, args) + + # Prevent a race condition in case the child is already terminated. + self._do_waitpid(pid) + + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + + def _do_waitpid_all(self): + + for pid in list(self._callbacks): + self._do_waitpid(pid) + + def _do_waitpid(self, expected_pid): + assert expected_pid > 0 + + try: + pid, status = os.waitpid(expected_pid, os.WNOHANG) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + pid = expected_pid + returncode = 255 + logger.warning( + "Unknown child process pid %d, will report returncode 255", + pid) + else: + if pid == 0: + # The child process is still alive. + return + + returncode = self._compute_returncode(status) + if self._loop.get_debug(): + logger.debug('process %s exited with returncode %s', + expected_pid, returncode) + + try: + callback, args = self._callbacks.pop(pid) + except KeyError: # pragma: no cover + # May happen if .remove_child_handler() is called + # after os.waitpid() returns. + if self._loop.get_debug(): + logger.warning("Child watcher got an unexpected pid: %r", + pid, exc_info=True) + else: + callback(pid, returncode, *args) + + +class FastChildWatcher(BaseChildWatcher): + """'Fast' child watcher implementation. + + This implementation reaps every terminated processes by calling + os.waitpid(-1) directly, possibly breaking other code spawning processes + and waiting for their termination. + + There is no noticeable overhead when handling a big number of children + (O(1) each time a child terminates). + """ + def __init__(self): + super().__init__() + self._callbacks = {} + self._lock = threading.Lock() + self._zombies = {} + self._forks = 0 + + def close(self): + self._callbacks.clear() + self._zombies.clear() + super().close() + + def __enter__(self): + with self._lock: + self._forks += 1 + + return self + + def __exit__(self, a, b, c): + with self._lock: + self._forks -= 1 + + if self._forks or not self._zombies: + return + + collateral_victims = str(self._zombies) + self._zombies.clear() + + logger.warning( + "Caught subprocesses termination from unknown pids: %s", + collateral_victims) + + def add_child_handler(self, pid, callback, *args): + assert self._forks, "Must use the context manager" + with self._lock: + try: + returncode = self._zombies.pop(pid) + except KeyError: + # The child is running. + self._callbacks[pid] = callback, args + return + + # The child is dead already. We can fire the callback. + callback(pid, returncode, *args) + + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + + def _do_waitpid_all(self): + # Because of signal coalescing, we must keep calling waitpid() as + # long as we're able to reap a child. + while True: + try: + pid, status = os.waitpid(-1, os.WNOHANG) + except ChildProcessError: + # No more child processes exist. + return + else: + if pid == 0: + # A child process is still alive. + return + + returncode = self._compute_returncode(status) + + with self._lock: + try: + callback, args = self._callbacks.pop(pid) + except KeyError: + # unknown child + if self._forks: + # It may not be registered yet. + self._zombies[pid] = returncode + if self._loop.get_debug(): + logger.debug('unknown process %s exited ' + 'with returncode %s', + pid, returncode) + continue + callback = None + else: + if self._loop.get_debug(): + logger.debug('process %s exited with returncode %s', + pid, returncode) + + if callback is None: + logger.warning( + "Caught subprocess termination from unknown pid: " + "%d -> %d", pid, returncode) + else: + callback(pid, returncode, *args) + + +class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): + """UNIX event loop policy with a watcher for child processes.""" + _loop_factory = _UnixSelectorEventLoop + + def __init__(self): + super().__init__() + self._watcher = None + + def _init_watcher(self): + with events._lock: + if self._watcher is None: # pragma: no branch + self._watcher = SafeChildWatcher() + if isinstance(threading.current_thread(), + threading._MainThread): + self._watcher.attach_loop(self._local._loop) + + def set_event_loop(self, loop): + """Set the event loop. + + As a side effect, if a child watcher was set before, then calling + .set_event_loop() from the main thread will call .attach_loop(loop) on + the child watcher. + """ + + super().set_event_loop(loop) + + if self._watcher is not None and \ + isinstance(threading.current_thread(), threading._MainThread): + self._watcher.attach_loop(loop) + + def get_child_watcher(self): + """Get the watcher for child processes. + + If not yet set, a SafeChildWatcher object is automatically created. + """ + if self._watcher is None: + self._init_watcher() + + return self._watcher + + def set_child_watcher(self, watcher): + """Set the watcher for child processes.""" + + assert watcher is None or isinstance(watcher, AbstractChildWatcher) + + if self._watcher is not None: + self._watcher.close() + + self._watcher = watcher + +SelectorEventLoop = _UnixSelectorEventLoop +DefaultEventLoopPolicy = _UnixDefaultEventLoopPolicy diff --git a/trollius/windows_events.py b/trollius/windows_events.py new file mode 100644 index 0000000..922594f --- /dev/null +++ b/trollius/windows_events.py @@ -0,0 +1,774 @@ +"""Selector and proactor event loops for Windows.""" + +import _winapi +import errno +import math +import socket +import struct +import weakref + +from . import events +from . import base_subprocess +from . import futures +from . import proactor_events +from . import selector_events +from . import tasks +from . import windows_utils +from . import _overlapped +from .coroutines import coroutine +from .log import logger + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', + 'DefaultEventLoopPolicy', + ] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + +# Initial delay in seconds for connect_pipe() before retrying to connect +CONNECT_PIPE_INIT_DELAY = 0.001 + +# Maximum delay in seconds for connect_pipe() before retrying to connect +CONNECT_PIPE_MAX_DELAY = 0.100 + + +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + self._ov = ov + + def _repr_info(self): + info = super()._repr_info() + if self._ov is not None: + state = 'pending' if self._ov.pending else 'completed' + info.insert(1, 'overlapped=<%s, %#x>' % (state, self._ov.address)) + return info + + def _cancel_overlapped(self): + if self._ov is None: + return + try: + self._ov.cancel() + except OSError as exc: + context = { + 'message': 'Cancelling an overlapped future failed', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + self._ov = None + + def cancel(self): + self._cancel_overlapped() + return super().cancel() + + def set_exception(self, exception): + super().set_exception(exception) + self._cancel_overlapped() + + def set_result(self, result): + super().set_result(result) + self._ov = None + + +class _BaseWaitHandleFuture(futures.Future): + """Subclass of Future which represents a wait handle.""" + + def __init__(self, ov, handle, wait_handle, *, loop=None): + super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + # Keep a reference to the Overlapped object to keep it alive until the + # wait is unregistered + self._ov = ov + self._handle = handle + self._wait_handle = wait_handle + + # Should we call UnregisterWaitEx() if the wait completes + # or is cancelled? + self._registered = True + + def _poll(self): + # non-blocking wait: use a timeout of 0 millisecond + return (_winapi.WaitForSingleObject(self._handle, 0) == + _winapi.WAIT_OBJECT_0) + + def _repr_info(self): + info = super()._repr_info() + info.append('handle=%#x' % self._handle) + if self._handle is not None: + state = 'signaled' if self._poll() else 'waiting' + info.append(state) + if self._wait_handle is not None: + info.append('wait_handle=%#x' % self._wait_handle) + return info + + def _unregister_wait_cb(self, fut): + # The wait was unregistered: it's not safe to destroy the Overlapped + # object + self._ov = None + + def _unregister_wait(self): + if not self._registered: + return + self._registered = False + + wait_handle = self._wait_handle + self._wait_handle = None + try: + _overlapped.UnregisterWait(wait_handle) + except OSError as exc: + if exc.winerror != _overlapped.ERROR_IO_PENDING: + context = { + 'message': 'Failed to unregister the wait handle', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + return + # ERROR_IO_PENDING means that the unregister is pending + + self._unregister_wait_cb(None) + + def cancel(self): + self._unregister_wait() + return super().cancel() + + def set_exception(self, exception): + self._unregister_wait() + super().set_exception(exception) + + def set_result(self, result): + self._unregister_wait() + super().set_result(result) + + +class _WaitCancelFuture(_BaseWaitHandleFuture): + """Subclass of Future which represents a wait for the cancellation of a + _WaitHandleFuture using an event. + """ + + def __init__(self, ov, event, wait_handle, *, loop=None): + super().__init__(ov, event, wait_handle, loop=loop) + + self._done_callback = None + + def cancel(self): + raise RuntimeError("_WaitCancelFuture must not be cancelled") + + def _schedule_callbacks(self): + super(_WaitCancelFuture, self)._schedule_callbacks() + if self._done_callback is not None: + self._done_callback(self) + + +class _WaitHandleFuture(_BaseWaitHandleFuture): + def __init__(self, ov, handle, wait_handle, proactor, *, loop=None): + super().__init__(ov, handle, wait_handle, loop=loop) + self._proactor = proactor + self._unregister_proactor = True + self._event = _overlapped.CreateEvent(None, True, False, None) + self._event_fut = None + + def _unregister_wait_cb(self, fut): + if self._event is not None: + _winapi.CloseHandle(self._event) + self._event = None + self._event_fut = None + + # If the wait was cancelled, the wait may never be signalled, so + # it's required to unregister it. Otherwise, IocpProactor.close() will + # wait forever for an event which will never come. + # + # If the IocpProactor already received the event, it's safe to call + # _unregister() because we kept a reference to the Overlapped object + # which is used as an unique key. + self._proactor._unregister(self._ov) + self._proactor = None + + super()._unregister_wait_cb(fut) + + def _unregister_wait(self): + if not self._registered: + return + self._registered = False + + wait_handle = self._wait_handle + self._wait_handle = None + try: + _overlapped.UnregisterWaitEx(wait_handle, self._event) + except OSError as exc: + if exc.winerror != _overlapped.ERROR_IO_PENDING: + context = { + 'message': 'Failed to unregister the wait handle', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + return + # ERROR_IO_PENDING is not an error, the wait was unregistered + + self._event_fut = self._proactor._wait_cancel(self._event, + self._unregister_wait_cb) + + +class PipeServer(object): + """Class representing a pipe server. + + This is much like a bound, listening socket. + """ + def __init__(self, address): + self._address = address + self._free_instances = weakref.WeakSet() + # initialize the pipe attribute before calling _server_pipe_handle() + # because this function can raise an exception and the destructor calls + # the close() method + self._pipe = None + self._accept_pipe_future = None + self._pipe = self._server_pipe_handle(True) + + def _get_unconnected_pipe(self): + # Create new instance and return previous one. This ensures + # that (until the server is closed) there is always at least + # one pipe handle for address. Therefore if a client attempt + # to connect it will not fail with FileNotFoundError. + tmp, self._pipe = self._pipe, self._server_pipe_handle(False) + return tmp + + def _server_pipe_handle(self, first): + # Return a wrapper for a new pipe handle. + if self.closed(): + return None + flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED + if first: + flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + h = _winapi.CreateNamedPipe( + self._address, flags, + _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | + _winapi.PIPE_WAIT, + _winapi.PIPE_UNLIMITED_INSTANCES, + windows_utils.BUFSIZE, windows_utils.BUFSIZE, + _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + pipe = windows_utils.PipeHandle(h) + self._free_instances.add(pipe) + return pipe + + def closed(self): + return (self._address is None) + + def close(self): + if self._accept_pipe_future is not None: + self._accept_pipe_future.cancel() + self._accept_pipe_future = None + # Close all instances which have not been connected to by a client. + if self._address is not None: + for pipe in self._free_instances: + pipe.close() + self._pipe = None + self._address = None + self._free_instances.clear() + + __del__ = close + + +class _WindowsSelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Windows version of selector event loop.""" + + def _socketpair(self): + return windows_utils.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + """Windows version of proactor event loop using IOCP.""" + + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return windows_utils.socketpair() + + @coroutine + def create_pipe_connection(self, protocol_factory, address): + f = self._proactor.connect_pipe(address) + pipe = yield from f + protocol = protocol_factory() + trans = self._make_duplex_pipe_transport(pipe, protocol, + extra={'addr': address}) + return trans, protocol + + @coroutine + def start_serving_pipe(self, protocol_factory, address): + server = PipeServer(address) + + def loop_accept_pipe(f=None): + pipe = None + try: + if f: + pipe = f.result() + server._free_instances.discard(pipe) + + if server.closed(): + # A client connected before the server was closed: + # drop the client (close the pipe) and exit + pipe.close() + return + + protocol = protocol_factory() + self._make_duplex_pipe_transport( + pipe, protocol, extra={'addr': address}) + + pipe = server._get_unconnected_pipe() + if pipe is None: + return + + f = self._proactor.accept_pipe(pipe) + except OSError as exc: + if pipe and pipe.fileno() != -1: + self.call_exception_handler({ + 'message': 'Pipe accept failed', + 'exception': exc, + 'pipe': pipe, + }) + pipe.close() + elif self._debug: + logger.warning("Accept pipe failed on pipe %r", + pipe, exc_info=True) + except futures.CancelledError: + if pipe: + pipe.close() + else: + server._accept_pipe_future = f + f.add_done_callback(loop_accept_pipe) + + self.call_soon(loop_accept_pipe) + return [server] + + @coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + waiter = futures.Future(loop=self) + transp = _WindowsSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + waiter=waiter, extra=extra, + **kwargs) + try: + yield from waiter + except Exception as exc: + # Workaround CPython bug #23353: using yield/yield-from in an + # except block of a generator doesn't clear properly sys.exc_info() + err = exc + else: + err = None + + if err is not None: + transp.close() + yield from transp._wait() + raise err + + return transp + + +class IocpProactor: + """Proactor implementation using IOCP.""" + + def __init__(self, concurrency=0xffffffff): + self._loop = None + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + self._unregistered = [] + self._stopped_serving = weakref.WeakSet() + + def __repr__(self): + return ('<%s overlapped#=%s result#=%s>' + % (self.__class__.__name__, len(self._cache), + len(self._results))) + + def set_loop(self, loop): + self._loop = loop + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def _result(self, value): + fut = futures.Future(loop=self._loop) + fut.set_result(value) + return fut + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + try: + if isinstance(conn, socket.socket): + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(conn.fileno(), nbytes) + except BrokenPipeError: + return self._result(b'') + + def finish_recv(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_recv) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(conn.fileno(), buf) + + def finish_send(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_send) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket(listener.family) + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + @coroutine + def accept_coro(future, conn): + # Coroutine closing the accept socket if the future is cancelled + try: + yield from future + except futures.CancelledError: + conn.close() + raise + + future = self._register(ov, listener, finish_accept) + coro = accept_coro(future, conn) + tasks.ensure_future(coro, loop=self._loop) + return future + + def connect(self, conn, address): + self._register_with_iocp(conn) + # The socket needs to be locally bound before we call ConnectEx(). + try: + _overlapped.BindLocal(conn.fileno(), conn.family) + except OSError as e: + if e.winerror != errno.WSAEINVAL: + raise + # Probably already locally bound; check using getsockname(). + if conn.getsockname()[1] == 0: + raise + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) + return conn + + return self._register(ov, conn, finish_connect) + + def accept_pipe(self, pipe): + self._register_with_iocp(pipe) + ov = _overlapped.Overlapped(NULL) + connected = ov.ConnectNamedPipe(pipe.fileno()) + + if connected: + # ConnectNamePipe() failed with ERROR_PIPE_CONNECTED which means + # that the pipe is connected. There is no need to wait for the + # completion of the connection. + return self._result(pipe) + + def finish_accept_pipe(trans, key, ov): + ov.getresult() + return pipe + + return self._register(ov, pipe, finish_accept_pipe) + + @coroutine + def connect_pipe(self, address): + delay = CONNECT_PIPE_INIT_DELAY + while True: + # Unfortunately there is no way to do an overlapped connect to a pipe. + # Call CreateFile() in a loop until it doesn't fail with + # ERROR_PIPE_BUSY + try: + handle = _overlapped.ConnectPipe(address) + break + except OSError as exc: + if exc.winerror != _overlapped.ERROR_PIPE_BUSY: + raise + + # ConnectPipe() failed with ERROR_PIPE_BUSY: retry later + delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) + yield from tasks.sleep(delay, loop=self._loop) + + return windows_utils.PipeHandle(handle) + + def wait_for_handle(self, handle, timeout=None): + """Wait for a handle. + + Return a Future object. The result of the future is True if the wait + completed, or False if the wait did not complete (on timeout). + """ + return self._wait_for_handle(handle, timeout, False) + + def _wait_cancel(self, event, done_callback): + fut = self._wait_for_handle(event, None, True) + # add_done_callback() cannot be used because the wait may only complete + # in IocpProactor.close(), while the event loop is not running. + fut._done_callback = done_callback + return fut + + def _wait_for_handle(self, handle, timeout, _is_cancel): + if timeout is None: + ms = _winapi.INFINITE + else: + # RegisterWaitForSingleObject() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + ms = math.ceil(timeout * 1e3) + + # We only create ov so we can use ov.address as a key for the cache. + ov = _overlapped.Overlapped(NULL) + wait_handle = _overlapped.RegisterWaitWithQueue( + handle, self._iocp, ov.address, ms) + if _is_cancel: + f = _WaitCancelFuture(ov, handle, wait_handle, loop=self._loop) + else: + f = _WaitHandleFuture(ov, handle, wait_handle, self, + loop=self._loop) + if f._source_traceback: + del f._source_traceback[-1] + + def finish_wait_for_handle(trans, key, ov): + # Note that this second wait means that we should only use + # this with handles types where a successful wait has no + # effect. So events or processes are all right, but locks + # or semaphores are not. Also note if the handle is + # signalled and then quickly reset, then we may return + # False even though we have not timed out. + return f._poll() + + self._cache[ov.address] = (f, ov, 0, finish_wait_for_handle) + return f + + def _register_with_iocp(self, obj): + # To get notifications of finished ops on this objects sent to the + # completion port, were must register the handle. + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + # XXX We could also use SetFileCompletionNotificationModes() + # to avoid sending notifications to completion port of ops + # that succeed immediately. + + def _register(self, ov, obj, callback): + # Return a future which will be set with the result of the + # operation when it completes. The future's value is actually + # the value returned by callback(). + f = _OverlappedFuture(ov, loop=self._loop) + if f._source_traceback: + del f._source_traceback[-1] + if not ov.pending: + # The operation has completed, so no need to postpone the + # work. We cannot take this short cut if we need the + # NumberOfBytes, CompletionKey values returned by + # PostQueuedCompletionStatus(). + try: + value = callback(None, None, ov) + except OSError as e: + f.set_exception(e) + else: + f.set_result(value) + # Even if GetOverlappedResult() was called, we have to wait for the + # notification of the completion in GetQueuedCompletionStatus(). + # Register the overlapped operation to keep a reference to the + # OVERLAPPED object, otherwise the memory is freed and Windows may + # read uninitialized memory. + + # Register the overlapped operation for later. Note that + # we only store obj to prevent it from being garbage + # collected too early. + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _unregister(self, ov): + """Unregister an overlapped object. + + Call this method when its future has been cancelled. The event can + already be signalled (pending in the proactor event queue). It is also + safe if the event is never signalled (because it was cancelled). + """ + self._unregistered.append(ov) + + def _get_accept_socket(self, family): + s = socket.socket(family) + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + # GetQueuedCompletionStatus() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + ms = math.ceil(timeout * 1e3) + if ms >= INFINITE: + raise ValueError("timeout too big") + + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + break + ms = 0 + + err, transferred, key, address = status + try: + f, ov, obj, callback = self._cache.pop(address) + except KeyError: + if self._loop.get_debug(): + self._loop.call_exception_handler({ + 'message': ('GetQueuedCompletionStatus() returned an ' + 'unexpected event'), + 'status': ('err=%s transferred=%s key=%#x address=%#x' + % (err, transferred, key, address)), + }) + + # key is either zero, or it is used to return a pipe + # handle which should be closed to avoid a leak. + if key not in (0, _overlapped.INVALID_HANDLE_VALUE): + _winapi.CloseHandle(key) + continue + + if obj in self._stopped_serving: + f.cancel() + # Don't call the callback if _register() already read the result or + # if the overlapped has been cancelled + elif not f.done(): + try: + value = callback(transferred, key, ov) + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + + # Remove unregisted futures + for ov in self._unregistered: + self._cache.pop(ov.address, None) + self._unregistered.clear() + + def _stop_serving(self, obj): + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop._stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) + + def close(self): + # Cancel remaining registered operations. + for address, (fut, ov, obj, callback) in list(self._cache.items()): + if fut.cancelled(): + # Nothing to do with cancelled futures + pass + elif isinstance(fut, _WaitCancelFuture): + # _WaitCancelFuture must not be cancelled + pass + else: + try: + fut.cancel() + except OSError as exc: + if self._loop is not None: + context = { + 'message': 'Cancelling a future failed', + 'exception': exc, + 'future': fut, + } + if fut._source_traceback: + context['source_traceback'] = fut._source_traceback + self._loop.call_exception_handler(context) + + while self._cache: + if not self._poll(1): + logger.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None + + def __del__(self): + self.close() + + +class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport): + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + self._proc = windows_utils.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + bufsize=bufsize, **kwargs) + + def callback(f): + returncode = self._proc.poll() + self._process_exited(returncode) + + f = self._loop._proactor.wait_for_handle(int(self._proc._handle)) + f.add_done_callback(callback) + + +SelectorEventLoop = _WindowsSelectorEventLoop + + +class _WindowsDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): + _loop_factory = SelectorEventLoop + + +DefaultEventLoopPolicy = _WindowsDefaultEventLoopPolicy diff --git a/trollius/windows_utils.py b/trollius/windows_utils.py new file mode 100644 index 0000000..870cd13 --- /dev/null +++ b/trollius/windows_utils.py @@ -0,0 +1,223 @@ +""" +Various Windows specific bits and pieces +""" + +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import _winapi +import itertools +import msvcrt +import os +import socket +import subprocess +import tempfile +import warnings + + +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] + + +# Constants/globals + + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +STDOUT = subprocess.STDOUT +_mmap_counter = itertools.count() + + +if hasattr(socket, 'socketpair'): + # Since Python 3.5, socket.socketpair() is now also available on Windows + socketpair = socket.socketpair +else: + # Replacement for socket.socketpair() + def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. + Public domain. + """ + if family == socket.AF_INET: + host = '127.0.0.1' + elif family == socket.AF_INET6: + host = '::1' + else: + raise ValueError("Only AF_INET and AF_INET6 socket address " + "families are supported") + if type != socket.SOCK_STREAM: + raise ValueError("Only SOCK_STREAM socket type is supported") + if proto != 0: + raise ValueError("Only protocol zero is supported") + + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + try: + lsock.bind((host, 0)) + lsock.listen(1) + # On IPv6, ignore flow_info and scope_id + addr, port = lsock.getsockname()[:2] + csock = socket.socket(family, type, proto) + try: + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + csock.setblocking(True) + ssock, _ = lsock.accept() + except: + csock.close() + raise + finally: + lsock.close() + return (ssock, csock) + + +# Replacement for os.pipe() using handles instead of fds + + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + + +# Wrapper for a pipe handle + + +class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ + def __init__(self, handle): + self._handle = handle + + def __repr__(self): + if self._handle is not None: + handle = 'handle=%r' % self._handle + else: + handle = 'closed' + return '<%s %s>' % (self.__class__.__name__, handle) + + @property + def handle(self): + return self._handle + + def fileno(self): + if self._handle is None: + raise ValueError("I/O operatioon on closed pipe") + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle is not None: + CloseHandle(self._handle) + self._handle = None + + def __del__(self): + if self._handle is not None: + warnings.warn("unclosed %r" % self, ResourceWarning) + self.close() + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + + +# Replacement for subprocess.Popen using overlapped pipe handles + + +class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + assert not kwds.get('universal_newlines') + assert kwds.get('bufsize', 0) == 0 + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True), duplex=True) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + else: + stdin_rfd = stdin + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + else: + stdout_wfd = stdout + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + elif stderr == STDOUT: + stderr_wfd = stdout_wfd + else: + stderr_wfd = stderr + try: + super().__init__(args, stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + if h is not None: + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) |