summaryrefslogtreecommitdiff
path: root/src/jinja2/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/jinja2/compiler.py')
-rw-r--r--src/jinja2/compiler.py80
1 files changed, 39 insertions, 41 deletions
diff --git a/src/jinja2/compiler.py b/src/jinja2/compiler.py
index 38d4fd3..1d73f7d 100644
--- a/src/jinja2/compiler.py
+++ b/src/jinja2/compiler.py
@@ -19,6 +19,7 @@ from .idtracking import VAR_LOAD_RESOLVE
from .idtracking import VAR_LOAD_UNDEFINED
from .nodes import EvalContext
from .optimizer import Optimizer
+from .utils import _PassArg
from .utils import concat
from .visitor import NodeVisitor
@@ -977,15 +978,11 @@ class CodeGenerator(NodeVisitor):
if node.ignore_missing:
self.outdent()
- def visit_Import(self, node, frame):
- """Visit regular imports."""
- self.writeline(f"{frame.symbols.ref(node.target)} = ", node)
- if frame.toplevel:
- self.write(f"context.vars[{node.target!r}] = ")
-
+ def _import_common(self, node, frame):
self.write(f"{self.choose_async('await ')}environment.get_template(")
self.visit(node.template, frame)
self.write(f", {self.name!r}).")
+
if node.with_context:
f_name = f"make_module{self.choose_async('_async')}"
self.write(
@@ -995,26 +992,23 @@ class CodeGenerator(NodeVisitor):
self.write("_get_default_module_async()")
else:
self.write("_get_default_module(context)")
+
+ def visit_Import(self, node, frame):
+ """Visit regular imports."""
+ self.writeline(f"{frame.symbols.ref(node.target)} = ", node)
+ if frame.toplevel:
+ self.write(f"context.vars[{node.target!r}] = ")
+
+ self._import_common(node, frame)
+
if frame.toplevel and not node.target.startswith("_"):
self.writeline(f"context.exported_vars.discard({node.target!r})")
def visit_FromImport(self, node, frame):
"""Visit named imports."""
self.newline(node)
- prefix = self.choose_async("await ")
- self.write(f"included_template = {prefix}environment.get_template(")
- self.visit(node.template, frame)
- self.write(f", {self.name!r}).")
- if node.with_context:
- f_name = f"make_module{self.choose_async('_async')}"
- self.write(
- f"{f_name}(context.get_all(), True, {self.dump_local_context(frame)})"
- )
- elif self.environment.is_async:
- self.write("_get_default_module_async()")
- else:
- self.write("_get_default_module(context)")
-
+ self.write("included_template = ")
+ self._import_common(node, frame)
var_names = []
discarded_names = []
for name in node.names:
@@ -1289,21 +1283,25 @@ class CodeGenerator(NodeVisitor):
if self.environment.finalize:
src = "environment.finalize("
env_finalize = self.environment.finalize
+ pass_arg = {
+ _PassArg.context: "context",
+ _PassArg.eval_context: "context.eval_ctx",
+ _PassArg.environment: "environment",
+ }.get(_PassArg.from_obj(env_finalize))
+ finalize = None
- def finalize(value):
- return default(env_finalize(value))
-
- if getattr(env_finalize, "contextfunction", False) is True:
- src += "context, "
- finalize = None # noqa: F811
- elif getattr(env_finalize, "evalcontextfunction", False) is True:
- src += "context.eval_ctx, "
- finalize = None
- elif getattr(env_finalize, "environmentfunction", False) is True:
- src += "environment, "
+ if pass_arg is None:
def finalize(value):
- return default(env_finalize(self.environment, value))
+ return default(env_finalize(value))
+
+ else:
+ src = f"{src}{pass_arg}, "
+
+ if pass_arg == "environment":
+
+ def finalize(value):
+ return default(env_finalize(self.environment, value))
self._finalize = self._FinalizeInfo(finalize, src)
return self._finalize
@@ -1673,13 +1671,11 @@ class CodeGenerator(NodeVisitor):
if is_filter:
compiler_map = self.filters
env_map = self.environment.filters
- type_name = mark_name = "filter"
+ type_name = "filter"
else:
compiler_map = self.tests
env_map = self.environment.tests
type_name = "test"
- # Filters use "contextfilter", tests and calls use "contextfunction".
- mark_name = "function"
if self.environment.is_async:
self.write("await auto_await(")
@@ -1693,12 +1689,14 @@ class CodeGenerator(NodeVisitor):
if func is None and not frame.soft_frame:
self.fail(f"No {type_name} named {node.name!r}.", node.lineno)
- if getattr(func, f"context{mark_name}", False) is True:
- self.write("context, ")
- elif getattr(func, f"evalcontext{mark_name}", False) is True:
- self.write("context.eval_ctx, ")
- elif getattr(func, f"environment{mark_name}", False) is True:
- self.write("environment, ")
+ pass_arg = {
+ _PassArg.context: "context",
+ _PassArg.eval_context: "context.eval_ctx",
+ _PassArg.environment: "environment",
+ }.get(_PassArg.from_obj(func))
+
+ if pass_arg is not None:
+ self.write(f"{pass_arg}, ")
# Back to the visitor function to handle visiting the target of
# the filter or test.