diff options
Diffstat (limited to 'src/jinja2/compiler.py')
-rw-r--r-- | src/jinja2/compiler.py | 80 |
1 files changed, 39 insertions, 41 deletions
diff --git a/src/jinja2/compiler.py b/src/jinja2/compiler.py index 38d4fd3..1d73f7d 100644 --- a/src/jinja2/compiler.py +++ b/src/jinja2/compiler.py @@ -19,6 +19,7 @@ from .idtracking import VAR_LOAD_RESOLVE from .idtracking import VAR_LOAD_UNDEFINED from .nodes import EvalContext from .optimizer import Optimizer +from .utils import _PassArg from .utils import concat from .visitor import NodeVisitor @@ -977,15 +978,11 @@ class CodeGenerator(NodeVisitor): if node.ignore_missing: self.outdent() - def visit_Import(self, node, frame): - """Visit regular imports.""" - self.writeline(f"{frame.symbols.ref(node.target)} = ", node) - if frame.toplevel: - self.write(f"context.vars[{node.target!r}] = ") - + def _import_common(self, node, frame): self.write(f"{self.choose_async('await ')}environment.get_template(") self.visit(node.template, frame) self.write(f", {self.name!r}).") + if node.with_context: f_name = f"make_module{self.choose_async('_async')}" self.write( @@ -995,26 +992,23 @@ class CodeGenerator(NodeVisitor): self.write("_get_default_module_async()") else: self.write("_get_default_module(context)") + + def visit_Import(self, node, frame): + """Visit regular imports.""" + self.writeline(f"{frame.symbols.ref(node.target)} = ", node) + if frame.toplevel: + self.write(f"context.vars[{node.target!r}] = ") + + self._import_common(node, frame) + if frame.toplevel and not node.target.startswith("_"): self.writeline(f"context.exported_vars.discard({node.target!r})") def visit_FromImport(self, node, frame): """Visit named imports.""" self.newline(node) - prefix = self.choose_async("await ") - self.write(f"included_template = {prefix}environment.get_template(") - self.visit(node.template, frame) - self.write(f", {self.name!r}).") - if node.with_context: - f_name = f"make_module{self.choose_async('_async')}" - self.write( - f"{f_name}(context.get_all(), True, {self.dump_local_context(frame)})" - ) - elif self.environment.is_async: - self.write("_get_default_module_async()") - else: - self.write("_get_default_module(context)") - + self.write("included_template = ") + self._import_common(node, frame) var_names = [] discarded_names = [] for name in node.names: @@ -1289,21 +1283,25 @@ class CodeGenerator(NodeVisitor): if self.environment.finalize: src = "environment.finalize(" env_finalize = self.environment.finalize + pass_arg = { + _PassArg.context: "context", + _PassArg.eval_context: "context.eval_ctx", + _PassArg.environment: "environment", + }.get(_PassArg.from_obj(env_finalize)) + finalize = None - def finalize(value): - return default(env_finalize(value)) - - if getattr(env_finalize, "contextfunction", False) is True: - src += "context, " - finalize = None # noqa: F811 - elif getattr(env_finalize, "evalcontextfunction", False) is True: - src += "context.eval_ctx, " - finalize = None - elif getattr(env_finalize, "environmentfunction", False) is True: - src += "environment, " + if pass_arg is None: def finalize(value): - return default(env_finalize(self.environment, value)) + return default(env_finalize(value)) + + else: + src = f"{src}{pass_arg}, " + + if pass_arg == "environment": + + def finalize(value): + return default(env_finalize(self.environment, value)) self._finalize = self._FinalizeInfo(finalize, src) return self._finalize @@ -1673,13 +1671,11 @@ class CodeGenerator(NodeVisitor): if is_filter: compiler_map = self.filters env_map = self.environment.filters - type_name = mark_name = "filter" + type_name = "filter" else: compiler_map = self.tests env_map = self.environment.tests type_name = "test" - # Filters use "contextfilter", tests and calls use "contextfunction". - mark_name = "function" if self.environment.is_async: self.write("await auto_await(") @@ -1693,12 +1689,14 @@ class CodeGenerator(NodeVisitor): if func is None and not frame.soft_frame: self.fail(f"No {type_name} named {node.name!r}.", node.lineno) - if getattr(func, f"context{mark_name}", False) is True: - self.write("context, ") - elif getattr(func, f"evalcontext{mark_name}", False) is True: - self.write("context.eval_ctx, ") - elif getattr(func, f"environment{mark_name}", False) is True: - self.write("environment, ") + pass_arg = { + _PassArg.context: "context", + _PassArg.eval_context: "context.eval_ctx", + _PassArg.environment: "environment", + }.get(_PassArg.from_obj(func)) + + if pass_arg is not None: + self.write(f"{pass_arg}, ") # Back to the visitor function to handle visiting the target of # the filter or test. |