diff options
-rw-r--r-- | jinja2/asyncsupport.py | 41 | ||||
-rw-r--r-- | jinja2/compiler.py | 6 | ||||
-rw-r--r-- | tests/test_async.py | 11 |
3 files changed, 49 insertions, 9 deletions
diff --git a/jinja2/asyncsupport.py b/jinja2/asyncsupport.py index 534fb80..e81e83c 100644 --- a/jinja2/asyncsupport.py +++ b/jinja2/asyncsupport.py @@ -2,7 +2,16 @@ import sys import asyncio import inspect -from jinja2.utils import concat +from jinja2.utils import concat, internalcode, concat, Markup + + +async def concat_async(async_gen): + rv = [] + async def collect(): + async for event in async_gen: + rv.append(event) + await collect() + return concat(rv) async def render_async(self, *args, **kwargs): @@ -12,14 +21,9 @@ async def render_async(self, *args, **kwargs): vars = dict(*args, **kwargs) ctx = self.new_context(vars) - rv = [] - async def collect(): - async for event in self.root_render_func(ctx): - rv.append(event) try: - await collect() - return concat(rv) + return await concat_async(self.root_render_func(ctx)) except Exception: exc_info = sys.exc_info() return self.environment.handle_exception(exc_info, True) @@ -34,14 +38,37 @@ def wrap_render_func(original_render): return render +def wrap_block_reference_call(original_call): + @internalcode + async def async_call(self): + rv = await concat_async(self._stack[self._depth](self._context)) + if self._context.eval_ctx.autoescape: + rv = Markup(rv) + return rv + + @internalcode + def __call__(self): + if not self._context.environment._async: + return original_call(self) + return async_call(self) + + return __call__ + + def patch_template(): from jinja2 import Template Template.render_async = render_async Template.render = wrap_render_func(Template.render) +def patch_runtime(): + from jinja2.runtime import BlockReference + BlockReference.__call__ = wrap_block_reference_call(BlockReference.__call__) + + def patch_all(): patch_template() + patch_runtime() async def auto_await(value): diff --git a/jinja2/compiler.py b/jinja2/compiler.py index a22904a..09ad42b 100644 --- a/jinja2/compiler.py +++ b/jinja2/compiler.py @@ -887,8 +887,10 @@ class CodeGenerator(NodeVisitor): self.indent() level += 1 context = node.scoped and 'context.derived(locals())' or 'context' - self.writeline('for event in context.blocks[%r][0](%s):' % ( - node.name, context), node) + + loop = self.environment._async and 'async for' or 'for' + self.writeline('%s event in context.blocks[%r][0](%s):' % ( + loop, node.name, context), node) self.indent() self.simple_write('event', frame) self.outdent(level) diff --git a/tests/test_async.py b/tests/test_async.py index 00dfffe..fd88805 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -55,3 +55,14 @@ def test_await_and_macros(): rv = run(func) assert rv == '[42][42]' + + +@pytest.mark.skipif(not have_async_gen, reason='No async generators') +def test_async_blocks(): + t = Template('{% block foo %}<Test>{% endblock %}{{ self.foo() }}', + enable_async=True, autoescape=True) + async def func(): + return await t.render_async() + + rv = run(func) + assert rv == '<Test><Test>' |