summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYury Selivanov <yselivanov@sprymix.com>2015-05-11 12:30:48 -0400
committerYury Selivanov <yselivanov@sprymix.com>2015-05-11 12:30:48 -0400
commit71f7c249efc8a97e7e06d25c65ae96a2c001a6b3 (patch)
tree32baece7d1400024a1f2b305a4ad70a6a0ed9e0d
parent2798fb43af22c966a0c7ba15258a073cf651a3c2 (diff)
downloadtrollius-git-71f7c249efc8a97e7e06d25c65ae96a2c001a6b3.tar.gz
Add new loop APIs: set_task_factory() and get_task_factory()
-rw-r--r--asyncio/base_events.py28
-rw-r--r--asyncio/events.py8
-rw-r--r--tests/test_base_events.py36
3 files changed, 69 insertions, 3 deletions
diff --git a/asyncio/base_events.py b/asyncio/base_events.py
index bfa435c..efbb9f4 100644
--- a/asyncio/base_events.py
+++ b/asyncio/base_events.py
@@ -197,6 +197,7 @@ class BaseEventLoop(events.AbstractEventLoop):
# exceed this duration in seconds, the slow callback/task is logged.
self.slow_callback_duration = 0.1
self._current_handle = None
+ self._task_factory = None
def __repr__(self):
return ('<%s running=%s closed=%s debug=%s>'
@@ -209,11 +210,32 @@ class BaseEventLoop(events.AbstractEventLoop):
Return a task object.
"""
self._check_closed()
- task = tasks.Task(coro, loop=self)
- if task._source_traceback:
- del task._source_traceback[-1]
+ if self._task_factory is None:
+ task = tasks.Task(coro, loop=self)
+ if task._source_traceback:
+ del task._source_traceback[-1]
+ else:
+ task = self._task_factory(self, coro)
return task
+ def set_task_factory(self, factory):
+ """Set a task factory that will be used by loop.create_task().
+
+ If factory is None the default task factory will be set.
+
+ If factory is a callable, it should have a signature matching
+ '(loop, coro)', where 'loop' will be a reference to the active
+ event loop, 'coro' will be a coroutine object. The callable
+ must return a Future.
+ """
+ if factory is not None and not callable(factory):
+ raise TypeError('task factory must be a callable or None')
+ self._task_factory = factory
+
+ def get_task_factory(self):
+ """Return a task factory, or None if the default one is in use."""
+ return self._task_factory
+
def _make_socket_transport(self, sock, protocol, waiter=None, *,
extra=None, server=None):
"""Create socket transport."""
diff --git a/asyncio/events.py b/asyncio/events.py
index 99e12e6..496075b 100644
--- a/asyncio/events.py
+++ b/asyncio/events.py
@@ -438,6 +438,14 @@ class AbstractEventLoop:
def remove_signal_handler(self, sig):
raise NotImplementedError
+ # Task factory.
+
+ def set_task_factory(self, factory):
+ raise NotImplementedError
+
+ def get_task_factory(self):
+ raise NotImplementedError
+
# Error handlers.
def set_exception_handler(self, handler):
diff --git a/tests/test_base_events.py b/tests/test_base_events.py
index 9e7c50c..af6a4c3 100644
--- a/tests/test_base_events.py
+++ b/tests/test_base_events.py
@@ -623,6 +623,42 @@ class BaseEventLoopTests(test_utils.TestCase):
self.assertIs(type(_context['context']['exception']),
ZeroDivisionError)
+ def test_set_task_factory_invalid(self):
+ with self.assertRaisesRegex(
+ TypeError, 'task factory must be a callable or None'):
+
+ self.loop.set_task_factory(1)
+
+ self.assertIsNone(self.loop.get_task_factory())
+
+ def test_set_task_factory(self):
+ self.loop._process_events = mock.Mock()
+
+ class MyTask(asyncio.Task):
+ pass
+
+ @asyncio.coroutine
+ def coro():
+ pass
+
+ factory = lambda loop, coro: MyTask(coro, loop=loop)
+
+ self.assertIsNone(self.loop.get_task_factory())
+ self.loop.set_task_factory(factory)
+ self.assertIs(self.loop.get_task_factory(), factory)
+
+ task = self.loop.create_task(coro())
+ self.assertTrue(isinstance(task, MyTask))
+ self.loop.run_until_complete(task)
+
+ self.loop.set_task_factory(None)
+ self.assertIsNone(self.loop.get_task_factory())
+
+ task = self.loop.create_task(coro())
+ self.assertTrue(isinstance(task, asyncio.Task))
+ self.assertFalse(isinstance(task, MyTask))
+ self.loop.run_until_complete(task)
+
def test_env_var_debug(self):
code = '\n'.join((
'import asyncio',