diff options
author | Yury Selivanov <yselivanov@sprymix.com> | 2015-05-11 12:30:48 -0400 |
---|---|---|
committer | Yury Selivanov <yselivanov@sprymix.com> | 2015-05-11 12:30:48 -0400 |
commit | 71f7c249efc8a97e7e06d25c65ae96a2c001a6b3 (patch) | |
tree | 32baece7d1400024a1f2b305a4ad70a6a0ed9e0d | |
parent | 2798fb43af22c966a0c7ba15258a073cf651a3c2 (diff) | |
download | trollius-git-71f7c249efc8a97e7e06d25c65ae96a2c001a6b3.tar.gz |
Add new loop APIs: set_task_factory() and get_task_factory()
-rw-r--r-- | asyncio/base_events.py | 28 | ||||
-rw-r--r-- | asyncio/events.py | 8 | ||||
-rw-r--r-- | tests/test_base_events.py | 36 |
3 files changed, 69 insertions, 3 deletions
diff --git a/asyncio/base_events.py b/asyncio/base_events.py index bfa435c..efbb9f4 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -197,6 +197,7 @@ class BaseEventLoop(events.AbstractEventLoop): # exceed this duration in seconds, the slow callback/task is logged. self.slow_callback_duration = 0.1 self._current_handle = None + self._task_factory = None def __repr__(self): return ('<%s running=%s closed=%s debug=%s>' @@ -209,11 +210,32 @@ class BaseEventLoop(events.AbstractEventLoop): Return a task object. """ self._check_closed() - task = tasks.Task(coro, loop=self) - if task._source_traceback: - del task._source_traceback[-1] + if self._task_factory is None: + task = tasks.Task(coro, loop=self) + if task._source_traceback: + del task._source_traceback[-1] + else: + task = self._task_factory(self, coro) return task + def set_task_factory(self, factory): + """Set a task factory that will be used by loop.create_task(). + + If factory is None the default task factory will be set. + + If factory is a callable, it should have a signature matching + '(loop, coro)', where 'loop' will be a reference to the active + event loop, 'coro' will be a coroutine object. The callable + must return a Future. + """ + if factory is not None and not callable(factory): + raise TypeError('task factory must be a callable or None') + self._task_factory = factory + + def get_task_factory(self): + """Return a task factory, or None if the default one is in use.""" + return self._task_factory + def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): """Create socket transport.""" diff --git a/asyncio/events.py b/asyncio/events.py index 99e12e6..496075b 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -438,6 +438,14 @@ class AbstractEventLoop: def remove_signal_handler(self, sig): raise NotImplementedError + # Task factory. + + def set_task_factory(self, factory): + raise NotImplementedError + + def get_task_factory(self): + raise NotImplementedError + # Error handlers. def set_exception_handler(self, handler): diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 9e7c50c..af6a4c3 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -623,6 +623,42 @@ class BaseEventLoopTests(test_utils.TestCase): self.assertIs(type(_context['context']['exception']), ZeroDivisionError) + def test_set_task_factory_invalid(self): + with self.assertRaisesRegex( + TypeError, 'task factory must be a callable or None'): + + self.loop.set_task_factory(1) + + self.assertIsNone(self.loop.get_task_factory()) + + def test_set_task_factory(self): + self.loop._process_events = mock.Mock() + + class MyTask(asyncio.Task): + pass + + @asyncio.coroutine + def coro(): + pass + + factory = lambda loop, coro: MyTask(coro, loop=loop) + + self.assertIsNone(self.loop.get_task_factory()) + self.loop.set_task_factory(factory) + self.assertIs(self.loop.get_task_factory(), factory) + + task = self.loop.create_task(coro()) + self.assertTrue(isinstance(task, MyTask)) + self.loop.run_until_complete(task) + + self.loop.set_task_factory(None) + self.assertIsNone(self.loop.get_task_factory()) + + task = self.loop.create_task(coro()) + self.assertTrue(isinstance(task, asyncio.Task)) + self.assertFalse(isinstance(task, MyTask)) + self.loop.run_until_complete(task) + def test_env_var_debug(self): code = '\n'.join(( 'import asyncio', |