diff options
Diffstat (limited to 'futures/thread.py')
-rw-r--r-- | futures/thread.py | 346 |
1 files changed, 346 insertions, 0 deletions
diff --git a/futures/thread.py b/futures/thread.py new file mode 100644 index 0000000..a4c779d --- /dev/null +++ b/futures/thread.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python + +import queue +import threading + +FIRST_COMPLETED = 0 +FIRST_EXCEPTION = 1 +ALL_COMPLETED = 2 +RETURN_IMMEDIATELY = 3 + +_PENDING = 0 +_RUNNING = 1 +_CANCELLED = 2 +_FINISHED = 3 + +_STATE_TO_DESCRIPTION_MAP = { + _PENDING: "pending", + _RUNNING: "running", + _CANCELLED: "cancelled", + _FINISHED: "finished" +} + +class CancelledException(Exception): + pass + +class TimeoutException(Exception): + pass + +class Future(object): + def __init__(self): + self._condition = threading.Condition() + self._state = _PENDING + self._result = None + self._exception = None + + def __repr__(self): + with self._condition: + if self._state == _FINISHED: + if self._exception: + return '<Future state=%s raised %s>' % ( + _STATE_TO_DESCRIPTION_MAP[self._state], + self._exception.__class__.__name__) + else: + return '<Future state=%s returned %s>' % ( + _STATE_TO_DESCRIPTION_MAP[self._state], + self._result.__class__.__name__) + return '<Future state=%s>' % _STATE_TO_DESCRIPTION_MAP[self._state] + + def cancel(self): + with self._condition: + if self._state in [_RUNNING, _FINISHED]: + return False + + self._state = _CANCELLED + return True + + def cancelled(self): + with self._condition: + return self._state == _CANCELLED + + def done(self): + with self._condition: + return self._state in [_CANCELLED, _FINISHED] + + def __get_result(self): + if self._exception: + raise self._exception + else: + return self._result + + def result(self, timeout=None): + with self._condition: + if self._state == _CANCELLED: + raise CancelledException() + elif self._state == _FINISHED: + return self.__get_result() + + print('Waiting...') + self._condition.wait(timeout) + print('Post Waiting...') + + if self._state == _CANCELLED: + raise CancelledException() + elif self._state == _FINISHED: + return self.__get_result() + else: + raise TimeoutException() + + def exception(self, timeout=None): + with self._condition: + if self._state == _CANCELLED: + raise CancelledException() + elif self._state == _FINISHED: + return self._exception + + self._condition.wait(timeout) + + if self._state == _CANCELLED: + raise CancelledException() + elif self._state == _FINISHED: + return self._exception + else: + raise TimeoutException() + +class _NullWaitTracker(object): + def add_result(self): + pass + + def add_exception(self): + pass + + def add_cancelled(self): + pass + +class _FirstCompletedWaitTracker(object): + def __init__(self): + self.event = threading.Event() + + def add_result(self): + self.event.set() + + def add_exception(self): + self.event.set() + + def add_cancelled(self): + self.event.set() + +class _AllCompletedWaitTracker(object): + def __init__(self, pending_calls, stop_on_exception): + self.event = threading.Event() + self.pending_calls = pending_calls + self.stop_on_exception = stop_on_exception + + def add_result(self): + self.pending_calls -= 1 + if not self.pending_calls: + self.event.set() + + def add_exception(self): + self.add_result() + if self.stop_on_exception: + self.event.set() + + def add_cancelled(self): + self.add_result() + +class _ThreadEventSink(object): + def __init__(self): + self._condition = threading.Lock() + self._waiters = [] + + def add(self, e): + self._waiters.append(e) + + def add_result(self): + with self._condition: + for waiter in self._waiters: + waiter.add_result() + + def add_exception(self): + with self._condition: + for waiter in self._waiters: + waiter.add_exception() + + def add_cancelled(self): + with self._condition: + for waiter in self._waiters: + waiter.add_cancelled() + +class FutureList(object): + def __init__(self, futures, event_sink): + self._futures = futures + self._event_sink = event_sink + + def wait(self, timeout=None, run_until=ALL_COMPLETED): + with self._event_sink._condition: + print('WAIT 123') + if all(f.done() for f in self): + return + print('WAIT 1234') + + if run_until == FIRST_COMPLETED: + m = _FirstCompletedWaitTracker() + elif run_until == FIRST_EXCEPTION: + m = _AllCompletedWaitTracker(len(self), stop_on_exception=True) + elif run_until == ALL_COMPLETED: + m = _AllCompletedWaitTracker(len(self), stop_on_exception=False) + elif run_until == RETURN_IMMEDIATELY: + m = _NullWaitTracker() + else: + raise ValueError() + + self._event_sink.add(m) + + if run_until != RETURN_IMMEDIATELY: + print('WAIT 12345', timeout) + m.event.wait(timeout) + + def cancel(self, timeout=None): + for f in self: + f.cancel() + self.wait(timeout=timeout, run_until=ALL_COMPLETED) + if any(not f.done() for f in self): + raise TimeoutException() + + def has_running_futures(self): + return bool(self.running_futures()) + + def has_cancelled_futures(self): + return bool(self.cancelled_futures()) + + def has_done_futures(self): + return bool(self.done_futures()) + + def has_successful_futures(self): + return bool(self.successful_futures()) + + def has_exception_futures(self): + return bool(self.exception_futures()) + + def running_futures(self): + return [f for f in self if not f.done() and not f.cancelled()] + + def cancelled_futures(self): + return [f for f in self if f.cancelled()] + + def done_futures(self): + return [f for f in self if f.done()] + + def successful_futures(self): + return [f for f in self + if f.done() and not f.cancelled() and f.exception() is None] + + def exception_futures(self): + return [f for f in self if f.done() and f.exception() is not None] + + def __getitem__(self, i): + return self._futures[i] + + def __len__(self): + return len(self._futures) + + def __iter__(self): + return iter(self._futures) + + def __contains__(self, f): + return f in self._futures + + def __repr__(self): + return ('<FutureList #futures=%d ' + '[#success=%d #exception=%d #cancelled=%d]>' % ( + len(self), + len(self.successful_futures()), + len(self.exception_futures()), + len(self.cancelled_futures()))) + +class _WorkItem(object): + def __init__(self, call, future, completion_tracker): + self.call = call + self.future = future + self.completion_tracker = completion_tracker + + def run(self): + if self.future.cancelled(): + with self.future._condition: + self.future._condition.notify_all() + self.completion_tracker.add_cancelled() + return + + self.future._state = _RUNNING + try: + r = self.call() + except BaseException as e: + with self.future._condition: + self.future._exception = e + self.future._state = _FINISHED + self.future._condition.notify_all() + self.completion_tracker.add_exception() + else: + with self.future._condition: + self.future._result = r + self.future._state = _FINISHED + self.future._condition.notify_all() + self.completion_tracker.add_result() + +class ThreadPoolExecutor(object): + def __init__(self, max_threads): + self._max_threads = max_threads + self._work_queue = queue.Queue() + self._threads = set() + self._shutdown = False + self._lock = threading.Lock() + + def _worker(self): + try: + while True: + try: + work_item = self._work_queue.get(block=True, + timeout=0.1) + except queue.Empty: + if self._shutdown: + return + else: + work_item.run() + except BaseException as e: + print('Out e:', e) + + def _adjust_thread_count(self): + for _ in range(len(self._threads), + min(self._max_threads, self._work_queue.qsize())): + print('Creating a thread') + t = threading.Thread(target=self._worker) + t.daemon = True + t.start() + self._threads.add(t) + + def run(self, calls, timeout=None, run_until=ALL_COMPLETED): + with self._lock: + if self._shutdown: + raise RuntimeError() + + futures = [] + event_sink = _ThreadEventSink() + for call in calls: + f = Future() + w = _WorkItem(call, f, event_sink) + self._work_queue.put(w) + futures.append(f) + + print('futures:', futures) + self._adjust_thread_count() + fl = FutureList(futures, event_sink) + fl.wait(timeout=timeout, run_until=run_until) + return fl + + def runXXX(self, calls, timeout=None): + fs = self.run(calls, timeout, run_util=FIRST_EXCEPTION) + + if fs.has_exception_futures(): + raise fs.exception_futures()[0].exception() + else: + return [f.result() for f in fs] + + def shutdown(self): + with self._lock: + self._shutdown = True |