summaryrefslogtreecommitdiff
path: root/proactor.py
diff options
context:
space:
mode:
Diffstat (limited to 'proactor.py')
-rw-r--r--proactor.py432
1 files changed, 432 insertions, 0 deletions
diff --git a/proactor.py b/proactor.py
new file mode 100644
index 0000000..ffe5b17
--- /dev/null
+++ b/proactor.py
@@ -0,0 +1,432 @@
+#
+# Module implementing the Proactor pattern
+#
+# A proactor is used to initiate asynchronous I/O, and to wait for
+# completion of previously initiated operations.
+#
+
+import os
+import sys
+import errno
+import socket
+import select
+import time
+import warnings
+
+
+__all__ = ['SelectProactor']
+
+#
+# Future class
+#
+
+class Future(Exception):
+
+ def __init__(self):
+ self._callbacks = []
+
+ def result(self):
+ # does not block for operation to complete
+ assert self.done()
+ if self.success:
+ return self.value
+ else:
+ raise self.value
+
+ def set_result(self, value):
+ assert not self.done()
+ self.success = True
+ self.value = value
+ self._invoke_callbacks()
+
+ def set_exception(self, value):
+ assert not self.done()
+ self.success = False
+ self.value = value
+ self._invoke_callbacks()
+
+ def done(self):
+ return hasattr(self, 'success')
+
+ def add_done_callback(self, func):
+ if self.done():
+ func(self)
+ else:
+ self._callbacks.append(func)
+
+ def _invoke_callbacks(self):
+ for func in self._callbacks:
+ try:
+ func(self)
+ except Exception:
+ sys.excepthook(*sys.exc_info())
+ del self._callbacks
+
+#
+# Base class for all proactors
+#
+
+class BaseProactor:
+ _Future = Future
+
+ def __init__(self):
+ self._results = []
+
+ def poll(self, timeout=None):
+ if not self._results:
+ self._poll(timeout)
+ tmp, self._results = self._results, []
+ return tmp
+
+ def filteredpoll(self, penders, timeout=None):
+ if timeout is None:
+ deadline = None
+ elif timeout < 0:
+ raise ValueError('negative timeout')
+ else:
+ deadline = time.monotonic() + timeout
+ S = set(penders)
+ while True:
+ filtered = [x for x in self._results if x[0] in S]
+ if filtered:
+ self._results = [x for x in self._results if x[0] not in S]
+ return filtered
+ self._poll(timeout)
+ if deadline is not None:
+ timeout = deadline - time.monotonic()
+ if timeout <= 0:
+ break
+
+ def close(self):
+ pass
+
+#
+# Initiator methods for proactors based on select()/poll()/epoll()/kqueue()
+#
+
+READABLE = 0
+WRITABLE = 1
+
+class ReadyBaseProactor(BaseProactor):
+ def __init__(self):
+ super().__init__()
+ self._queue = [{}, {}]
+
+ def pollable(self):
+ return any(self._queue)
+
+ def recv(self, sock, nbytes, flags=0):
+ try:
+ return sock.recv(nbytes, flags)
+ except BlockingIOError:
+ raise self._register(sock.fileno(), READABLE,
+ sock.recv, nbytes, flags)
+
+ def send(self, sock, buf, flags=0):
+ try:
+ return sock.send(buf, flags)
+ except BlockingIOError:
+ raise self._register(sock.fileno(), WRITABLE,
+ sock.send, buf, flags)
+
+ def accept(self, sock):
+ def _accept():
+ conn, addr = sock.accept()
+ conn.settimeout(0)
+ return conn, addr
+ try:
+ return _accept()
+ except BlockingIOError:
+ raise self._register(sock.fileno(), READABLE, _accept)
+
+ def connect(self, sock, addr):
+ assert sock.gettimeout() == 0
+ err = sock.connect_ex(addr)
+ if err not in self._connection_errors:
+ raise OSError(err, os.strerror(err))
+ def _connect():
+ err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ if err != 0:
+ raise OSError(err, os.strerror(err))
+ raise self._register(sock.fileno(), WRITABLE, _connect)
+
+ # hacks to support SSL
+ def _readable(self, sock):
+ return self._register(sock.fileno(), READABLE, lambda:None)
+
+ def _writable(self, sock):
+ return self._register(sock.fileno(), WRITABLE, lambda:None)
+
+#
+# Proactor using select()
+#
+
+class SelectProactor(ReadyBaseProactor):
+ _connection_errors = {0, errno.EINPROGRESS}
+ _select = select.select
+
+ def _poll(self, timeout=None):
+ rfds, wfds, xfds = self._select(self._queue[READABLE].keys(),
+ self._queue[WRITABLE].keys(),
+ (), timeout)
+ for fd in rfds:
+ self._handle(fd, READABLE)
+ for fd in wfds:
+ self._handle(fd, WRITABLE)
+
+ def _handle(self, fd, kind):
+ Q = self._queue[kind][fd]
+ f, callback, args = Q.pop(0)
+ try:
+ f.set_result(callback(*args))
+ except OSError as e:
+ f.set_exception(e)
+ self._results.append(f)
+ if not Q:
+ del self._queue[kind][fd]
+
+ def _register(self, fd, kind, callback, *args):
+ f = self._Future()
+ queue = self._queue[kind]
+ if fd not in queue:
+ queue[fd] = []
+ queue[fd].append((f, callback, args))
+ return f
+
+ if sys.platform == 'win32':
+ # Windows insists on being awkward...
+ _connection_errors = {0, errno.WSAEWOULDBLOCK}
+
+ def _select(self, rfds, wfds, _, timeout=None):
+ if not (rfds or wfds):
+ time.sleep(timeout)
+ return [], [], []
+ else:
+ rfds, wfds, xfds = select.select(rfds, wfds, wfds, timeout)
+ return rfds, wfds + xfds, []
+
+
+#
+# Proactor using poll()
+#
+
+if hasattr(select, 'poll'):
+ __all__.append('PollProactor')
+
+ from select import POLLIN, POLLPRI, POLLOUT, POLLHUP, POLLERR, POLLNVAL
+
+ FLAG = [POLLIN, POLLOUT]
+ READ_EXTRA_FLAGS = POLLIN | POLLHUP | POLLNVAL | POLLERR
+ WRITE_EXTRA_FLAGS = POLLOUT | POLLHUP | POLLNVAL | POLLERR
+
+ class PollProactor(ReadyBaseProactor):
+ _connection_errors = {0, errno.EINPROGRESS}
+ _make_poller = select.poll
+ _uses_msecs = True
+
+ def __init__(self):
+ super().__init__()
+ self._poller = self._make_poller()
+ self._flag = {}
+
+ def _poll(self, timeout=None):
+ if timeout is None:
+ timeout = -1
+ elif timeout < 0:
+ raise ValueError('negative timeout')
+ elif self._uses_msecs:
+ timeout = int(timeout*1000 + 0.5)
+ ready = self._poller.poll(timeout)
+ for fd, flags in ready:
+ if fd in self._queue[READABLE] and flags & READ_EXTRA_FLAGS:
+ self._handle(fd, READABLE)
+ if fd in self._queue[WRITABLE] and flags & WRITE_EXTRA_FLAGS:
+ self._handle(fd, WRITABLE)
+
+ def _handle(self, fd, kind):
+ Q = self._queue[kind][fd]
+ f, callback, args = Q.pop(0)
+ try:
+ f.set_result(callback(*args))
+ except OSError as e:
+ f.set_exception(e)
+ self._results.append(f)
+ if not Q:
+ del self._queue[kind][fd]
+ flag = self._flag[fd] = self._flag[fd] & ~FLAG[kind]
+ if flag == 0:
+ del self._flag[fd]
+ self._poller.unregister(fd)
+ else:
+ self._poller.modify(fd, flag)
+
+ def _register(self, fd, kind, callback, *args):
+ f = self._Future()
+ queue = self._queue[kind]
+ if fd not in queue:
+ queue[fd] = []
+ old_flag = self._flag.get(fd, 0)
+ flag = self._flag[fd] = old_flag | FLAG[kind]
+ if old_flag == 0:
+ self._poller.register(fd, flag)
+ else:
+ self._poller.modify(fd, flag)
+ queue[fd].append((f, callback, args))
+ return f
+
+#
+# Proactor using epoll()
+#
+
+if hasattr(select, 'epoll'):
+ assert (select.EPOLLIN, select.EPOLLOUT) == (POLLIN, POLLOUT)
+
+ __all__.append('EpollProactor')
+
+ class EpollProactor(PollProactor):
+ _make_poller = select.epoll
+ _uses_msecs = False
+
+
+#
+# Proactor using overlapped IO and a completion port
+#
+
+try:
+ from _overlapped import *
+except ImportError:
+ if sys.platform == 'win32':
+ warnings.warn('IOCP support not compiled')
+else:
+ __all__.append('IocpProactor')
+
+ from _winapi import CloseHandle
+ import weakref
+
+ class IocpProactor(BaseProactor):
+ def __init__(self, concurrency=0xffffffff):
+ super().__init__()
+ self._iocp = CreateIoCompletionPort(
+ INVALID_HANDLE_VALUE, NULL, 0, concurrency)
+ self._cache = {}
+ self._registered = weakref.WeakSet()
+
+ def pollable(self):
+ return bool(self._cache)
+
+ def recv(self, conn, nbytes, flags=0):
+ self._register_obj(conn)
+ ov = Overlapped(NULL)
+ ov.WSARecv(conn.fileno(), nbytes, flags)
+ if ov.pending:
+ raise self._register(ov, conn, ov.getresult)
+ return ov.getresult()
+
+ def send(self, conn, buf, flags=0):
+ self._register_obj(conn)
+ ov = Overlapped(NULL)
+ ov.WSASend(conn.fileno(), buf, flags)
+ if ov.pending:
+ raise self._register(ov, conn, ov.getresult)
+ return ov.getresult()
+
+ def accept(self, listener):
+ self._register_obj(listener)
+ conn = self._get_accept_socket()
+ ov = Overlapped(NULL)
+ ov.AcceptEx(listener.fileno(), conn.fileno())
+ def finish_accept():
+ addr = ov.getresult()
+ conn.setsockopt(socket.SOL_SOCKET,
+ SO_UPDATE_ACCEPT_CONTEXT, listener.fileno())
+ conn.settimeout(listener.gettimeout())
+ return conn, conn.getpeername()
+ if ov.pending:
+ raise self._register(ov, listener, finish_accept)
+ return ov.getresult()
+
+ def connect(self, conn, address):
+ self._register_obj(conn)
+ BindLocal(conn.fileno(), len(address))
+ ov = Overlapped(NULL)
+ ov.ConnectEx(conn.fileno(), address)
+ def finish_connect():
+ ov.getresult()
+ conn.setsockopt(socket.SOL_SOCKET,
+ SO_UPDATE_CONNECT_CONTEXT, 0)
+ return conn
+ if ov.pending:
+ raise self._register(ov, conn, finish_connect)
+ return ov.getresult()
+
+ def _readable(self, sock):
+ raise NotImplementedError('IocpProactor._readable()')
+
+ def _writable(self, sock):
+ raise NotImplementedError('IocpProactor._writable()')
+
+ def _register_obj(self, obj):
+ if obj not in self._registered:
+ self._registered.add(obj)
+ CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0)
+ SetFileCompletionNotificationModes(obj.fileno(),
+ FILE_SKIP_COMPLETION_PORT_ON_SUCCESS);
+
+ def _register(self, ov, obj, callback, discard=False):
+ # we prevent ov and obj from being garbage collected
+ f = None if discard else self._Future()
+ self._cache[ov.address] = (f, ov, obj, callback)
+ return f
+
+ def _get_accept_socket(self):
+ s = socket.socket()
+ s.settimeout(0)
+ return s
+
+ def _poll(self, timeout=None):
+ if timeout is None:
+ ms = INFINITE
+ elif timeout < 0:
+ raise ValueError("negative timeout")
+ else:
+ ms = int(timeout * 1000 + 0.5)
+ if ms >= INFINITE:
+ raise ValueError("timeout too big")
+ while True:
+ status = GetQueuedCompletionStatus(self._iocp, ms)
+ if status is None:
+ return
+ f, ov, obj, callback = self._cache.pop(status[3])
+ try:
+ value = callback()
+ except OSError as e:
+ if f is None:
+ sys.excepthook(*sys.exc_info())
+ continue
+ f.set_exception(e)
+ self._results.append(f)
+ else:
+ if f is None:
+ continue
+ f.set_result(value)
+ self._results.append(f)
+ ms = 0
+
+ def close(self, *, CloseHandle=CloseHandle):
+ if self._iocp is not None:
+ CloseHandle(self._iocp)
+ self._iocp = None
+
+ __del__ = close
+
+#
+# Select default proactor (IocpReactor does not support SSL)
+#
+
+for _ in ('EpollProactor', 'IocpProactor', 'PollProactor', 'SelectProactor'):
+ if _ in globals():
+ Proactor = globals()[_]
+ break
+del _
+
+# Proactor = SelectProactor