summaryrefslogtreecommitdiff
path: root/Lib
diff options
context:
space:
mode:
authorItamar Ostricher <itamarost@gmail.com>2023-05-01 14:10:13 -0700
committerGitHub <noreply@github.com>2023-05-01 15:10:13 -0600
commita474e04388c2ef6aca75c26cb70a1b6200235feb (patch)
tree43520d5ad16016620f149dc1e84d4d57e45051d5 /Lib
parent59bc36aacddd5a3acd32c80c0dfd0726135a7817 (diff)
downloadcpython-git-a474e04388c2ef6aca75c26cb70a1b6200235feb.tar.gz
gh-97696: asyncio eager tasks factory (#102853)
Co-authored-by: Jacob Bower <jbower@meta.com> Co-authored-by: Carol Willing <carolcode@willingconsulting.com>
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/base_tasks.py10
-rw-r--r--Lib/asyncio/tasks.py122
-rw-r--r--Lib/test/test_asyncio/test_eager_task_factory.py344
3 files changed, 450 insertions, 26 deletions
diff --git a/Lib/asyncio/base_tasks.py b/Lib/asyncio/base_tasks.py
index 26298e638c..c907b68341 100644
--- a/Lib/asyncio/base_tasks.py
+++ b/Lib/asyncio/base_tasks.py
@@ -15,11 +15,13 @@ def _task_repr_info(task):
info.insert(1, 'name=%r' % task.get_name())
- coro = coroutines._format_coroutine(task._coro)
- info.insert(2, f'coro=<{coro}>')
-
if task._fut_waiter is not None:
- info.insert(3, f'wait_for={task._fut_waiter!r}')
+ info.insert(2, f'wait_for={task._fut_waiter!r}')
+
+ if task._coro:
+ coro = coroutines._format_coroutine(task._coro)
+ info.insert(2, f'coro=<{coro}>')
+
return info
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
index c90d32c97a..aa5269ade1 100644
--- a/Lib/asyncio/tasks.py
+++ b/Lib/asyncio/tasks.py
@@ -6,6 +6,7 @@ __all__ = (
'wait', 'wait_for', 'as_completed', 'sleep',
'gather', 'shield', 'ensure_future', 'run_coroutine_threadsafe',
'current_task', 'all_tasks',
+ 'create_eager_task_factory', 'eager_task_factory',
'_register_task', '_unregister_task', '_enter_task', '_leave_task',
)
@@ -43,22 +44,26 @@ def all_tasks(loop=None):
"""Return a set of all tasks for the loop."""
if loop is None:
loop = events.get_running_loop()
- # Looping over a WeakSet (_all_tasks) isn't safe as it can be updated from another
- # thread while we do so. Therefore we cast it to list prior to filtering. The list
- # cast itself requires iteration, so we repeat it several times ignoring
- # RuntimeErrors (which are not very likely to occur). See issues 34970 and 36607 for
- # details.
+ # capturing the set of eager tasks first, so if an eager task "graduates"
+ # to a regular task in another thread, we don't risk missing it.
+ eager_tasks = list(_eager_tasks)
+ # Looping over the WeakSet isn't safe as it can be updated from another
+ # thread, therefore we cast it to list prior to filtering. The list cast
+ # itself requires iteration, so we repeat it several times ignoring
+ # RuntimeErrors (which are not very likely to occur).
+ # See issues 34970 and 36607 for details.
+ scheduled_tasks = None
i = 0
while True:
try:
- tasks = list(_all_tasks)
+ scheduled_tasks = list(_scheduled_tasks)
except RuntimeError:
i += 1
if i >= 1000:
raise
else:
break
- return {t for t in tasks
+ return {t for t in itertools.chain(scheduled_tasks, eager_tasks)
if futures._get_loop(t) is loop and not t.done()}
@@ -93,7 +98,8 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
# status is still pending
_log_destroy_pending = True
- def __init__(self, coro, *, loop=None, name=None, context=None):
+ def __init__(self, coro, *, loop=None, name=None, context=None,
+ eager_start=False):
super().__init__(loop=loop)
if self._source_traceback:
del self._source_traceback[-1]
@@ -117,8 +123,11 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
else:
self._context = context
- self._loop.call_soon(self.__step, context=self._context)
- _register_task(self)
+ if eager_start and self._loop.is_running():
+ self.__eager_start()
+ else:
+ self._loop.call_soon(self.__step, context=self._context)
+ _register_task(self)
def __del__(self):
if self._state == futures._PENDING and self._log_destroy_pending:
@@ -250,6 +259,25 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
self._num_cancels_requested -= 1
return self._num_cancels_requested
+ def __eager_start(self):
+ prev_task = _swap_current_task(self._loop, self)
+ try:
+ _register_eager_task(self)
+ try:
+ self._context.run(self.__step_run_and_handle_result, None)
+ finally:
+ _unregister_eager_task(self)
+ finally:
+ try:
+ curtask = _swap_current_task(self._loop, prev_task)
+ assert curtask is self
+ finally:
+ if self.done():
+ self._coro = None
+ self = None # Needed to break cycles when an exception occurs.
+ else:
+ _register_task(self)
+
def __step(self, exc=None):
if self.done():
raise exceptions.InvalidStateError(
@@ -258,11 +286,17 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
if not isinstance(exc, exceptions.CancelledError):
exc = self._make_cancelled_error()
self._must_cancel = False
- coro = self._coro
self._fut_waiter = None
_enter_task(self._loop, self)
- # Call either coro.throw(exc) or coro.send(None).
+ try:
+ self.__step_run_and_handle_result(exc)
+ finally:
+ _leave_task(self._loop, self)
+ self = None # Needed to break cycles when an exception occurs.
+
+ def __step_run_and_handle_result(self, exc):
+ coro = self._coro
try:
if exc is None:
# We use the `send` method directly, because coroutines
@@ -334,7 +368,6 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
self._loop.call_soon(
self.__step, new_exc, context=self._context)
finally:
- _leave_task(self._loop, self)
self = None # Needed to break cycles when an exception occurs.
def __wakeup(self, future):
@@ -897,8 +930,27 @@ def run_coroutine_threadsafe(coro, loop):
return future
-# WeakSet containing all alive tasks.
-_all_tasks = weakref.WeakSet()
+def create_eager_task_factory(custom_task_constructor):
+
+ if "eager_start" not in inspect.signature(custom_task_constructor).parameters:
+ raise TypeError(
+ "Provided constructor does not support eager task execution")
+
+ def factory(loop, coro, *, name=None, context=None):
+ return custom_task_constructor(
+ coro, loop=loop, name=name, context=context, eager_start=True)
+
+
+ return factory
+
+eager_task_factory = create_eager_task_factory(Task)
+
+
+# Collectively these two sets hold references to the complete set of active
+# tasks. Eagerly executed tasks use a faster regular set as an optimization
+# but may graduate to a WeakSet if the task blocks on IO.
+_scheduled_tasks = weakref.WeakSet()
+_eager_tasks = set()
# Dictionary containing tasks that are currently active in
# all running event loops. {EventLoop: Task}
@@ -906,8 +958,13 @@ _current_tasks = {}
def _register_task(task):
- """Register a new task in asyncio as executed by loop."""
- _all_tasks.add(task)
+ """Register an asyncio Task scheduled to run on an event loop."""
+ _scheduled_tasks.add(task)
+
+
+def _register_eager_task(task):
+ """Register an asyncio Task about to be eagerly executed."""
+ _eager_tasks.add(task)
def _enter_task(loop, task):
@@ -926,28 +983,49 @@ def _leave_task(loop, task):
del _current_tasks[loop]
+def _swap_current_task(loop, task):
+ prev_task = _current_tasks.get(loop)
+ if task is None:
+ del _current_tasks[loop]
+ else:
+ _current_tasks[loop] = task
+ return prev_task
+
+
def _unregister_task(task):
- """Unregister a task."""
- _all_tasks.discard(task)
+ """Unregister a completed, scheduled Task."""
+ _scheduled_tasks.discard(task)
+
+
+def _unregister_eager_task(task):
+ """Unregister a task which finished its first eager step."""
+ _eager_tasks.discard(task)
_py_current_task = current_task
_py_register_task = _register_task
+_py_register_eager_task = _register_eager_task
_py_unregister_task = _unregister_task
+_py_unregister_eager_task = _unregister_eager_task
_py_enter_task = _enter_task
_py_leave_task = _leave_task
+_py_swap_current_task = _swap_current_task
try:
- from _asyncio import (_register_task, _unregister_task,
- _enter_task, _leave_task,
- _all_tasks, _current_tasks,
+ from _asyncio import (_register_task, _register_eager_task,
+ _unregister_task, _unregister_eager_task,
+ _enter_task, _leave_task, _swap_current_task,
+ _scheduled_tasks, _eager_tasks, _current_tasks,
current_task)
except ImportError:
pass
else:
_c_current_task = current_task
_c_register_task = _register_task
+ _c_register_eager_task = _register_eager_task
_c_unregister_task = _unregister_task
+ _c_unregister_eager_task = _unregister_eager_task
_c_enter_task = _enter_task
_c_leave_task = _leave_task
+ _c_swap_current_task = _swap_current_task
diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py
new file mode 100644
index 0000000000..fe69093429
--- /dev/null
+++ b/Lib/test/test_asyncio/test_eager_task_factory.py
@@ -0,0 +1,344 @@
+"""Tests for base_events.py"""
+
+import asyncio
+import contextvars
+import gc
+import time
+import unittest
+
+from types import GenericAlias
+from unittest import mock
+from asyncio import base_events
+from asyncio import tasks
+from test.test_asyncio import utils as test_utils
+from test.test_asyncio.test_tasks import get_innermost_context
+from test import support
+
+MOCK_ANY = mock.ANY
+
+
+def tearDownModule():
+ asyncio.set_event_loop_policy(None)
+
+
+class EagerTaskFactoryLoopTests:
+
+ Task = None
+
+ def run_coro(self, coro):
+ """
+ Helper method to run the `coro` coroutine in the test event loop.
+ It helps with making sure the event loop is running before starting
+ to execute `coro`. This is important for testing the eager step
+ functionality, since an eager step is taken only if the event loop
+ is already running.
+ """
+
+ async def coro_runner():
+ self.assertTrue(asyncio.get_event_loop().is_running())
+ return await coro
+
+ return self.loop.run_until_complete(coro)
+
+ def setUp(self):
+ super().setUp()
+ self.loop = asyncio.new_event_loop()
+ self.eager_task_factory = asyncio.create_eager_task_factory(self.Task)
+ self.loop.set_task_factory(self.eager_task_factory)
+ self.set_event_loop(self.loop)
+
+ def test_eager_task_factory_set(self):
+ self.assertIsNotNone(self.eager_task_factory)
+ self.assertIs(self.loop.get_task_factory(), self.eager_task_factory)
+
+ async def noop(): pass
+
+ async def run():
+ t = self.loop.create_task(noop())
+ self.assertIsInstance(t, self.Task)
+ await t
+
+ self.run_coro(run())
+
+ def test_await_future_during_eager_step(self):
+
+ async def set_result(fut, val):
+ fut.set_result(val)
+
+ async def run():
+ fut = self.loop.create_future()
+ t = self.loop.create_task(set_result(fut, 'my message'))
+ # assert the eager step completed the task
+ self.assertTrue(t.done())
+ return await fut
+
+ self.assertEqual(self.run_coro(run()), 'my message')
+
+ def test_eager_completion(self):
+
+ async def coro():
+ return 'hello'
+
+ async def run():
+ t = self.loop.create_task(coro())
+ # assert the eager step completed the task
+ self.assertTrue(t.done())
+ return await t
+
+ self.assertEqual(self.run_coro(run()), 'hello')
+
+ def test_block_after_eager_step(self):
+
+ async def coro():
+ await asyncio.sleep(0.1)
+ return 'finished after blocking'
+
+ async def run():
+ t = self.loop.create_task(coro())
+ self.assertFalse(t.done())
+ result = await t
+ self.assertTrue(t.done())
+ return result
+
+ self.assertEqual(self.run_coro(run()), 'finished after blocking')
+
+ def test_cancellation_after_eager_completion(self):
+
+ async def coro():
+ return 'finished without blocking'
+
+ async def run():
+ t = self.loop.create_task(coro())
+ t.cancel()
+ result = await t
+ # finished task can't be cancelled
+ self.assertFalse(t.cancelled())
+ return result
+
+ self.assertEqual(self.run_coro(run()), 'finished without blocking')
+
+ def test_cancellation_after_eager_step_blocks(self):
+
+ async def coro():
+ await asyncio.sleep(0.1)
+ return 'finished after blocking'
+
+ async def run():
+ t = self.loop.create_task(coro())
+ t.cancel('cancellation message')
+ self.assertGreater(t.cancelling(), 0)
+ result = await t
+
+ with self.assertRaises(asyncio.CancelledError) as cm:
+ self.run_coro(run())
+
+ self.assertEqual('cancellation message', cm.exception.args[0])
+
+ def test_current_task(self):
+ captured_current_task = None
+
+ async def coro():
+ nonlocal captured_current_task
+ captured_current_task = asyncio.current_task()
+ # verify the task before and after blocking is identical
+ await asyncio.sleep(0.1)
+ self.assertIs(asyncio.current_task(), captured_current_task)
+
+ async def run():
+ t = self.loop.create_task(coro())
+ self.assertIs(captured_current_task, t)
+ await t
+
+ self.run_coro(run())
+ captured_current_task = None
+
+ def test_all_tasks_with_eager_completion(self):
+ captured_all_tasks = None
+
+ async def coro():
+ nonlocal captured_all_tasks
+ captured_all_tasks = asyncio.all_tasks()
+
+ async def run():
+ t = self.loop.create_task(coro())
+ self.assertIn(t, captured_all_tasks)
+ self.assertNotIn(t, asyncio.all_tasks())
+
+ self.run_coro(run())
+
+ def test_all_tasks_with_blocking(self):
+ captured_eager_all_tasks = None
+
+ async def coro(fut1, fut2):
+ nonlocal captured_eager_all_tasks
+ captured_eager_all_tasks = asyncio.all_tasks()
+ await fut1
+ fut2.set_result(None)
+
+ async def run():
+ fut1 = self.loop.create_future()
+ fut2 = self.loop.create_future()
+ t = self.loop.create_task(coro(fut1, fut2))
+ self.assertIn(t, captured_eager_all_tasks)
+ self.assertIn(t, asyncio.all_tasks())
+ fut1.set_result(None)
+ await fut2
+ self.assertNotIn(t, asyncio.all_tasks())
+
+ self.run_coro(run())
+
+ def test_context_vars(self):
+ cv = contextvars.ContextVar('cv', default=0)
+
+ coro_first_step_ran = False
+ coro_second_step_ran = False
+
+ async def coro():
+ nonlocal coro_first_step_ran
+ nonlocal coro_second_step_ran
+ self.assertEqual(cv.get(), 1)
+ cv.set(2)
+ self.assertEqual(cv.get(), 2)
+ coro_first_step_ran = True
+ await asyncio.sleep(0.1)
+ self.assertEqual(cv.get(), 2)
+ cv.set(3)
+ self.assertEqual(cv.get(), 3)
+ coro_second_step_ran = True
+
+ async def run():
+ cv.set(1)
+ t = self.loop.create_task(coro())
+ self.assertTrue(coro_first_step_ran)
+ self.assertFalse(coro_second_step_ran)
+ self.assertEqual(cv.get(), 1)
+ await t
+ self.assertTrue(coro_second_step_ran)
+ self.assertEqual(cv.get(), 1)
+
+ self.run_coro(run())
+
+
+class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
+ Task = tasks._PyTask
+
+
+@unittest.skipUnless(hasattr(tasks, '_CTask'),
+ 'requires the C _asyncio module')
+class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
+ Task = getattr(tasks, '_CTask', None)
+
+
+class AsyncTaskCounter:
+ def __init__(self, loop, *, task_class, eager):
+ self.suspense_count = 0
+ self.task_count = 0
+
+ def CountingTask(*args, eager_start=False, **kwargs):
+ if not eager_start:
+ self.task_count += 1
+ kwargs["eager_start"] = eager_start
+ return task_class(*args, **kwargs)
+
+ if eager:
+ factory = asyncio.create_eager_task_factory(CountingTask)
+ else:
+ def factory(loop, coro, **kwargs):
+ return CountingTask(coro, loop=loop, **kwargs)
+ loop.set_task_factory(factory)
+
+ def get(self):
+ return self.task_count
+
+
+async def awaitable_chain(depth):
+ if depth == 0:
+ return 0
+ return 1 + await awaitable_chain(depth - 1)
+
+
+async def recursive_taskgroups(width, depth):
+ if depth == 0:
+ return
+
+ async with asyncio.TaskGroup() as tg:
+ futures = [
+ tg.create_task(recursive_taskgroups(width, depth - 1))
+ for _ in range(width)
+ ]
+
+
+async def recursive_gather(width, depth):
+ if depth == 0:
+ return
+
+ await asyncio.gather(
+ *[recursive_gather(width, depth - 1) for _ in range(width)]
+ )
+
+
+class BaseTaskCountingTests:
+
+ Task = None
+ eager = None
+ expected_task_count = None
+
+ def setUp(self):
+ super().setUp()
+ self.loop = asyncio.new_event_loop()
+ self.counter = AsyncTaskCounter(self.loop, task_class=self.Task, eager=self.eager)
+ self.set_event_loop(self.loop)
+
+ def test_awaitables_chain(self):
+ observed_depth = self.loop.run_until_complete(awaitable_chain(100))
+ self.assertEqual(observed_depth, 100)
+ self.assertEqual(self.counter.get(), 0 if self.eager else 1)
+
+ def test_recursive_taskgroups(self):
+ num_tasks = self.loop.run_until_complete(recursive_taskgroups(5, 4))
+ self.assertEqual(self.counter.get(), self.expected_task_count)
+
+ def test_recursive_gather(self):
+ self.loop.run_until_complete(recursive_gather(5, 4))
+ self.assertEqual(self.counter.get(), self.expected_task_count)
+
+
+class BaseNonEagerTaskFactoryTests(BaseTaskCountingTests):
+ eager = False
+ expected_task_count = 781 # 1 + 5 + 5^2 + 5^3 + 5^4
+
+
+class BaseEagerTaskFactoryTests(BaseTaskCountingTests):
+ eager = True
+ expected_task_count = 0
+
+
+class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
+ Task = asyncio.Task
+
+
+class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
+ Task = asyncio.Task
+
+
+class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
+ Task = tasks._PyTask
+
+
+class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
+ Task = tasks._PyTask
+
+
+@unittest.skipUnless(hasattr(tasks, '_CTask'),
+ 'requires the C _asyncio module')
+class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
+ Task = getattr(tasks, '_CTask', None)
+
+
+@unittest.skipUnless(hasattr(tasks, '_CTask'),
+ 'requires the C _asyncio module')
+class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
+ Task = getattr(tasks, '_CTask', None)
+
+if __name__ == '__main__':
+ unittest.main()