summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMichele Simionato <michele.simionato@gmail.com>2017-06-06 19:29:45 +0200
committerMichele Simionato <michele.simionato@gmail.com>2017-06-06 19:29:45 +0200
commit24bc55b8c3446a417ce80b15b7cbf15b6b120ffc (patch)
tree6eae8c6fd6f3b1b68582565cbe86ea51a4d391e1 /src
parent238bf589eb6100f2da4efba3a9c9297adbd63bef (diff)
downloadpython-decorator-git-24bc55b8c3446a417ce80b15b7cbf15b6b120ffc.tar.gz
Fixed bug in the decoration of coroutines
Diffstat (limited to 'src')
-rw-r--r--src/decorator.py9
-rw-r--r--src/tests/test.py23
2 files changed, 19 insertions, 13 deletions
diff --git a/src/decorator.py b/src/decorator.py
index bdb8e71..e8ee475 100644
--- a/src/decorator.py
+++ b/src/decorator.py
@@ -223,9 +223,12 @@ class FunctionMaker(object):
func = obj
self = cls(func, name, signature, defaults, doc, module)
ibody = '\n'.join(' ' + line for line in body.splitlines())
- coro = 'async ' if self.coro else ''
- return self.make(coro + 'def %(name)s(%(signature)s):\n' + ibody,
- evaldict, addsource, **attrs)
+ if self.coro:
+ body = ('async def %(name)s(%(signature)s):\n' + ibody).replace(
+ 'return', 'return await')
+ else:
+ body = 'def %(name)s(%(signature)s):\n' + ibody
+ return self.make(body, evaldict, addsource, **attrs)
def decorate(func, caller):
diff --git a/src/tests/test.py b/src/tests/test.py
index 0e31053..d882418 100644
--- a/src/tests/test.py
+++ b/src/tests/test.py
@@ -24,18 +24,21 @@ def assertRaises(etype):
raise Exception('Expected %s' % etype.__name__)
if sys.version >= '3.5':
- exec('''\
-class CoroutineTestCase(unittest.TestCase):
- def test(self):
- async def cor():
- pass
- self.assertTrue(inspect.iscoroutinefunction(cor))
+ exec('''from asyncio import get_event_loop
+
+@decorator
+async def before_after(coro, *args, **kwargs):
+ return "<before>" + (await coro(*args, **kwargs)) + "<after>"
- @decorator
- def identity(f, *args, **kwargs):
- return f(*args, **kwargs)
- self.assertTrue(inspect.iscoroutinefunction(identity(cor)))
+class CoroutineTestCase(unittest.TestCase):
+ def test(self):
+ @before_after
+ async def coro(x):
+ return x
+ self.assertTrue(inspect.iscoroutinefunction(coro))
+ out = get_event_loop().run_until_complete(coro('x'))
+ self.assertEqual(out, '<before>x<after>')
''')