diff options
| author | Michele Simionato <michele.simionato@gmail.com> | 2017-06-06 19:29:45 +0200 |
|---|---|---|
| committer | Michele Simionato <michele.simionato@gmail.com> | 2017-06-06 19:29:45 +0200 |
| commit | 24bc55b8c3446a417ce80b15b7cbf15b6b120ffc (patch) | |
| tree | 6eae8c6fd6f3b1b68582565cbe86ea51a4d391e1 /src | |
| parent | 238bf589eb6100f2da4efba3a9c9297adbd63bef (diff) | |
| download | python-decorator-git-24bc55b8c3446a417ce80b15b7cbf15b6b120ffc.tar.gz | |
Fixed bug in the decoration of coroutines
Diffstat (limited to 'src')
| -rw-r--r-- | src/decorator.py | 9 | ||||
| -rw-r--r-- | src/tests/test.py | 23 |
2 files changed, 19 insertions, 13 deletions
diff --git a/src/decorator.py b/src/decorator.py index bdb8e71..e8ee475 100644 --- a/src/decorator.py +++ b/src/decorator.py @@ -223,9 +223,12 @@ class FunctionMaker(object): func = obj self = cls(func, name, signature, defaults, doc, module) ibody = '\n'.join(' ' + line for line in body.splitlines()) - coro = 'async ' if self.coro else '' - return self.make(coro + 'def %(name)s(%(signature)s):\n' + ibody, - evaldict, addsource, **attrs) + if self.coro: + body = ('async def %(name)s(%(signature)s):\n' + ibody).replace( + 'return', 'return await') + else: + body = 'def %(name)s(%(signature)s):\n' + ibody + return self.make(body, evaldict, addsource, **attrs) def decorate(func, caller): diff --git a/src/tests/test.py b/src/tests/test.py index 0e31053..d882418 100644 --- a/src/tests/test.py +++ b/src/tests/test.py @@ -24,18 +24,21 @@ def assertRaises(etype): raise Exception('Expected %s' % etype.__name__) if sys.version >= '3.5': - exec('''\ -class CoroutineTestCase(unittest.TestCase): - def test(self): - async def cor(): - pass - self.assertTrue(inspect.iscoroutinefunction(cor)) + exec('''from asyncio import get_event_loop + +@decorator +async def before_after(coro, *args, **kwargs): + return "<before>" + (await coro(*args, **kwargs)) + "<after>" - @decorator - def identity(f, *args, **kwargs): - return f(*args, **kwargs) - self.assertTrue(inspect.iscoroutinefunction(identity(cor))) +class CoroutineTestCase(unittest.TestCase): + def test(self): + @before_after + async def coro(x): + return x + self.assertTrue(inspect.iscoroutinefunction(coro)) + out = get_event_loop().run_until_complete(coro('x')) + self.assertEqual(out, '<before>x<after>') ''') |
