summaryrefslogtreecommitdiff
path: root/scheduling.py
diff options
context:
space:
mode:
Diffstat (limited to 'scheduling.py')
-rw-r--r--scheduling.py64
1 files changed, 33 insertions, 31 deletions
diff --git a/scheduling.py b/scheduling.py
index 9c876fb..82840b6 100644
--- a/scheduling.py
+++ b/scheduling.py
@@ -194,8 +194,9 @@ def run():
def sleep(secs):
"""COROUTINE: Sleep for some time (a float in seconds)."""
- context.current_task.block()
- context.eventloop.call_later(secs, self.unblock)
+ current_task = context.current_task
+ current_task.block()
+ context.eventloop.call_later(secs, current_task.unblock)
yield
@@ -221,44 +222,45 @@ def call_in_thread(func, *args, executor=None):
return future.result()
-def wait_any(tasks):
- """COROUTINE: Wait for the first of a set of tasks to complete."""
- assert tasks
- current_task = context.current_task
- assert all(task is not current_task for task in tasks)
- for task in tasks:
- if not task.alive:
- return task
- winner = None
- def wait_any_callback(task):
- nonlocal winner, current_task
- if not winner:
- winner = task
- current_task.unblock()
- # TODO: Avoid adding N callbacks.
- for task in tasks:
- task.add_done_callback(wait_any_callback)
- current_task.block()
- yield
- return winner
+def wait_for(count, tasks):
+ """COROUTINE: Wait for the first N of a set of tasks to complete.
+ May return more than N if more than N are immediately ready.
-def wait_all(tasks):
- """COROUTINE: Wait for all of a set of tasks to complete."""
+ NOTE: Tasks that were cancelled or raised are also considered ready.
+ """
assert tasks
+ tasks = set(tasks)
+ assert 1 <= count <= len(tasks)
current_task = context.current_task
assert all(task is not current_task for task in tasks)
todo = set()
- def wait_all_callback(task):
- nonlocal todo, current_task
+ done = set()
+ def wait_for_callback(task):
+ nonlocal todo, done, current_task, count
todo.remove(task)
- if not todo:
- current_task.unblock()
+ if len(done) < count:
+ done.add(task)
+ if len(done) == count:
+ current_task.unblock()
for task in tasks:
if task.alive:
todo.add(task)
- task.add_done_callback(wait_all_callback)
- if todo:
+ else:
+ done.add(task)
+ if len(done) < count:
+ for task in todo:
+ task.add_done_callback(wait_for_callback)
current_task.block()
yield
- return tasks # Not redundant: handy if called with a comprehension.
+ return done
+
+
+def wait_any(tasks):
+ """COROUTINE: Wait for the first of a set of tasks to complete."""
+ return wait_for(1, tasks)
+
+
+def wait_all(tasks):
+ """COROUTINE: Wait for all of a set of tasks to complete."""
+ return wait_for(len(tasks), tasks)