From 245aacf40c19560264e3bc311d2af2105b232765 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 26 Oct 2012 11:22:47 -0700 Subject: Radical refactoring: no more Scheduler class; use thread-local Context. --- main.py | 65 +++++++++--------- polling.py | 20 ++++-- scheduling.py | 208 +++++++++++++++++++++++++++++++++------------------------- sockets.py | 24 ++++--- 4 files changed, 178 insertions(+), 139 deletions(-) diff --git a/main.py b/main.py index f4e64d0..829e1f4 100644 --- a/main.py +++ b/main.py @@ -30,29 +30,10 @@ import time import socket import sys -# Initialize logging before we import polling. -# TODO: Change polling.py so we can do this in main(). -if '-d' in sys.argv: - level = logging.DEBUG -elif '-v' in sys.argv: - level = logging.INFO -elif '-q' in sys.argv: - level = logging.ERROR -else: - level = logging.WARN -logging.basicConfig(level=level) - # Local imports (keep in alphabetic order). -import polling import scheduling import sockets -eventloop = polling.EventLoop() -threadrunner = polling.ThreadRunner(eventloop) -scheduler = scheduling.Scheduler(eventloop, threadrunner) - -sockets.scheduler = scheduler # TODO: Find a better way. - def urlfetch(host, port=80, method='GET', path='/', body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): @@ -114,23 +95,27 @@ def urlfetch(host, port=80, method='GET', path='/', def doit(): + TIMEOUT = 2 + tasks = set() + # This references NDB's default test service. # (Sadly the service is single-threaded.) - task1 = scheduler.newtask(urlfetch('localhost', 8080, path='/'), - 'root', timeout=2) - task2 = scheduler.newtask(urlfetch('localhost', 8080, path='/home'), - 'home', timeout=2) + task1 = scheduling.Task(urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(urlfetch('localhost', 8080, path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) # Fetch python.org home page. - task3 = scheduler.newtask(urlfetch('python.org', 80, path='/'), - 'python', timeout=2) - - tasks = {task1, task2, task3} + task3 = scheduling.Task(urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) # Fetch XKCD home page using SSL. (Doesn't like IPv6.) - task4 = scheduler.newtask(urlfetch('xkcd.com', 443, path='/', - af=socket.AF_INET), - 'xkcd', timeout=2) + task4 = scheduling.Task(urlfetch('xkcd.com', 443, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) tasks.add(task4) ## # Fetch many links from python.org (/x.y.z). @@ -139,14 +124,14 @@ def doit(): ## path = '/{}.{}'.format(x, y) ## g = urlfetch('82.94.164.162', 80, ## path=path, hdrs={'host': 'python.org'}) -## t = scheduler.newtask(g, path, timeout=2) +## t = scheduling.Task(g, path, timeout=2) ## tasks.add(t) -## print(tasks) +## print(tasks) for t in tasks: t.start() - scheduler.run() -## print(tasks) + scheduling.run() +## print(tasks) for t in tasks: print(t.name + ':', t.exception or t.result) @@ -160,6 +145,18 @@ def logtimes(real): def main(): t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + doit() t1 = time.time() logtimes(t1-t0) diff --git a/polling.py b/polling.py index c822034..96a3f11 100644 --- a/polling.py +++ b/polling.py @@ -421,12 +421,17 @@ elif hasattr(select, 'poll'): else: poll_base = SelectMixin -logging.info('Using Pollster base class %r', poll_base.__name__) - class EventLoop(EventLoopMixin, poll_base): """Event loop implementation using the optimal pollster mixin.""" + def __init__(self): + super().__init__() + logging.info('Using Pollster base class %r', poll_base.__name__) + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + class ThreadRunner: """Helper to submit work to a thread pool and wait for it. @@ -438,9 +443,9 @@ class ThreadRunner: The only public API is submit(). """ - def __init__(self, eventloop, max_workers=5): + def __init__(self, eventloop, executor=None): self.eventloop = eventloop - self.threadpool = concurrent.futures.ThreadPoolExecutor(max_workers) + self.executor = executor # Will be constructed lazily. self.pipe_read_fd, self.pipe_write_fd = os.pipe() self.active_count = 0 @@ -460,7 +465,12 @@ class ThreadRunner: should not wait for that, but rather add a callback to it. """ if executor is None: - executor = self.threadpool + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor assert self.active_count >= 0, self.active_count future = executor.submit(func, *args) if self.active_count == 0: diff --git a/scheduling.py b/scheduling.py index d4c3cb6..27339d0 100644 --- a/scheduling.py +++ b/scheduling.py @@ -14,7 +14,8 @@ PATTERNS TO TRY: - Wait for first N that are ready. - Wait until some predicate becomes true. - Run with timeout. - +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). """ __author__ = 'Guido van Rossum ' @@ -22,11 +23,47 @@ __author__ = 'Guido van Rossum ' # Standard library imports (keep in alphabetic order). from concurrent.futures import TimeoutError import logging +import threading import time +import polling + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = polling.EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = polling.ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! + class Task: - """Lightweight wrapper around a generator. + """Wrapper around a stack of generators. This is a bit like a Future, but with a different interface. @@ -35,8 +72,7 @@ class Task: - wait for result. """ - def __init__(self, sched, gen, name=None, *, timeout=None): - self.sched = sched + def __init__(self, gen, name=None, *, timeout=None): self.gen = gen self.name = name or gen.__name__ if timeout is not None and timeout < 1000000: @@ -53,7 +89,7 @@ class Task: def run(self): if not self.alive: return - self.sched.current = self + context.current_task = self try: if self.timeout is not None and self.timeout < time.time(): self.gen.throw(TimeoutError) @@ -71,99 +107,91 @@ class Task: self.exception = exc raise else: - if self.sched.current is not None: + if context.current_task is not None: self.start() finally: - self.sched.current = None + context.current_task = None def start(self): if self.alive: - self.sched.eventloop.call_soon(self.run) + context.eventloop.call_soon(self.run) -class Scheduler: +def run(): + context.eventloop.run() - def __init__(self, eventloop, threadrunner): - self.eventloop = eventloop # polling.EventLoop instance. - self.threadrunner = threadrunner # polling.Threadrunner instance. - self.current = None # Current Task. - def run(self): - self.eventloop.run() - - def newtask(self, gen, name=None, *, timeout=None): - return Task(self, gen, name, timeout=timeout) - - def start(self, gen, name=None, *, timeout=None): - task = self.newtask(gen, name, timeout=timeout) - task.start() - return task - - def block_r(self, fd): - self.block_io(fd, 'r') - - def block_w(self, fd): - self.block_io(fd, 'w') - - def block_io(self, fd, flag): - assert isinstance(fd, int), repr(fd) - assert flag in ('r', 'w'), repr(flag) - task = self.block() - dcall = None - if task.timeout: - dcall = self.eventloop.call_later(task.timeout, - self.unblock_timeout, - fd, flag, task) - if flag == 'r': - self.eventloop.add_reader(fd, self.unblock_io, - fd, flag, task, dcall) - else: - self.eventloop.add_writer(fd, self.unblock_io, - fd, flag, task, dcall) - - def block(self): - assert self.current - task = self.current - self.current = None - return task - - def unblock_io(self, fd, flag, task, dcall): - if dcall is not None: - dcall.cancel() - if flag == 'r': - self.eventloop.remove_reader(fd) - else: - self.eventloop.remove_writer(fd) - task.start() - - def unblock_timeout(self, fd, flag, task): - # NOTE: Due to the call_soon() semantics, we can't guarantee - # that unblock_timeout() isn't called *after* unblock_io() has - # already been called. So we must write this defensively. - # TODO: Analyse this further for race conditions etc. - if flag == 'r': - if fd in self.eventloop.readers: - self.eventloop.remove_reader(fd) - else: - if fd in self.eventloop.writers: - self.eventloop.remove_writer(fd) - task.timeout = 0 # Force it to cancel. - task.start() - - def call_in_thread(self, func, *args, executor=None): - # TODO: Prove there is no race condition here. - task = self.block() - future = self.threadrunner.submit(func, *args, executor=executor) - future.add_done_callback(lambda _: task.start()) - try: - yield - except TimeoutError: - future.cancel() - raise - assert future.done() - return future.result() +def sleep(secs): + task = block() + context.eventloop.call_later(secs, task.start) + yield + + +def block_r(fd): + block_io(fd, 'r') + + +def block_w(fd): + block_io(fd, 'w') + + +def block_io(fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + task = block() + dcall = None + if task.timeout: + dcall = context.eventloop.call_later(task.timeout, unblock_timeout, + fd, flag, task) + if flag == 'r': + context.eventloop.add_reader(fd, unblock_io, fd, flag, task, dcall) + else: + context.eventloop.add_writer(fd, unblock_io, fd, flag, task, dcall) + + +def block(): + assert context.current_task + task = context.current_task + context.current_task = None + return task + + +def unblock_io(fd, flag, task, dcall): + if dcall is not None: + dcall.cancel() + if flag == 'r': + context.eventloop.remove_reader(fd) + else: + context.eventloop.remove_writer(fd) + task.start() + + +def unblock_timeout(fd, flag, task): + # NOTE: Due to the call_soon() semantics, we can't guarantee + # that unblock_timeout() isn't called *after* unblock_io() has + # already been called. So we must write this defensively. + # TODO: Analyse this further for race conditions etc. + if flag == 'r': + if fd in context.eventloop.readers: + context.eventloop.remove_reader(fd) + else: + if fd in context.eventloop.writers: + context.eventloop.remove_writer(fd) + task.timeout = 0 # Force it to cancel. + task.start() + - def sleep(self, secs): - task = self.block() - self.eventloop.call_later(secs, task.start) +def call_in_thread(func, *args, executor=None): + # TODO: Prove there is no race condition here. + task = block() + future = context.threadrunner.submit(func, *args, executor=executor) + # Don't reference context in the lambda! It is called in another thread. + this_eventloop = context.eventloop + future.add_done_callback(lambda _: this_eventloop.call_soon(task.run)) + try: yield + except TimeoutError: + future.cancel() + raise + assert future.done() + return future.result() diff --git a/sockets.py b/sockets.py index 85b7005..9f569e4 100644 --- a/sockets.py +++ b/sockets.py @@ -23,11 +23,15 @@ TODO: __author__ = 'Guido van Rossum ' +# Stdlib imports. import errno import re import socket import ssl +# Local imports. +import scheduling + class SocketTransport: """Transport wrapping a socket. @@ -41,14 +45,14 @@ class SocketTransport: def recv(self, n): """COROUTINE: Read up to n bytes, blocking at most once.""" assert n >= 0, n - scheduler.block_r(self.sock.fileno()) + scheduling.block_r(self.sock.fileno()) yield return self.sock.recv(n) def send(self, data): """COROUTINE; Send data to the socket, blocking until all written.""" while data: - scheduler.block_w(self.sock.fileno()) + scheduling.block_w(self.sock.fileno()) yield n = self.sock.send(data) assert 0 <= n <= len(data), (n, len(data)) @@ -80,10 +84,10 @@ class SslTransport: try: self.sslsock.do_handshake() except ssl.SSLWantReadError: - scheduler.block_r(self.sslsock.fileno()) + scheduling.block_r(self.sslsock.fileno()) yield except ssl.SSLWantWriteError: - scheduler.block_w(self.sslsock.fileno()) + scheduling.block_w(self.sslsock.fileno()) yield else: break @@ -97,7 +101,7 @@ class SslTransport: try: return self.sslsock.recv(n) except socket.error as err: - scheduler.block_r(self.sslsock.fileno()) + scheduling.block_r(self.sslsock.fileno()) yield def send(self, data): @@ -106,7 +110,7 @@ class SslTransport: try: n = self.sslsock.send(data) except socket.error as err: - scheduler.block_w(self.sslsock.fileno()) + scheduling.block_w(self.sslsock.fileno()) yield if n == len(data): break @@ -186,7 +190,7 @@ def connect(sock, address): except socket.error as err: if err.errno != errno.EINPROGRESS: raise - scheduler.block_w(sock.fileno()) + scheduling.block_w(sock.fileno()) yield err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: @@ -198,9 +202,9 @@ def getaddrinfo(host, port, af=0, socktype=0, proto=0): Each info is a tuple (af, socktype, protocol, canonname, address). """ - infos = yield from scheduler.call_in_thread(socket.getaddrinfo, - host, port, af, - socktype, proto) + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) return infos -- cgit v1.2.1