summaryrefslogtreecommitdiff
path: root/tests/test_ext.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_ext.py')
-rw-r--r--tests/test_ext.py25
1 files changed, 21 insertions, 4 deletions
diff --git a/tests/test_ext.py b/tests/test_ext.py
index 1c349bd..7975271 100644
--- a/tests/test_ext.py
+++ b/tests/test_ext.py
@@ -109,24 +109,30 @@ newstyle_i18n_env.install_gettext_callables(gettext, ngettext, newstyle=True)
class ExampleExtension(Extension):
tags = set(['test'])
ext_attr = 42
+ context_reference_node_cls = nodes.ContextReference
def parse(self, parser):
return nodes.Output([self.call_method('_dump', [
nodes.EnvironmentAttribute('sandboxed'),
self.attr('ext_attr'),
nodes.ImportedName(__name__ + '.importable_object'),
- nodes.ContextReference()
+ self.context_reference_node_cls()
])]).set_lineno(next(parser.stream).lineno)
def _dump(self, sandboxed, ext_attr, imported_object, context):
- return '%s|%s|%s|%s' % (
+ return '%s|%s|%s|%s|%s' % (
sandboxed,
ext_attr,
imported_object,
- context.blocks
+ context.blocks,
+ context.get('test_var')
)
+class DerivedExampleExtension(ExampleExtension):
+ context_reference_node_cls = nodes.DerivedContextReference
+
+
class PreprocessorExtension(Extension):
def preprocess(self, source, name, filename=None):
@@ -205,7 +211,18 @@ class TestExtensions(object):
def test_extension_nodes(self):
env = Environment(extensions=[ExampleExtension])
tmpl = env.from_string('{% test %}')
- assert tmpl.render() == 'False|42|23|{}'
+ assert tmpl.render() == 'False|42|23|{}|None'
+
+ def test_contextreference_node_passes_context(self):
+ env = Environment(extensions=[ExampleExtension])
+ tmpl = env.from_string('{% set test_var="test_content" %}{% test %}')
+ assert tmpl.render() == 'False|42|23|{}|test_content'
+
+ def test_contextreference_node_can_pass_locals(self):
+ env = Environment(extensions=[DerivedExampleExtension])
+ tmpl = env.from_string(
+ '{% for test_var in ["test_content"] %}{% test %}{% endfor %}')
+ assert tmpl.render() == 'False|42|23|{}|test_content'
def test_identifier(self):
assert ExampleExtension.identifier == __name__ + '.ExampleExtension'