summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Lord <davidism@gmail.com>2021-05-08 13:03:15 -0700
committerDavid Lord <davidism@gmail.com>2021-05-08 13:03:15 -0700
commitbe15556dbaebc4ab45d3f5f5d34874279831103c (patch)
treea84a778972947c9b3f82ccf5ae6e7358984a9b4b
parentac8d8d69fe3e0eb81251867218c67193a086e427 (diff)
downloadjinja2-be15556dbaebc4ab45d3f5f5d34874279831103c.tar.gz
add type annotations
-rw-r--r--CHANGES.rst1
-rw-r--r--MANIFEST.in1
-rw-r--r--docs/examples/inline_gettext_extension.py2
-rw-r--r--setup.cfg12
-rw-r--r--src/jinja2/async_utils.py44
-rw-r--r--src/jinja2/bccache.py117
-rw-r--r--src/jinja2/compiler.py565
-rw-r--r--src/jinja2/debug.py53
-rw-r--r--src/jinja2/defaults.py7
-rw-r--r--src/jinja2/environment.py603
-rw-r--r--src/jinja2/exceptions.py63
-rw-r--r--src/jinja2/ext.py390
-rw-r--r--src/jinja2/filters.py93
-rw-r--r--src/jinja2/idtracking.py163
-rw-r--r--src/jinja2/lexer.py251
-rw-r--r--src/jinja2/loaders.py162
-rw-r--r--src/jinja2/meta.py51
-rw-r--r--src/jinja2/nativetypes.py40
-rw-r--r--src/jinja2/nodes.py304
-rw-r--r--src/jinja2/optimizer.py15
-rw-r--r--src/jinja2/parser.py289
-rw-r--r--src/jinja2/py.typed0
-rw-r--r--src/jinja2/runtime.py416
-rw-r--r--src/jinja2/sandbox.py137
-rw-r--r--src/jinja2/utils.py178
-rw-r--r--src/jinja2/visitor.py27
-rw-r--r--tests/test_ext.py2
27 files changed, 2469 insertions, 1517 deletions
diff --git a/CHANGES.rst b/CHANGES.rst
index 950b4e5..bebdb98 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -9,6 +9,7 @@ Unreleased
- Bump MarkupSafe dependency to >=1.1.
- Bump Babel optional dependency to >=2.1.
- Remove code that was marked deprecated.
+- Add type hinting. :pr:`1412`
- Use :pep:`451` API to load templates with
:class:`~loaders.PackageLoader`. :issue:`1168`
- Fix a bug that caused imported macros to not have access to the
diff --git a/MANIFEST.in b/MANIFEST.in
index 8690e35..e5b231d 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -6,4 +6,5 @@ graft docs
prune docs/_build
graft examples
graft tests
+include src/click/py.typed
global-exclude *.pyc
diff --git a/docs/examples/inline_gettext_extension.py b/docs/examples/inline_gettext_extension.py
index d75119c..bf8b9db 100644
--- a/docs/examples/inline_gettext_extension.py
+++ b/docs/examples/inline_gettext_extension.py
@@ -30,7 +30,7 @@ class InlineGettext(Extension):
pos = 0
lineno = token.lineno
- while 1:
+ while True:
if not paren_stack:
match = _outside_re.search(token.value, pos)
else:
diff --git a/setup.cfg b/setup.cfg
index acac7ca..a56d553 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -86,15 +86,21 @@ per-file-ignores =
files = src/jinja2
python_version = 3.6
disallow_subclassing_any = True
-# disallow_untyped_calls = True
-# disallow_untyped_defs = True
+disallow_untyped_calls = True
+disallow_untyped_defs = True
disallow_incomplete_defs = True
no_implicit_optional = True
local_partial_types = True
-# no_implicit_reexport = True
+no_implicit_reexport = True
strict_equality = True
warn_redundant_casts = True
warn_unused_configs = True
warn_unused_ignores = True
warn_return_any = True
warn_unreachable = True
+
+[mypy-jinja2.defaults]
+no_implicit_reexport = False
+
+[mypy-markupsafe]
+no_implicit_reexport = False
diff --git a/src/jinja2/async_utils.py b/src/jinja2/async_utils.py
index cb011b2..cedd7ba 100644
--- a/src/jinja2/async_utils.py
+++ b/src/jinja2/async_utils.py
@@ -5,27 +5,26 @@ from functools import wraps
from .utils import _PassArg
from .utils import pass_eval_context
-if t.TYPE_CHECKING:
- V = t.TypeVar("V")
+V = t.TypeVar("V")
-def async_variant(normal_func):
- def decorator(async_func):
+def async_variant(normal_func): # type: ignore
+ def decorator(async_func): # type: ignore
pass_arg = _PassArg.from_obj(normal_func)
need_eval_context = pass_arg is None
if pass_arg is _PassArg.environment:
- def is_async(args):
- return args[0].is_async
+ def is_async(args: t.Any) -> bool:
+ return t.cast(bool, args[0].is_async)
else:
- def is_async(args):
- return args[0].environment.is_async
+ def is_async(args: t.Any) -> bool:
+ return t.cast(bool, args[0].environment.is_async)
@wraps(normal_func)
- def wrapper(*args, **kwargs):
+ def wrapper(*args, **kwargs): # type: ignore
b = is_async(args)
if need_eval_context:
@@ -45,32 +44,25 @@ def async_variant(normal_func):
return decorator
-async def auto_await(value):
+async def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V":
if inspect.isawaitable(value):
- return await value
+ return await t.cast("t.Awaitable[V]", value)
- return value
+ return t.cast("V", value)
-async def auto_aiter(iterable):
+async def auto_aiter(
+ iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
+) -> "t.AsyncIterator[V]":
if hasattr(iterable, "__aiter__"):
- async for item in iterable:
+ async for item in t.cast("t.AsyncIterable[V]", iterable):
yield item
else:
- for item in iterable:
+ for item in t.cast("t.Iterable[V]", iterable):
yield item
async def auto_to_list(
value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
-) -> "t.List[V]":
- seq = []
-
- if hasattr(value, "__aiter__"):
- async for item in t.cast(t.AsyncIterable, value):
- seq.append(item)
- else:
- for item in t.cast(t.Iterable, value):
- seq.append(item)
-
- return seq
+) -> t.List["V"]:
+ return [x async for x in auto_aiter(value)]
diff --git a/src/jinja2/bccache.py b/src/jinja2/bccache.py
index 7ddcf40..bbe4bee 100644
--- a/src/jinja2/bccache.py
+++ b/src/jinja2/bccache.py
@@ -13,10 +13,22 @@ import pickle
import stat
import sys
import tempfile
+import typing as t
from hashlib import sha1
from io import BytesIO
+from types import CodeType
+
+if t.TYPE_CHECKING:
+ import typing_extensions as te
+ from .environment import Environment
+
+ class _MemcachedClient(te.Protocol):
+ def get(self, key: str) -> bytes:
+ ...
+
+ def set(self, key: str, value: bytes, timeout: t.Optional[int] = None) -> None:
+ ...
-from .utils import open_if_exists
bc_version = 5
# Magic bytes to identify Jinja bytecode cache files. Contains the
@@ -38,17 +50,17 @@ class Bucket:
cache subclasses don't have to care about cache invalidation.
"""
- def __init__(self, environment, key, checksum):
+ def __init__(self, environment: "Environment", key: str, checksum: str) -> None:
self.environment = environment
self.key = key
self.checksum = checksum
self.reset()
- def reset(self):
+ def reset(self) -> None:
"""Resets the bucket (unloads the bytecode)."""
- self.code = None
+ self.code: t.Optional[CodeType] = None
- def load_bytecode(self, f):
+ def load_bytecode(self, f: t.BinaryIO) -> None:
"""Loads bytecode from a file or file like object."""
# make sure the magic header is correct
magic = f.read(len(bc_magic))
@@ -67,7 +79,7 @@ class Bucket:
self.reset()
return
- def write_bytecode(self, f):
+ def write_bytecode(self, f: t.BinaryIO) -> None:
"""Dump the bytecode into the file or file like object passed."""
if self.code is None:
raise TypeError("can't write empty bucket")
@@ -75,12 +87,12 @@ class Bucket:
pickle.dump(self.checksum, f, 2)
marshal.dump(self.code, f)
- def bytecode_from_string(self, string):
- """Load bytecode from a string."""
+ def bytecode_from_string(self, string: bytes) -> None:
+ """Load bytecode from bytes."""
self.load_bytecode(BytesIO(string))
- def bytecode_to_string(self):
- """Return the bytecode as string."""
+ def bytecode_to_string(self) -> bytes:
+ """Return the bytecode as bytes."""
out = BytesIO()
self.write_bytecode(out)
return out.getvalue()
@@ -115,41 +127,48 @@ class BytecodeCache:
Jinja.
"""
- def load_bytecode(self, bucket):
+ def load_bytecode(self, bucket: Bucket) -> None:
"""Subclasses have to override this method to load bytecode into a
bucket. If they are not able to find code in the cache for the
bucket, it must not do anything.
"""
raise NotImplementedError()
- def dump_bytecode(self, bucket):
+ def dump_bytecode(self, bucket: Bucket) -> None:
"""Subclasses have to override this method to write the bytecode
from a bucket back to the cache. If it unable to do so it must not
fail silently but raise an exception.
"""
raise NotImplementedError()
- def clear(self):
+ def clear(self) -> None:
"""Clears the cache. This method is not used by Jinja but should be
implemented to allow applications to clear the bytecode cache used
by a particular environment.
"""
- def get_cache_key(self, name, filename=None):
+ def get_cache_key(
+ self, name: str, filename: t.Optional[t.Union[str]] = None
+ ) -> str:
"""Returns the unique hash key for this template name."""
hash = sha1(name.encode("utf-8"))
+
if filename is not None:
- filename = "|" + filename
- if isinstance(filename, str):
- filename = filename.encode("utf-8")
- hash.update(filename)
+ hash.update(f"|{filename}".encode("utf-8"))
+
return hash.hexdigest()
- def get_source_checksum(self, source):
+ def get_source_checksum(self, source: str) -> str:
"""Returns a checksum for the source."""
return sha1(source.encode("utf-8")).hexdigest()
- def get_bucket(self, environment, name, filename, source):
+ def get_bucket(
+ self,
+ environment: "Environment",
+ name: str,
+ filename: t.Optional[str],
+ source: str,
+ ) -> Bucket:
"""Return a cache bucket for the given template. All arguments are
mandatory but filename may be `None`.
"""
@@ -159,7 +178,7 @@ class BytecodeCache:
self.load_bytecode(bucket)
return bucket
- def set_bucket(self, bucket):
+ def set_bucket(self, bucket: Bucket) -> None:
"""Put the bucket into the cache."""
self.dump_bytecode(bucket)
@@ -182,14 +201,16 @@ class FileSystemBytecodeCache(BytecodeCache):
This bytecode cache supports clearing of the cache using the clear method.
"""
- def __init__(self, directory=None, pattern="__jinja2_%s.cache"):
+ def __init__(
+ self, directory: t.Optional[str] = None, pattern: str = "__jinja2_%s.cache"
+ ) -> None:
if directory is None:
directory = self._get_default_cache_dir()
self.directory = directory
self.pattern = pattern
- def _get_default_cache_dir(self):
- def _unsafe_dir():
+ def _get_default_cache_dir(self) -> str:
+ def _unsafe_dir() -> t.NoReturn:
raise RuntimeError(
"Cannot determine safe temp directory. You "
"need to explicitly provide one."
@@ -235,25 +256,21 @@ class FileSystemBytecodeCache(BytecodeCache):
return actual_dir
- def _get_cache_filename(self, bucket):
+ def _get_cache_filename(self, bucket: Bucket) -> str:
return os.path.join(self.directory, self.pattern % (bucket.key,))
- def load_bytecode(self, bucket):
- f = open_if_exists(self._get_cache_filename(bucket), "rb")
- if f is not None:
- try:
+ def load_bytecode(self, bucket: Bucket) -> None:
+ filename = self._get_cache_filename(bucket)
+
+ if os.path.exists(filename):
+ with open(filename, "rb") as f:
bucket.load_bytecode(f)
- finally:
- f.close()
- def dump_bytecode(self, bucket):
- f = open(self._get_cache_filename(bucket), "wb")
- try:
+ def dump_bytecode(self, bucket: Bucket) -> None:
+ with open(self._get_cache_filename(bucket), "wb") as f:
bucket.write_bytecode(f)
- finally:
- f.close()
- def clear(self):
+ def clear(self) -> None:
# imported lazily here because google app-engine doesn't support
# write access on the file system and the function does not exist
# normally.
@@ -314,32 +331,34 @@ class MemcachedBytecodeCache(BytecodeCache):
def __init__(
self,
- client,
- prefix="jinja2/bytecode/",
- timeout=None,
- ignore_memcache_errors=True,
+ client: "_MemcachedClient",
+ prefix: str = "jinja2/bytecode/",
+ timeout: t.Optional[int] = None,
+ ignore_memcache_errors: bool = True,
):
self.client = client
self.prefix = prefix
self.timeout = timeout
self.ignore_memcache_errors = ignore_memcache_errors
- def load_bytecode(self, bucket):
+ def load_bytecode(self, bucket: Bucket) -> None:
try:
code = self.client.get(self.prefix + bucket.key)
except Exception:
if not self.ignore_memcache_errors:
raise
- code = None
- if code is not None:
+ else:
bucket.bytecode_from_string(code)
- def dump_bytecode(self, bucket):
- args = (self.prefix + bucket.key, bucket.bytecode_to_string())
- if self.timeout is not None:
- args += (self.timeout,)
+ def dump_bytecode(self, bucket: Bucket) -> None:
+ key = self.prefix + bucket.key
+ value = bucket.bytecode_to_string()
+
try:
- self.client.set(*args)
+ if self.timeout is not None:
+ self.client.set(key, value, self.timeout)
+ else:
+ self.client.set(key, value)
except Exception:
if not self.ignore_memcache_errors:
raise
diff --git a/src/jinja2/compiler.py b/src/jinja2/compiler.py
index b15fb67..a8cfe9e 100644
--- a/src/jinja2/compiler.py
+++ b/src/jinja2/compiler.py
@@ -1,6 +1,5 @@
"""Compiles nodes from the parser into Python code."""
import typing as t
-from collections import namedtuple
from contextlib import contextmanager
from functools import update_wrapper
from io import StringIO
@@ -23,6 +22,11 @@ from .utils import _PassArg
from .utils import concat
from .visitor import NodeVisitor
+if t.TYPE_CHECKING:
+ from .environment import Environment
+
+F = t.TypeVar("F", bound=t.Callable[..., t.Any])
+
operators = {
"eq": "==",
"ne": "!=",
@@ -35,33 +39,89 @@ operators = {
}
-def optimizeconst(f):
- def new_func(self, node, frame, **kwargs):
+def optimizeconst(f: F) -> F:
+ def new_func(
+ self: "CodeGenerator", node: nodes.Expr, frame: "Frame", **kwargs: t.Any
+ ) -> t.Any:
# Only optimize if the frame is not volatile
- if self.optimized and not frame.eval_ctx.volatile:
+ if self.optimizer is not None and not frame.eval_ctx.volatile:
new_node = self.optimizer.visit(node, frame.eval_ctx)
+
if new_node != node:
return self.visit(new_node, frame)
+
return f(self, node, frame, **kwargs)
- return update_wrapper(new_func, f)
+ return update_wrapper(t.cast(F, new_func), f)
+
+
+def _make_binop(op: str) -> t.Callable[["CodeGenerator", nodes.BinExpr, "Frame"], None]:
+ @optimizeconst
+ def visitor(self: "CodeGenerator", node: nodes.BinExpr, frame: Frame) -> None:
+ if (
+ self.environment.sandboxed
+ and op in self.environment.intercepted_binops # type: ignore
+ ):
+ self.write(f"environment.call_binop(context, {op!r}, ")
+ self.visit(node.left, frame)
+ self.write(", ")
+ self.visit(node.right, frame)
+ else:
+ self.write("(")
+ self.visit(node.left, frame)
+ self.write(f" {op} ")
+ self.visit(node.right, frame)
+
+ self.write(")")
+
+ return visitor
+
+
+def _make_unop(
+ op: str,
+) -> t.Callable[["CodeGenerator", nodes.UnaryExpr, "Frame"], None]:
+ @optimizeconst
+ def visitor(self: "CodeGenerator", node: nodes.UnaryExpr, frame: Frame) -> None:
+ if (
+ self.environment.sandboxed
+ and op in self.environment.intercepted_unops # type: ignore
+ ):
+ self.write(f"environment.call_unop(context, {op!r}, ")
+ self.visit(node.node, frame)
+ else:
+ self.write("(" + op)
+ self.visit(node.node, frame)
+
+ self.write(")")
+
+ return visitor
def generate(
- node, environment, name, filename, stream=None, defer_init=False, optimized=True
-):
+ node: nodes.Template,
+ environment: "Environment",
+ name: t.Optional[str],
+ filename: t.Optional[str],
+ stream: t.Optional[t.TextIO] = None,
+ defer_init: bool = False,
+ optimized: bool = True,
+) -> t.Optional[str]:
"""Generate the python source for a node tree."""
if not isinstance(node, nodes.Template):
raise TypeError("Can't compile non template nodes")
+
generator = environment.code_generator_class(
environment, name, filename, stream, defer_init, optimized
)
generator.visit(node)
+
if stream is None:
- return generator.stream.getvalue()
+ return generator.stream.getvalue() # type: ignore
+
+ return None
-def has_safe_repr(value):
+def has_safe_repr(value: t.Any) -> bool:
"""Does the node have a safe representation?"""
if value is None or value is NotImplemented or value is Ellipsis:
return True
@@ -78,7 +138,9 @@ def has_safe_repr(value):
return False
-def find_undeclared(nodes, names):
+def find_undeclared(
+ nodes: t.Iterable[nodes.Node], names: t.Iterable[str]
+) -> t.Set[str]:
"""Check if the names passed are accessed undeclared. The return value
is a set of all the undeclared names from the sequence of names found.
"""
@@ -92,7 +154,7 @@ def find_undeclared(nodes, names):
class MacroRef:
- def __init__(self, node):
+ def __init__(self, node: t.Union[nodes.Macro, nodes.CallBlock]) -> None:
self.node = node
self.accesses_caller = False
self.accesses_kwargs = False
@@ -102,9 +164,38 @@ class MacroRef:
class Frame:
"""Holds compile time information for us."""
- def __init__(self, eval_ctx, parent=None, level=None):
+ def __init__(
+ self,
+ eval_ctx: EvalContext,
+ parent: t.Optional["Frame"] = None,
+ level: t.Optional[int] = None,
+ ) -> None:
self.eval_ctx = eval_ctx
- self.symbols = Symbols(parent.symbols if parent else None, level=level)
+
+ # the parent of this frame
+ self.parent = parent
+
+ if parent is None:
+ self.symbols = Symbols(level=level)
+
+ # in some dynamic inheritance situations the compiler needs to add
+ # write tests around output statements.
+ self.require_output_check = False
+
+ # inside some tags we are using a buffer rather than yield statements.
+ # this for example affects {% filter %} or {% macro %}. If a frame
+ # is buffered this variable points to the name of the list used as
+ # buffer.
+ self.buffer: t.Optional[str] = None
+
+ # the name of the block we're in, otherwise None.
+ self.block: t.Optional[str] = None
+
+ else:
+ self.symbols = Symbols(parent.symbols, level=level)
+ self.require_output_check = parent.require_output_check
+ self.buffer = parent.buffer
+ self.block = parent.block
# a toplevel frame is the root + soft frames such as if conditions.
self.toplevel = False
@@ -114,25 +205,6 @@ class Frame:
# situations.
self.rootlevel = False
- # in some dynamic inheritance situations the compiler needs to add
- # write tests around output statements.
- self.require_output_check = parent and parent.require_output_check
-
- # inside some tags we are using a buffer rather than yield statements.
- # this for example affects {% filter %} or {% macro %}. If a frame
- # is buffered this variable points to the name of the list used as
- # buffer.
- self.buffer = None
-
- # the name of the block we're in, otherwise None.
- self.block = parent.block if parent else None
-
- # the parent of this frame
- self.parent = parent
-
- if parent is not None:
- self.buffer = parent.buffer
-
# variables set inside of loops and blocks should not affect outer frames,
# but they still needs to be kept track of as part of the active context.
self.loop_frame = False
@@ -143,20 +215,20 @@ class Frame:
# or compile time.
self.soft_frame = False
- def copy(self):
+ def copy(self) -> "Frame":
"""Create a copy of the current one."""
- rv = object.__new__(self.__class__)
+ rv = t.cast(Frame, object.__new__(self.__class__))
rv.__dict__.update(self.__dict__)
rv.symbols = self.symbols.copy()
return rv
- def inner(self, isolated=False):
+ def inner(self, isolated: bool = False) -> "Frame":
"""Return an inner frame."""
if isolated:
return Frame(self.eval_ctx, level=self.symbols.level + 1)
return Frame(self.eval_ctx, self)
- def soft(self):
+ def soft(self) -> "Frame":
"""Return a soft frame. A soft frame may not be modified as
standalone thing as it shares the resources with the frame it
was created of, but it's not a rootlevel frame any longer.
@@ -179,19 +251,19 @@ class VisitorExit(RuntimeError):
class DependencyFinderVisitor(NodeVisitor):
"""A visitor that collects filter and test calls."""
- def __init__(self):
- self.filters = set()
- self.tests = set()
+ def __init__(self) -> None:
+ self.filters: t.Set[str] = set()
+ self.tests: t.Set[str] = set()
- def visit_Filter(self, node):
+ def visit_Filter(self, node: nodes.Filter) -> None:
self.generic_visit(node)
self.filters.add(node.name)
- def visit_Test(self, node):
+ def visit_Test(self, node: nodes.Test) -> None:
self.generic_visit(node)
self.tests.add(node.name)
- def visit_Block(self, node):
+ def visit_Block(self, node: nodes.Block) -> None:
"""Stop visiting at blocks."""
@@ -201,11 +273,11 @@ class UndeclaredNameVisitor(NodeVisitor):
not stop at closure frames.
"""
- def __init__(self, names):
+ def __init__(self, names: t.Iterable[str]) -> None:
self.names = set(names)
- self.undeclared = set()
+ self.undeclared: t.Set[str] = set()
- def visit_Name(self, node):
+ def visit_Name(self, node: nodes.Name) -> None:
if node.ctx == "load" and node.name in self.names:
self.undeclared.add(node.name)
if self.undeclared == self.names:
@@ -213,7 +285,7 @@ class UndeclaredNameVisitor(NodeVisitor):
else:
self.names.discard(node.name)
- def visit_Block(self, node):
+ def visit_Block(self, node: nodes.Block) -> None:
"""Stop visiting a blocks."""
@@ -226,8 +298,14 @@ class CompilerExit(Exception):
class CodeGenerator(NodeVisitor):
def __init__(
- self, environment, name, filename, stream=None, defer_init=False, optimized=True
- ):
+ self,
+ environment: "Environment",
+ name: t.Optional[str],
+ filename: t.Optional[str],
+ stream: t.Optional[t.TextIO] = None,
+ defer_init: bool = False,
+ optimized: bool = True,
+ ) -> None:
if stream is None:
stream = StringIO()
self.environment = environment
@@ -236,16 +314,17 @@ class CodeGenerator(NodeVisitor):
self.stream = stream
self.created_block_context = False
self.defer_init = defer_init
- self.optimized = optimized
+ self.optimizer: t.Optional[Optimizer] = None
+
if optimized:
self.optimizer = Optimizer(environment)
# aliases for imports
- self.import_aliases = {}
+ self.import_aliases: t.Dict[str, str] = {}
# a registry for all blocks. Because blocks are moved out
# into the global python scope they are registered here
- self.blocks = {}
+ self.blocks: t.Dict[str, nodes.Block] = {}
# the number of extends statements so far
self.extends_so_far = 0
@@ -259,12 +338,12 @@ class CodeGenerator(NodeVisitor):
self.code_lineno = 1
# registry of all filters and tests (global, not block local)
- self.tests = {}
- self.filters = {}
+ self.tests: t.Dict[str, str] = {}
+ self.filters: t.Dict[str, str] = {}
# the debug information
- self.debug_info = []
- self._write_debug_info = None
+ self.debug_info: t.List[t.Tuple[int, int]] = []
+ self._write_debug_info: t.Optional[int] = None
# the number of new lines before the next write()
self._new_lines = 0
@@ -283,31 +362,37 @@ class CodeGenerator(NodeVisitor):
self._indentation = 0
# Tracks toplevel assignments
- self._assign_stack = []
+ self._assign_stack: t.List[t.Set[str]] = []
# Tracks parameter definition blocks
- self._param_def_block = []
+ self._param_def_block: t.List[t.Set[str]] = []
# Tracks the current context.
self._context_reference_stack = ["context"]
+ @property
+ def optimized(self) -> bool:
+ return self.optimizer is not None
+
# -- Various compilation helpers
- def fail(self, msg, lineno):
+ def fail(self, msg: str, lineno: int) -> t.NoReturn:
"""Fail with a :exc:`TemplateAssertionError`."""
raise TemplateAssertionError(msg, lineno, self.name, self.filename)
- def temporary_identifier(self):
+ def temporary_identifier(self) -> str:
"""Get a new unique identifier."""
self._last_identifier += 1
return f"t_{self._last_identifier}"
- def buffer(self, frame):
+ def buffer(self, frame: Frame) -> None:
"""Enable buffering for the frame from that point onwards."""
frame.buffer = self.temporary_identifier()
self.writeline(f"{frame.buffer} = []")
- def return_buffer_contents(self, frame, force_unescaped=False):
+ def return_buffer_contents(
+ self, frame: Frame, force_unescaped: bool = False
+ ) -> None:
"""Return the buffer contents of the frame."""
if not force_unescaped:
if frame.eval_ctx.volatile:
@@ -325,33 +410,35 @@ class CodeGenerator(NodeVisitor):
return
self.writeline(f"return concat({frame.buffer})")
- def indent(self):
+ def indent(self) -> None:
"""Indent by one."""
self._indentation += 1
- def outdent(self, step=1):
+ def outdent(self, step: int = 1) -> None:
"""Outdent by step."""
self._indentation -= step
- def start_write(self, frame, node=None):
+ def start_write(self, frame: Frame, node: t.Optional[nodes.Node] = None) -> None:
"""Yield or write into the frame buffer."""
if frame.buffer is None:
self.writeline("yield ", node)
else:
self.writeline(f"{frame.buffer}.append(", node)
- def end_write(self, frame):
+ def end_write(self, frame: Frame) -> None:
"""End the writing process started by `start_write`."""
if frame.buffer is not None:
self.write(")")
- def simple_write(self, s, frame, node=None):
+ def simple_write(
+ self, s: str, frame: Frame, node: t.Optional[nodes.Node] = None
+ ) -> None:
"""Simple shortcut for start_write + write + end_write."""
self.start_write(frame, node)
self.write(s)
self.end_write(frame)
- def blockvisit(self, nodes, frame):
+ def blockvisit(self, nodes: t.Iterable[nodes.Node], frame: Frame) -> None:
"""Visit a list of nodes as block in a frame. If the current frame
is no buffer a dummy ``if 0: yield None`` is written automatically.
"""
@@ -362,7 +449,7 @@ class CodeGenerator(NodeVisitor):
except CompilerExit:
pass
- def write(self, x):
+ def write(self, x: str) -> None:
"""Write a string into the output stream."""
if self._new_lines:
if not self._first_write:
@@ -376,19 +463,26 @@ class CodeGenerator(NodeVisitor):
self._new_lines = 0
self.stream.write(x)
- def writeline(self, x, node=None, extra=0):
+ def writeline(
+ self, x: str, node: t.Optional[nodes.Node] = None, extra: int = 0
+ ) -> None:
"""Combination of newline and write."""
self.newline(node, extra)
self.write(x)
- def newline(self, node=None, extra=0):
+ def newline(self, node: t.Optional[nodes.Node] = None, extra: int = 0) -> None:
"""Add one or more newlines before the next write."""
self._new_lines = max(self._new_lines, 1 + extra)
if node is not None and node.lineno != self._last_line:
self._write_debug_info = node.lineno
self._last_line = node.lineno
- def signature(self, node, frame, extra_kwargs=None):
+ def signature(
+ self,
+ node: t.Union[nodes.Call, nodes.Filter, nodes.Test],
+ frame: Frame,
+ extra_kwargs: t.Optional[t.Mapping[str, t.Any]] = None,
+ ) -> None:
"""Writes a function call to the stream for the current node.
A leading comma is added automatically. The extra keyword
arguments may not include python keywords otherwise a syntax
@@ -397,11 +491,10 @@ class CodeGenerator(NodeVisitor):
"""
# if any of the given keyword arguments is a python keyword
# we have to make sure that no invalid call is created.
- kwarg_workaround = False
- for kwarg in chain((x.key for x in node.kwargs), extra_kwargs or ()):
- if is_python_keyword(kwarg):
- kwarg_workaround = True
- break
+ kwarg_workaround = any(
+ is_python_keyword(t.cast(str, k))
+ for k in chain((x.key for x in node.kwargs), extra_kwargs or ())
+ )
for arg in node.args:
self.write(", ")
@@ -441,7 +534,7 @@ class CodeGenerator(NodeVisitor):
self.write(", **")
self.visit(node.dyn_kwargs, frame)
- def pull_dependencies(self, nodes):
+ def pull_dependencies(self, nodes: t.Iterable[nodes.Node]) -> None:
"""Find all filter and test names used in the template and
assign them to variables in the compiled namespace. Checking
that the names are registered with the environment is done when
@@ -457,23 +550,25 @@ class CodeGenerator(NodeVisitor):
for node in nodes:
visitor.visit(node)
- for dependency in "filters", "tests":
- mapping = getattr(self, dependency)
-
- for name in getattr(visitor, dependency):
- if name not in mapping:
- mapping[name] = self.temporary_identifier()
+ for id_map, names, dependency in (self.filters, visitor.filters, "filters"), (
+ self.tests,
+ visitor.tests,
+ "tests",
+ ):
+ for name in names:
+ if name not in id_map:
+ id_map[name] = self.temporary_identifier()
# add check during runtime that dependencies used inside of executed
# blocks are defined, as this step may be skipped during compile time
self.writeline("try:")
self.indent()
- self.writeline(f"{mapping[name]} = environment.{dependency}[{name!r}]")
+ self.writeline(f"{id_map[name]} = environment.{dependency}[{name!r}]")
self.outdent()
self.writeline("except KeyError:")
self.indent()
self.writeline("@internalcode")
- self.writeline(f"def {mapping[name]}(*unused):")
+ self.writeline(f"def {id_map[name]}(*unused):")
self.indent()
self.writeline(
f'raise TemplateRuntimeError("No {dependency[:-1]}'
@@ -482,7 +577,7 @@ class CodeGenerator(NodeVisitor):
self.outdent()
self.outdent()
- def enter_frame(self, frame):
+ def enter_frame(self, frame: Frame) -> None:
undefs = []
for target, (action, param) in frame.symbols.loads.items():
if action == VAR_LOAD_PARAMETER:
@@ -498,7 +593,7 @@ class CodeGenerator(NodeVisitor):
if undefs:
self.writeline(f"{' = '.join(undefs)} = missing")
- def leave_frame(self, frame, with_python_scope=False):
+ def leave_frame(self, frame: Frame, with_python_scope: bool = False) -> None:
if not with_python_scope:
undefs = []
for target in frame.symbols.loads:
@@ -506,13 +601,15 @@ class CodeGenerator(NodeVisitor):
if undefs:
self.writeline(f"{' = '.join(undefs)} = missing")
- def choose_async(self, async_value="async ", sync_value=""):
+ def choose_async(self, async_value: str = "async ", sync_value: str = "") -> str:
return async_value if self.environment.is_async else sync_value
- def func(self, name):
+ def func(self, name: str) -> str:
return f"{self.choose_async()}def {name}"
- def macro_body(self, node, frame):
+ def macro_body(
+ self, node: t.Union[nodes.Macro, nodes.CallBlock], frame: Frame
+ ) -> t.Tuple[Frame, MacroRef]:
"""Dump the function def of a macro or call block."""
frame = frame.inner()
frame.symbols.analyze_node(node)
@@ -521,6 +618,7 @@ class CodeGenerator(NodeVisitor):
explicit_caller = None
skip_special_params = set()
args = []
+
for idx, arg in enumerate(node.args):
if arg.name == "caller":
explicit_caller = idx
@@ -592,7 +690,7 @@ class CodeGenerator(NodeVisitor):
return frame, macro_ref
- def macro_def(self, macro_ref, frame):
+ def macro_def(self, macro_ref: MacroRef, frame: Frame) -> None:
"""Dump the macro definition for the def created by macro_body."""
arg_tuple = ", ".join(repr(x.name) for x in macro_ref.node.args)
name = getattr(macro_ref.node, "name", None)
@@ -604,21 +702,21 @@ class CodeGenerator(NodeVisitor):
f" {macro_ref.accesses_caller!r}, context.eval_ctx.autoescape)"
)
- def position(self, node):
+ def position(self, node: nodes.Node) -> str:
"""Return a human readable position for the node."""
rv = f"line {node.lineno}"
if self.name is not None:
rv = f"{rv} in {self.name!r}"
return rv
- def dump_local_context(self, frame):
+ def dump_local_context(self, frame: Frame) -> str:
items_kv = ", ".join(
f"{name!r}: {target}"
for name, target in frame.symbols.dump_stores().items()
)
return f"{{{items_kv}}}"
- def write_commons(self):
+ def write_commons(self) -> None:
"""Writes a common preamble that is used by root and block functions.
Primarily this sets up common local helpers and enforces a generator
through a dead branch.
@@ -630,7 +728,7 @@ class CodeGenerator(NodeVisitor):
self.writeline("cond_expr_undefined = Undefined")
self.writeline("if 0: yield None")
- def push_parameter_definitions(self, frame):
+ def push_parameter_definitions(self, frame: Frame) -> None:
"""Pushes all parameter targets from the given frame into a local
stack that permits tracking of yet to be assigned parameters. In
particular this enables the optimization from `visit_Name` to skip
@@ -639,46 +737,46 @@ class CodeGenerator(NodeVisitor):
"""
self._param_def_block.append(frame.symbols.dump_param_targets())
- def pop_parameter_definitions(self):
+ def pop_parameter_definitions(self) -> None:
"""Pops the current parameter definitions set."""
self._param_def_block.pop()
- def mark_parameter_stored(self, target):
+ def mark_parameter_stored(self, target: str) -> None:
"""Marks a parameter in the current parameter definitions as stored.
This will skip the enforced undefined checks.
"""
if self._param_def_block:
self._param_def_block[-1].discard(target)
- def push_context_reference(self, target):
+ def push_context_reference(self, target: str) -> None:
self._context_reference_stack.append(target)
- def pop_context_reference(self):
+ def pop_context_reference(self) -> None:
self._context_reference_stack.pop()
- def get_context_ref(self):
+ def get_context_ref(self) -> str:
return self._context_reference_stack[-1]
- def get_resolve_func(self):
+ def get_resolve_func(self) -> str:
target = self._context_reference_stack[-1]
if target == "context":
return "resolve"
return f"{target}.resolve"
- def derive_context(self, frame):
+ def derive_context(self, frame: Frame) -> str:
return f"{self.get_context_ref()}.derived({self.dump_local_context(frame)})"
- def parameter_is_undeclared(self, target):
+ def parameter_is_undeclared(self, target: str) -> bool:
"""Checks if a given target is an undeclared parameter."""
if not self._param_def_block:
return False
return target in self._param_def_block[-1]
- def push_assign_tracking(self):
+ def push_assign_tracking(self) -> None:
"""Pushes a new layer for assignment tracking."""
self._assign_stack.append(set())
- def pop_assign_tracking(self, frame):
+ def pop_assign_tracking(self, frame: Frame) -> None:
"""Pops the topmost level for assignment tracking and updates the
context variables if necessary.
"""
@@ -723,7 +821,9 @@ class CodeGenerator(NodeVisitor):
# -- Statement Visitors
- def visit_Template(self, node, frame=None):
+ def visit_Template(
+ self, node: nodes.Template, frame: t.Optional[Frame] = None
+ ) -> None:
assert frame is None, "no root frame allowed"
eval_ctx = EvalContext(self.environment, self.name)
@@ -840,7 +940,7 @@ class CodeGenerator(NodeVisitor):
debug_kv_str = "&".join(f"{k}={v}" for k, v in self.debug_info)
self.writeline(f"debug_info = {debug_kv_str!r}")
- def visit_Block(self, node, frame):
+ def visit_Block(self, node: nodes.Block, frame: Frame) -> None:
"""Call a block and register it for the template."""
level = 0
if frame.toplevel:
@@ -883,7 +983,7 @@ class CodeGenerator(NodeVisitor):
self.outdent(level)
- def visit_Extends(self, node, frame):
+ def visit_Extends(self, node: nodes.Extends, frame: Frame) -> None:
"""Calls the extender."""
if not frame.toplevel:
self.fail("cannot use extend from a non top-level scope", node.lineno)
@@ -926,7 +1026,7 @@ class CodeGenerator(NodeVisitor):
# and now we have one more
self.extends_so_far += 1
- def visit_Include(self, node, frame):
+ def visit_Include(self, node: nodes.Include, frame: Frame) -> None:
"""Handles includes."""
if node.ignore_missing:
self.writeline("try:")
@@ -977,7 +1077,9 @@ class CodeGenerator(NodeVisitor):
if node.ignore_missing:
self.outdent()
- def _import_common(self, node, frame):
+ def _import_common(
+ self, node: t.Union[nodes.Import, nodes.FromImport], frame: Frame
+ ) -> None:
self.write(f"{self.choose_async('await ')}environment.get_template(")
self.visit(node.template, frame)
self.write(f", {self.name!r}).")
@@ -992,7 +1094,7 @@ class CodeGenerator(NodeVisitor):
else:
self.write("_get_default_module(context)")
- def visit_Import(self, node, frame):
+ def visit_Import(self, node: nodes.Import, frame: Frame) -> None:
"""Visit regular imports."""
self.writeline(f"{frame.symbols.ref(node.target)} = ", node)
if frame.toplevel:
@@ -1003,7 +1105,7 @@ class CodeGenerator(NodeVisitor):
if frame.toplevel and not node.target.startswith("_"):
self.writeline(f"context.exported_vars.discard({node.target!r})")
- def visit_FromImport(self, node, frame):
+ def visit_FromImport(self, node: nodes.FromImport, frame: Frame) -> None:
"""Visit named imports."""
self.newline(node)
self.write("included_template = ")
@@ -1053,7 +1155,7 @@ class CodeGenerator(NodeVisitor):
f"context.exported_vars.difference_update(({names_str}))"
)
- def visit_For(self, node, frame):
+ def visit_For(self, node: nodes.For, frame: Frame) -> None:
loop_frame = frame.inner()
loop_frame.loop_frame = True
test_frame = frame.inner()
@@ -1187,7 +1289,7 @@ class CodeGenerator(NodeVisitor):
self.write(", loop)")
self.end_write(frame)
- def visit_If(self, node, frame):
+ def visit_If(self, node: nodes.If, frame: Frame) -> None:
if_frame = frame.soft()
self.writeline("if ", node)
self.visit(node.test, if_frame)
@@ -1208,7 +1310,7 @@ class CodeGenerator(NodeVisitor):
self.blockvisit(node.else_, if_frame)
self.outdent()
- def visit_Macro(self, node, frame):
+ def visit_Macro(self, node: nodes.Macro, frame: Frame) -> None:
macro_frame, macro_ref = self.macro_body(node, frame)
self.newline()
if frame.toplevel:
@@ -1218,7 +1320,7 @@ class CodeGenerator(NodeVisitor):
self.write(f"{frame.symbols.ref(node.name)} = ")
self.macro_def(macro_ref, macro_frame)
- def visit_CallBlock(self, node, frame):
+ def visit_CallBlock(self, node: nodes.CallBlock, frame: Frame) -> None:
call_frame, macro_ref = self.macro_body(node, frame)
self.writeline("caller = ")
self.macro_def(macro_ref, call_frame)
@@ -1226,7 +1328,7 @@ class CodeGenerator(NodeVisitor):
self.visit_Call(node.call, frame, forward_caller=True)
self.end_write(frame)
- def visit_FilterBlock(self, node, frame):
+ def visit_FilterBlock(self, node: nodes.FilterBlock, frame: Frame) -> None:
filter_frame = frame.inner()
filter_frame.symbols.analyze_node(node)
self.enter_frame(filter_frame)
@@ -1237,7 +1339,7 @@ class CodeGenerator(NodeVisitor):
self.end_write(frame)
self.leave_frame(filter_frame)
- def visit_With(self, node, frame):
+ def visit_With(self, node: nodes.With, frame: Frame) -> None:
with_frame = frame.inner()
with_frame.symbols.analyze_node(node)
self.enter_frame(with_frame)
@@ -1249,18 +1351,25 @@ class CodeGenerator(NodeVisitor):
self.blockvisit(node.body, with_frame)
self.leave_frame(with_frame)
- def visit_ExprStmt(self, node, frame):
+ def visit_ExprStmt(self, node: nodes.ExprStmt, frame: Frame) -> None:
self.newline(node)
self.visit(node.node, frame)
- _FinalizeInfo = namedtuple("_FinalizeInfo", ("const", "src"))
- #: The default finalize function if the environment isn't configured
- #: with one. Or if the environment has one, this is called on that
- #: function's output for constants.
- _default_finalize = str
+ class _FinalizeInfo(t.NamedTuple):
+ const: t.Optional[t.Callable[..., str]]
+ src: t.Optional[str]
+
+ @staticmethod
+ def _default_finalize(value: t.Any) -> t.Any:
+ """The default finalize function if the environment isn't
+ configured with one. Or, if the environment has one, this is
+ called on that function's output for constants.
+ """
+ return str(value)
+
_finalize: t.Optional[_FinalizeInfo] = None
- def _make_finalize(self):
+ def _make_finalize(self) -> _FinalizeInfo:
"""Build the finalize function to be used on constants and at
runtime. Cached so it's only created once for all output nodes.
@@ -1276,6 +1385,7 @@ class CodeGenerator(NodeVisitor):
if self._finalize is not None:
return self._finalize
+ finalize: t.Optional[t.Callable[..., t.Any]]
finalize = default = self._default_finalize
src = None
@@ -1286,12 +1396,14 @@ class CodeGenerator(NodeVisitor):
_PassArg.context: "context",
_PassArg.eval_context: "context.eval_ctx",
_PassArg.environment: "environment",
- }.get(_PassArg.from_obj(env_finalize))
+ }.get(
+ _PassArg.from_obj(env_finalize) # type: ignore
+ )
finalize = None
if pass_arg is None:
- def finalize(value):
+ def finalize(value: t.Any) -> t.Any:
return default(env_finalize(value))
else:
@@ -1299,20 +1411,22 @@ class CodeGenerator(NodeVisitor):
if pass_arg == "environment":
- def finalize(value):
+ def finalize(value: t.Any) -> t.Any:
return default(env_finalize(self.environment, value))
self._finalize = self._FinalizeInfo(finalize, src)
return self._finalize
- def _output_const_repr(self, group):
+ def _output_const_repr(self, group: t.Iterable[t.Any]) -> str:
"""Given a group of constant values converted from ``Output``
child nodes, produce a string to write to the template module
source.
"""
return repr(concat(group))
- def _output_child_to_const(self, node, frame, finalize):
+ def _output_child_to_const(
+ self, node: nodes.Expr, frame: Frame, finalize: _FinalizeInfo
+ ) -> str:
"""Try to optimize a child of an ``Output`` node by trying to
convert it to constant, finalized data at compile time.
@@ -1329,9 +1443,11 @@ class CodeGenerator(NodeVisitor):
if isinstance(node, nodes.TemplateData):
return str(const)
- return finalize.const(const)
+ return finalize.const(const) # type: ignore
- def _output_child_pre(self, node, frame, finalize):
+ def _output_child_pre(
+ self, node: nodes.Expr, frame: Frame, finalize: _FinalizeInfo
+ ) -> None:
"""Output extra source code before visiting a child of an
``Output`` node.
"""
@@ -1345,7 +1461,9 @@ class CodeGenerator(NodeVisitor):
if finalize.src is not None:
self.write(finalize.src)
- def _output_child_post(self, node, frame, finalize):
+ def _output_child_post(
+ self, node: nodes.Expr, frame: Frame, finalize: _FinalizeInfo
+ ) -> None:
"""Output extra source code after visiting a child of an
``Output`` node.
"""
@@ -1354,7 +1472,7 @@ class CodeGenerator(NodeVisitor):
if finalize.src is not None:
self.write(")")
- def visit_Output(self, node, frame):
+ def visit_Output(self, node: nodes.Output, frame: Frame) -> None:
# If an extends is active, don't render outside a block.
if frame.require_output_check:
# A top-level extends is known to exist at compile time.
@@ -1365,7 +1483,7 @@ class CodeGenerator(NodeVisitor):
self.indent()
finalize = self._make_finalize()
- body = []
+ body: t.List[t.Union[t.List[t.Any], nodes.Expr]] = []
# Evaluate constants at compile time if possible. Each item in
# body will be either a list of static data or a node to be
@@ -1433,7 +1551,7 @@ class CodeGenerator(NodeVisitor):
if frame.require_output_check:
self.outdent()
- def visit_Assign(self, node, frame):
+ def visit_Assign(self, node: nodes.Assign, frame: Frame) -> None:
self.push_assign_tracking()
self.newline(node)
self.visit(node.target, frame)
@@ -1441,7 +1559,7 @@ class CodeGenerator(NodeVisitor):
self.visit(node.node, frame)
self.pop_assign_tracking(frame)
- def visit_AssignBlock(self, node, frame):
+ def visit_AssignBlock(self, node: nodes.AssignBlock, frame: Frame) -> None:
self.push_assign_tracking()
block_frame = frame.inner()
# This is a special case. Since a set block always captures we
@@ -1465,7 +1583,7 @@ class CodeGenerator(NodeVisitor):
# -- Expression Visitors
- def visit_Name(self, node, frame):
+ def visit_Name(self, node: nodes.Name, frame: Frame) -> None:
if node.ctx == "store" and (
frame.toplevel or frame.loop_frame or frame.block_frame
):
@@ -1490,7 +1608,7 @@ class CodeGenerator(NodeVisitor):
self.write(ref)
- def visit_NSRef(self, node, frame):
+ def visit_NSRef(self, node: nodes.NSRef, frame: Frame) -> None:
# NSRefs can only be used to store values; since they use the normal
# `foo.bar` notation they will be parsed as a normal attribute access
# when used anywhere but in a `set` context
@@ -1504,14 +1622,14 @@ class CodeGenerator(NodeVisitor):
self.outdent()
self.writeline(f"{ref}[{node.attr!r}]")
- def visit_Const(self, node, frame):
+ def visit_Const(self, node: nodes.Const, frame: Frame) -> None:
val = node.as_const(frame.eval_ctx)
if isinstance(val, float):
self.write(str(val))
else:
self.write(repr(val))
- def visit_TemplateData(self, node, frame):
+ def visit_TemplateData(self, node: nodes.TemplateData, frame: Frame) -> None:
try:
self.write(repr(node.as_const(frame.eval_ctx)))
except nodes.Impossible:
@@ -1519,7 +1637,7 @@ class CodeGenerator(NodeVisitor):
f"(Markup if context.eval_ctx.autoescape else identity)({node.data!r})"
)
- def visit_Tuple(self, node, frame):
+ def visit_Tuple(self, node: nodes.Tuple, frame: Frame) -> None:
self.write("(")
idx = -1
for idx, item in enumerate(node.items):
@@ -1528,7 +1646,7 @@ class CodeGenerator(NodeVisitor):
self.visit(item, frame)
self.write(",)" if idx == 0 else ")")
- def visit_List(self, node, frame):
+ def visit_List(self, node: nodes.List, frame: Frame) -> None:
self.write("[")
for idx, item in enumerate(node.items):
if idx:
@@ -1536,7 +1654,7 @@ class CodeGenerator(NodeVisitor):
self.visit(item, frame)
self.write("]")
- def visit_Dict(self, node, frame):
+ def visit_Dict(self, node: nodes.Dict, frame: Frame) -> None:
self.write("{")
for idx, item in enumerate(node.items):
if idx:
@@ -1546,58 +1664,21 @@ class CodeGenerator(NodeVisitor):
self.visit(item.value, frame)
self.write("}")
- def binop(operator, interceptable=True): # noqa: B902
- @optimizeconst
- def visitor(self, node, frame):
- if (
- self.environment.sandboxed
- and operator in self.environment.intercepted_binops
- ):
- self.write(f"environment.call_binop(context, {operator!r}, ")
- self.visit(node.left, frame)
- self.write(", ")
- self.visit(node.right, frame)
- else:
- self.write("(")
- self.visit(node.left, frame)
- self.write(f" {operator} ")
- self.visit(node.right, frame)
- self.write(")")
-
- return visitor
-
- def uaop(operator, interceptable=True): # noqa: B902
- @optimizeconst
- def visitor(self, node, frame):
- if (
- self.environment.sandboxed
- and operator in self.environment.intercepted_unops
- ):
- self.write(f"environment.call_unop(context, {operator!r}, ")
- self.visit(node.node, frame)
- else:
- self.write("(" + operator)
- self.visit(node.node, frame)
- self.write(")")
-
- return visitor
-
- visit_Add = binop("+") # type:ignore
- visit_Sub = binop("-") # type:ignore
- visit_Mul = binop("*") # type:ignore
- visit_Div = binop("/") # type:ignore
- visit_FloorDiv = binop("//") # type:ignore
- visit_Pow = binop("**") # type:ignore
- visit_Mod = binop("%") # type:ignore
- visit_And = binop("and", interceptable=False) # type:ignore
- visit_Or = binop("or", interceptable=False) # type:ignore
- visit_Pos = uaop("+") # type:ignore
- visit_Neg = uaop("-") # type:ignore
- visit_Not = uaop("not ", interceptable=False) # type:ignore
- del binop, uaop
+ visit_Add = _make_binop("+")
+ visit_Sub = _make_binop("-")
+ visit_Mul = _make_binop("*")
+ visit_Div = _make_binop("/")
+ visit_FloorDiv = _make_binop("//")
+ visit_Pow = _make_binop("**")
+ visit_Mod = _make_binop("%")
+ visit_And = _make_binop("and")
+ visit_Or = _make_binop("or")
+ visit_Pos = _make_unop("+")
+ visit_Neg = _make_unop("-")
+ visit_Not = _make_unop("not ")
@optimizeconst
- def visit_Concat(self, node, frame):
+ def visit_Concat(self, node: nodes.Concat, frame: Frame) -> None:
if frame.eval_ctx.volatile:
func_name = "(markup_join if context.eval_ctx.volatile else str_join)"
elif frame.eval_ctx.autoescape:
@@ -1611,19 +1692,19 @@ class CodeGenerator(NodeVisitor):
self.write("))")
@optimizeconst
- def visit_Compare(self, node, frame):
+ def visit_Compare(self, node: nodes.Compare, frame: Frame) -> None:
self.write("(")
self.visit(node.expr, frame)
for op in node.ops:
self.visit(op, frame)
self.write(")")
- def visit_Operand(self, node, frame):
+ def visit_Operand(self, node: nodes.Operand, frame: Frame) -> None:
self.write(f" {operators[node.op]} ")
self.visit(node.expr, frame)
@optimizeconst
- def visit_Getattr(self, node, frame):
+ def visit_Getattr(self, node: nodes.Getattr, frame: Frame) -> None:
if self.environment.is_async:
self.write("(await auto_await(")
@@ -1635,7 +1716,7 @@ class CodeGenerator(NodeVisitor):
self.write("))")
@optimizeconst
- def visit_Getitem(self, node, frame):
+ def visit_Getitem(self, node: nodes.Getitem, frame: Frame) -> None:
# slices bypass the environment getitem method.
if isinstance(node.arg, nodes.Slice):
self.visit(node.node, frame)
@@ -1655,7 +1736,7 @@ class CodeGenerator(NodeVisitor):
if self.environment.is_async:
self.write("))")
- def visit_Slice(self, node, frame):
+ def visit_Slice(self, node: nodes.Slice, frame: Frame) -> None:
if node.start is not None:
self.visit(node.start, frame)
self.write(":")
@@ -1666,33 +1747,33 @@ class CodeGenerator(NodeVisitor):
self.visit(node.step, frame)
@contextmanager
- def _filter_test_common(self, node, frame, is_filter):
- if is_filter:
- compiler_map = self.filters
- env_map = self.environment.filters
- type_name = "filter"
- else:
- compiler_map = self.tests
- env_map = self.environment.tests
- type_name = "test"
-
+ def _filter_test_common(
+ self, node: t.Union[nodes.Filter, nodes.Test], frame: Frame, is_filter: bool
+ ) -> t.Iterator[None]:
if self.environment.is_async:
self.write("await auto_await(")
- self.write(compiler_map[node.name] + "(")
- func = env_map.get(node.name)
+ if is_filter:
+ self.write(f"{self.filters[node.name]}(")
+ func = self.environment.filters.get(node.name)
+ else:
+ self.write(f"{self.tests[node.name]}(")
+ func = self.environment.tests.get(node.name)
# When inside an If or CondExpr frame, allow the filter to be
# undefined at compile time and only raise an error if it's
# actually called at runtime. See pull_dependencies.
if func is None and not frame.soft_frame:
+ type_name = "filter" if is_filter else "test"
self.fail(f"No {type_name} named {node.name!r}.", node.lineno)
pass_arg = {
_PassArg.context: "context",
_PassArg.eval_context: "context.eval_ctx",
_PassArg.environment: "environment",
- }.get(_PassArg.from_obj(func))
+ }.get(
+ _PassArg.from_obj(func) # type: ignore
+ )
if pass_arg is not None:
self.write(f"{pass_arg}, ")
@@ -1708,7 +1789,7 @@ class CodeGenerator(NodeVisitor):
self.write(")")
@optimizeconst
- def visit_Filter(self, node, frame):
+ def visit_Filter(self, node: nodes.Filter, frame: Frame) -> None:
with self._filter_test_common(node, frame, True):
# if the filter node is None we are inside a filter block
# and want to write to the current buffer
@@ -1725,17 +1806,19 @@ class CodeGenerator(NodeVisitor):
self.write(f"concat({frame.buffer})")
@optimizeconst
- def visit_Test(self, node, frame):
+ def visit_Test(self, node: nodes.Test, frame: Frame) -> None:
with self._filter_test_common(node, frame, False):
self.visit(node.node, frame)
@optimizeconst
- def visit_CondExpr(self, node, frame):
+ def visit_CondExpr(self, node: nodes.CondExpr, frame: Frame) -> None:
frame = frame.soft()
- def write_expr2():
+ def write_expr2() -> None:
if node.expr2 is not None:
- return self.visit(node.expr2, frame)
+ self.visit(node.expr2, frame)
+ return
+
self.write(
f'cond_expr_undefined("the inline if-expression on'
f" {self.position(node)} evaluated to false and no else"
@@ -1751,7 +1834,9 @@ class CodeGenerator(NodeVisitor):
self.write(")")
@optimizeconst
- def visit_Call(self, node, frame, forward_caller=False):
+ def visit_Call(
+ self, node: nodes.Call, frame: Frame, forward_caller: bool = False
+ ) -> None:
if self.environment.is_async:
self.write("await auto_await(")
if self.environment.sandboxed:
@@ -1771,54 +1856,64 @@ class CodeGenerator(NodeVisitor):
if self.environment.is_async:
self.write(")")
- def visit_Keyword(self, node, frame):
+ def visit_Keyword(self, node: nodes.Keyword, frame: Frame) -> None:
self.write(node.key + "=")
self.visit(node.value, frame)
# -- Unused nodes for extensions
- def visit_MarkSafe(self, node, frame):
+ def visit_MarkSafe(self, node: nodes.MarkSafe, frame: Frame) -> None:
self.write("Markup(")
self.visit(node.expr, frame)
self.write(")")
- def visit_MarkSafeIfAutoescape(self, node, frame):
+ def visit_MarkSafeIfAutoescape(
+ self, node: nodes.MarkSafeIfAutoescape, frame: Frame
+ ) -> None:
self.write("(Markup if context.eval_ctx.autoescape else identity)(")
self.visit(node.expr, frame)
self.write(")")
- def visit_EnvironmentAttribute(self, node, frame):
+ def visit_EnvironmentAttribute(
+ self, node: nodes.EnvironmentAttribute, frame: Frame
+ ) -> None:
self.write("environment." + node.name)
- def visit_ExtensionAttribute(self, node, frame):
+ def visit_ExtensionAttribute(
+ self, node: nodes.ExtensionAttribute, frame: Frame
+ ) -> None:
self.write(f"environment.extensions[{node.identifier!r}].{node.name}")
- def visit_ImportedName(self, node, frame):
+ def visit_ImportedName(self, node: nodes.ImportedName, frame: Frame) -> None:
self.write(self.import_aliases[node.importname])
- def visit_InternalName(self, node, frame):
+ def visit_InternalName(self, node: nodes.InternalName, frame: Frame) -> None:
self.write(node.name)
- def visit_ContextReference(self, node, frame):
+ def visit_ContextReference(
+ self, node: nodes.ContextReference, frame: Frame
+ ) -> None:
self.write("context")
- def visit_DerivedContextReference(self, node, frame):
+ def visit_DerivedContextReference(
+ self, node: nodes.DerivedContextReference, frame: Frame
+ ) -> None:
self.write(self.derive_context(frame))
- def visit_Continue(self, node, frame):
+ def visit_Continue(self, node: nodes.Continue, frame: Frame) -> None:
self.writeline("continue", node)
- def visit_Break(self, node, frame):
+ def visit_Break(self, node: nodes.Break, frame: Frame) -> None:
self.writeline("break", node)
- def visit_Scope(self, node, frame):
+ def visit_Scope(self, node: nodes.Scope, frame: Frame) -> None:
scope_frame = frame.inner()
scope_frame.symbols.analyze_node(node)
self.enter_frame(scope_frame)
self.blockvisit(node.body, scope_frame)
self.leave_frame(scope_frame)
- def visit_OverlayScope(self, node, frame):
+ def visit_OverlayScope(self, node: nodes.OverlayScope, frame: Frame) -> None:
ctx = self.temporary_identifier()
self.writeline(f"{ctx} = {self.derive_context(frame)}")
self.writeline(f"{ctx}.vars = ")
@@ -1832,7 +1927,9 @@ class CodeGenerator(NodeVisitor):
self.leave_frame(scope_frame)
self.pop_context_reference()
- def visit_EvalContextModifier(self, node, frame):
+ def visit_EvalContextModifier(
+ self, node: nodes.EvalContextModifier, frame: Frame
+ ) -> None:
for keyword in node.options:
self.writeline(f"context.eval_ctx.{keyword.key} = ")
self.visit(keyword.value, frame)
@@ -1843,7 +1940,9 @@ class CodeGenerator(NodeVisitor):
else:
setattr(frame.eval_ctx, keyword.key, val)
- def visit_ScopedEvalContextModifier(self, node, frame):
+ def visit_ScopedEvalContextModifier(
+ self, node: nodes.ScopedEvalContextModifier, frame: Frame
+ ) -> None:
old_ctx_name = self.temporary_identifier()
saved_ctx = frame.eval_ctx.save()
self.writeline(f"{old_ctx_name} = context.eval_ctx.save()")
diff --git a/src/jinja2/debug.py b/src/jinja2/debug.py
index 8b5cd65..02de4ee 100644
--- a/src/jinja2/debug.py
+++ b/src/jinja2/debug.py
@@ -1,13 +1,18 @@
import platform
import sys
+import typing as t
from types import CodeType
+from types import TracebackType
-from . import TemplateSyntaxError
+from .exceptions import TemplateSyntaxError
from .utils import internal_code
from .utils import missing
+if t.TYPE_CHECKING:
+ from .runtime import Context
-def rewrite_traceback_stack(source=None):
+
+def rewrite_traceback_stack(source: t.Optional[str] = None) -> BaseException:
"""Rewrite the current exception to replace any tracebacks from
within compiled template code with tracebacks that look like they
came from the template source.
@@ -19,6 +24,8 @@ def rewrite_traceback_stack(source=None):
:return: The original exception with the rewritten traceback.
"""
_, exc_value, tb = sys.exc_info()
+ exc_value = t.cast(BaseException, exc_value)
+ tb = t.cast(TracebackType, tb)
if isinstance(exc_value, TemplateSyntaxError) and not exc_value.translated:
exc_value.translated = True
@@ -66,7 +73,9 @@ def rewrite_traceback_stack(source=None):
return exc_value.with_traceback(tb_next)
-def fake_traceback(exc_value, tb, filename, lineno):
+def fake_traceback( # type: ignore
+ exc_value: BaseException, tb: t.Optional[TracebackType], filename: str, lineno: int
+) -> TracebackType:
"""Produce a new traceback object that looks like it came from the
template source instead of the compiled code. The filename, line
number, and location name will point to the template, and the local
@@ -139,7 +148,7 @@ def fake_traceback(exc_value, tb, filename, lineno):
try:
# Copy original value if it exists.
- code_args.append(getattr(code, "co_" + attr))
+ code_args.append(getattr(code, "co_" + t.cast(str, attr)))
except AttributeError:
# Some arguments were added later.
continue
@@ -155,18 +164,18 @@ def fake_traceback(exc_value, tb, filename, lineno):
try:
exec(code, globals, locals)
except BaseException:
- return sys.exc_info()[2].tb_next
+ return sys.exc_info()[2].tb_next # type: ignore
-def get_template_locals(real_locals):
+def get_template_locals(real_locals: t.Mapping[str, t.Any]) -> t.Dict[str, t.Any]:
"""Based on the runtime locals, get the context that would be
available at that point in the template.
"""
# Start with the current template context.
- ctx = real_locals.get("context")
+ ctx: "t.Optional[Context]" = real_locals.get("context")
- if ctx:
- data = ctx.get_all().copy()
+ if ctx is not None:
+ data: t.Dict[str, t.Any] = ctx.get_all().copy()
else:
data = {}
@@ -174,7 +183,7 @@ def get_template_locals(real_locals):
# rather than pushing a context. Local variables follow the scheme
# l_depth_name. Find the highest-depth local that has a value for
# each name.
- local_overrides = {}
+ local_overrides: t.Dict[str, t.Tuple[int, t.Any]] = {}
for name, value in real_locals.items():
if not name.startswith("l_") or value is missing:
@@ -182,8 +191,8 @@ def get_template_locals(real_locals):
continue
try:
- _, depth, name = name.split("_", 2)
- depth = int(depth)
+ _, depth_str, name = name.split("_", 2)
+ depth = int(depth_str)
except ValueError:
continue
@@ -204,7 +213,9 @@ def get_template_locals(real_locals):
if sys.version_info >= (3, 7):
# tb_next is directly assignable as of Python 3.7
- def tb_set_next(tb, tb_next):
+ def tb_set_next(
+ tb: TracebackType, tb_next: t.Optional[TracebackType]
+ ) -> TracebackType:
tb.tb_next = tb_next
return tb
@@ -215,20 +226,24 @@ elif platform.python_implementation() == "PyPy":
import tputil # type: ignore
except ImportError:
# Without tproxy support, use the original traceback.
- def tb_set_next(tb, tb_next):
+ def tb_set_next(
+ tb: TracebackType, tb_next: t.Optional[TracebackType]
+ ) -> TracebackType:
return tb
else:
# With tproxy support, create a proxy around the traceback that
# returns the new tb_next.
- def tb_set_next(tb, tb_next):
- def controller(op):
+ def tb_set_next(
+ tb: TracebackType, tb_next: t.Optional[TracebackType]
+ ) -> TracebackType:
+ def controller(op): # type: ignore
if op.opname == "__getattribute__" and op.args[0] == "tb_next":
return tb_next
return op.delegate()
- return tputil.make_proxy(controller, obj=tb)
+ return tputil.make_proxy(controller, obj=tb) # type: ignore
else:
@@ -244,7 +259,9 @@ else:
("tb_next", ctypes.py_object),
]
- def tb_set_next(tb, tb_next):
+ def tb_set_next(
+ tb: TracebackType, tb_next: t.Optional[TracebackType]
+ ) -> TracebackType:
c_tb = _CTraceback.from_address(id(tb))
# Clear out the old tb_next.
diff --git a/src/jinja2/defaults.py b/src/jinja2/defaults.py
index a841f61..638cad3 100644
--- a/src/jinja2/defaults.py
+++ b/src/jinja2/defaults.py
@@ -7,6 +7,9 @@ from .utils import generate_lorem_ipsum
from .utils import Joiner
from .utils import Namespace
+if t.TYPE_CHECKING:
+ import typing_extensions as te
+
# defaults for the parser / lexer
BLOCK_START_STRING = "{%"
BLOCK_END_STRING = "%}"
@@ -18,7 +21,7 @@ LINE_STATEMENT_PREFIX: t.Optional[str] = None
LINE_COMMENT_PREFIX: t.Optional[str] = None
TRIM_BLOCKS = False
LSTRIP_BLOCKS = False
-NEWLINE_SEQUENCE = "\n"
+NEWLINE_SEQUENCE: "te.Literal['\\n', '\\r\\n', '\\r']" = "\n"
KEEP_TRAILING_NEWLINE = False
# default filters, tests and namespace
@@ -33,7 +36,7 @@ DEFAULT_NAMESPACE = {
}
# default policies
-DEFAULT_POLICIES = {
+DEFAULT_POLICIES: t.Dict[str, t.Any] = {
"compiler.ascii_str": True,
"urlize.rel": "noopener",
"urlize.target": None,
diff --git a/src/jinja2/environment.py b/src/jinja2/environment.py
index ae68738..d63653f 100644
--- a/src/jinja2/environment.py
+++ b/src/jinja2/environment.py
@@ -2,12 +2,14 @@
options.
"""
import os
-import sys
+import typing
import typing as t
import weakref
from collections import ChainMap
+from functools import lru_cache
from functools import partial
from functools import reduce
+from types import CodeType
from markupsafe import Markup
@@ -36,6 +38,7 @@ from .exceptions import TemplatesNotFound
from .exceptions import TemplateSyntaxError
from .exceptions import UndefinedError
from .lexer import get_lexer
+from .lexer import Lexer
from .lexer import TokenStream
from .nodes import EvalContext
from .parser import Parser
@@ -50,11 +53,18 @@ from .utils import internalcode
from .utils import LRUCache
from .utils import missing
-# for direct template usage we have up to ten living environments
-_spontaneous_environments = LRUCache(10)
+if t.TYPE_CHECKING:
+ import typing_extensions as te
+ from .bccache import BytecodeCache
+ from .ext import Extension
+ from .loaders import BaseLoader
+
+_env_bound = t.TypeVar("_env_bound", bound="Environment")
-def get_spontaneous_environment(cls, *args):
+# for direct template usage we have up to ten living environments
+@lru_cache(maxsize=10)
+def get_spontaneous_environment(cls: t.Type[_env_bound], *args: t.Any) -> _env_bound:
"""Return a new spontaneous environment. A spontaneous environment
is used for templates created directly rather than through an
existing environment.
@@ -62,61 +72,70 @@ def get_spontaneous_environment(cls, *args):
:param cls: Environment class to create.
:param args: Positional arguments passed to environment.
"""
- key = (cls, args)
-
- try:
- return _spontaneous_environments[key]
- except KeyError:
- _spontaneous_environments[key] = env = cls(*args)
- env.shared = True
- return env
+ env = cls(*args)
+ env.shared = True
+ return env
-def create_cache(size):
+def create_cache(
+ size: int,
+) -> t.Optional[t.MutableMapping[t.Tuple[weakref.ref, str], "Template"]]:
"""Return the cache class for the given size."""
if size == 0:
return None
+
if size < 0:
return {}
- return LRUCache(size)
+
+ return LRUCache(size) # type: ignore
-def copy_cache(cache):
+def copy_cache(
+ cache: t.Optional[t.MutableMapping],
+) -> t.Optional[t.MutableMapping[t.Tuple[weakref.ref, str], "Template"]]:
"""Create an empty copy of the given cache."""
if cache is None:
return None
- elif type(cache) is dict:
+
+ if type(cache) is dict:
return {}
- return LRUCache(cache.capacity)
+
+ return LRUCache(cache.capacity) # type: ignore
-def load_extensions(environment, extensions):
+def load_extensions(
+ environment: "Environment",
+ extensions: t.Sequence[t.Union[str, t.Type["Extension"]]],
+) -> t.Dict[str, "Extension"]:
"""Load the extensions from the list and bind it to the environment.
- Returns a dict of instantiated environments.
+ Returns a dict of instantiated extensions.
"""
result = {}
+
for extension in extensions:
if isinstance(extension, str):
- extension = import_string(extension)
+ extension = t.cast(t.Type["Extension"], import_string(extension))
+
result[extension.identifier] = extension(environment)
+
return result
-def _environment_sanity_check(environment):
+def _environment_config_check(environment: "Environment") -> "Environment":
"""Perform a sanity check on the environment."""
assert issubclass(
environment.undefined, Undefined
- ), "undefined must be a subclass of undefined because filters depend on it."
+ ), "'undefined' must be a subclass of 'jinja2.Undefined'."
assert (
environment.block_start_string
!= environment.variable_start_string
!= environment.comment_start_string
- ), "block, variable and comment start strings must be different"
+ ), "block, variable and comment start strings must be different."
assert environment.newline_sequence in {
"\r",
"\r\n",
"\n",
- }, "newline_sequence set to unknown line ending string."
+ }, "'newline_sequence' must be one of '\\n', '\\r\\n', or '\\r'."
return environment
@@ -260,38 +279,38 @@ class Environment:
#: the class that is used for code generation. See
#: :class:`~jinja2.compiler.CodeGenerator` for more information.
- code_generator_class = CodeGenerator
+ code_generator_class: t.Type["CodeGenerator"] = CodeGenerator
#: the context class that is used for templates. See
#: :class:`~jinja2.runtime.Context` for more information.
- context_class = Context
+ context_class: t.Type[Context] = Context
template_class: t.Type["Template"]
def __init__(
self,
- block_start_string=BLOCK_START_STRING,
- block_end_string=BLOCK_END_STRING,
- variable_start_string=VARIABLE_START_STRING,
- variable_end_string=VARIABLE_END_STRING,
- comment_start_string=COMMENT_START_STRING,
- comment_end_string=COMMENT_END_STRING,
- line_statement_prefix=LINE_STATEMENT_PREFIX,
- line_comment_prefix=LINE_COMMENT_PREFIX,
- trim_blocks=TRIM_BLOCKS,
- lstrip_blocks=LSTRIP_BLOCKS,
- newline_sequence=NEWLINE_SEQUENCE,
- keep_trailing_newline=KEEP_TRAILING_NEWLINE,
- extensions=(),
- optimized=True,
- undefined=Undefined,
- finalize=None,
- autoescape=False,
- loader=None,
- cache_size=400,
- auto_reload=True,
- bytecode_cache=None,
- enable_async=False,
+ block_start_string: str = BLOCK_START_STRING,
+ block_end_string: str = BLOCK_END_STRING,
+ variable_start_string: str = VARIABLE_START_STRING,
+ variable_end_string: str = VARIABLE_END_STRING,
+ comment_start_string: str = COMMENT_START_STRING,
+ comment_end_string: str = COMMENT_END_STRING,
+ line_statement_prefix: t.Optional[str] = LINE_STATEMENT_PREFIX,
+ line_comment_prefix: t.Optional[str] = LINE_COMMENT_PREFIX,
+ trim_blocks: bool = TRIM_BLOCKS,
+ lstrip_blocks: bool = LSTRIP_BLOCKS,
+ newline_sequence: "te.Literal['\\n', '\\r\\n', '\\r']" = NEWLINE_SEQUENCE,
+ keep_trailing_newline: bool = KEEP_TRAILING_NEWLINE,
+ extensions: t.Sequence[t.Union[str, t.Type["Extension"]]] = (),
+ optimized: bool = True,
+ undefined: t.Type[Undefined] = Undefined,
+ finalize: t.Optional[t.Callable[..., t.Any]] = None,
+ autoescape: t.Union[bool, t.Callable[[t.Optional[str]], bool]] = False,
+ loader: t.Optional["BaseLoader"] = None,
+ cache_size: int = 400,
+ auto_reload: bool = True,
+ bytecode_cache: t.Optional["BytecodeCache"] = None,
+ enable_async: bool = False,
):
# !!Important notice!!
# The constructor accepts quite a few arguments that should be
@@ -342,16 +361,16 @@ class Environment:
self.extensions = load_extensions(self, extensions)
self.is_async = enable_async
- _environment_sanity_check(self)
+ _environment_config_check(self)
- def add_extension(self, extension):
+ def add_extension(self, extension: t.Union[str, t.Type["Extension"]]) -> None:
"""Adds an extension after the environment was created.
.. versionadded:: 2.5
"""
self.extensions.update(load_extensions(self, [extension]))
- def extend(self, **attributes):
+ def extend(self, **attributes: t.Any) -> None:
"""Add the items to the instance of the environment if they do not exist
yet. This is used by :ref:`extensions <writing-extensions>` to register
callbacks and configuration values without breaking inheritance.
@@ -362,26 +381,26 @@ class Environment:
def overlay(
self,
- block_start_string=missing,
- block_end_string=missing,
- variable_start_string=missing,
- variable_end_string=missing,
- comment_start_string=missing,
- comment_end_string=missing,
- line_statement_prefix=missing,
- line_comment_prefix=missing,
- trim_blocks=missing,
- lstrip_blocks=missing,
- extensions=missing,
- optimized=missing,
- undefined=missing,
- finalize=missing,
- autoescape=missing,
- loader=missing,
- cache_size=missing,
- auto_reload=missing,
- bytecode_cache=missing,
- ):
+ block_start_string: str = missing,
+ block_end_string: str = missing,
+ variable_start_string: str = missing,
+ variable_end_string: str = missing,
+ comment_start_string: str = missing,
+ comment_end_string: str = missing,
+ line_statement_prefix: t.Optional[str] = missing,
+ line_comment_prefix: t.Optional[str] = missing,
+ trim_blocks: bool = missing,
+ lstrip_blocks: bool = missing,
+ extensions: t.Sequence[t.Union[str, t.Type["Extension"]]] = missing,
+ optimized: bool = missing,
+ undefined: t.Type[Undefined] = missing,
+ finalize: t.Optional[t.Callable[..., t.Any]] = missing,
+ autoescape: t.Union[bool, t.Callable[[t.Optional[str]], bool]] = missing,
+ loader: t.Optional["BaseLoader"] = missing,
+ cache_size: int = missing,
+ auto_reload: bool = missing,
+ bytecode_cache: t.Optional["BytecodeCache"] = missing,
+ ) -> "Environment":
"""Create a new overlay environment that shares all the data with the
current environment except for cache and the overridden attributes.
Extensions cannot be removed for an overlayed environment. An overlayed
@@ -416,15 +435,20 @@ class Environment:
if extensions is not missing:
rv.extensions.update(load_extensions(rv, extensions))
- return _environment_sanity_check(rv)
+ return _environment_config_check(rv)
- lexer = property(get_lexer, doc="The lexer for this environment.")
+ @property
+ def lexer(self) -> Lexer:
+ """The lexer for this environment."""
+ return get_lexer(self)
- def iter_extensions(self):
+ def iter_extensions(self) -> t.Iterator["Extension"]:
"""Iterates over the extensions by priority."""
return iter(sorted(self.extensions.values(), key=lambda x: x.priority))
- def getitem(self, obj, argument):
+ def getitem(
+ self, obj: t.Any, argument: t.Union[str, t.Any]
+ ) -> t.Union[t.Any, Undefined]:
"""Get an item or attribute of an object but prefer the item."""
try:
return obj[argument]
@@ -441,9 +465,9 @@ class Environment:
pass
return self.undefined(obj=obj, name=argument)
- def getattr(self, obj, attribute):
+ def getattr(self, obj: t.Any, attribute: str) -> t.Any:
"""Get an item or attribute of an object but prefer the attribute.
- Unlike :meth:`getitem` the attribute *must* be a bytestring.
+ Unlike :meth:`getitem` the attribute *must* be a string.
"""
try:
return getattr(obj, attribute)
@@ -455,8 +479,15 @@ class Environment:
return self.undefined(obj=obj, name=attribute)
def _filter_test_common(
- self, name, value, args, kwargs, context, eval_ctx, is_filter
- ):
+ self,
+ name: t.Union[str, Undefined],
+ value: t.Any,
+ args: t.Optional[t.Sequence[t.Any]],
+ kwargs: t.Optional[t.Mapping[str, t.Any]],
+ context: t.Optional[Context],
+ eval_ctx: t.Optional[EvalContext],
+ is_filter: bool,
+ ) -> t.Any:
if is_filter:
env_map = self.filters
type_name = "filter"
@@ -464,7 +495,7 @@ class Environment:
env_map = self.tests
type_name = "test"
- func = env_map.get(name)
+ func = env_map.get(name) # type: ignore
if func is None:
msg = f"No {type_name} named {name!r}."
@@ -502,8 +533,14 @@ class Environment:
return func(*args, **kwargs)
def call_filter(
- self, name, value, args=None, kwargs=None, context=None, eval_ctx=None
- ):
+ self,
+ name: str,
+ value: t.Any,
+ args: t.Optional[t.Sequence[t.Any]] = None,
+ kwargs: t.Optional[t.Mapping[str, t.Any]] = None,
+ context: t.Optional[Context] = None,
+ eval_ctx: t.Optional[EvalContext] = None,
+ ) -> t.Any:
"""Invoke a filter on a value the same way the compiler does.
This might return a coroutine if the filter is running from an
@@ -517,8 +554,14 @@ class Environment:
)
def call_test(
- self, name, value, args=None, kwargs=None, context=None, eval_ctx=None
- ):
+ self,
+ name: str,
+ value: t.Any,
+ args: t.Optional[t.Sequence[t.Any]] = None,
+ kwargs: t.Optional[t.Mapping[str, t.Any]] = None,
+ context: t.Optional[Context] = None,
+ eval_ctx: t.Optional[EvalContext] = None,
+ ) -> t.Any:
"""Invoke a test on a value the same way the compiler does.
This might return a coroutine if the test is running from an
@@ -536,7 +579,12 @@ class Environment:
)
@internalcode
- def parse(self, source, name=None, filename=None):
+ def parse(
+ self,
+ source: str,
+ name: t.Optional[str] = None,
+ filename: t.Optional[str] = None,
+ ) -> nodes.Template:
"""Parse the sourcecode and return the abstract syntax tree. This
tree of nodes is used by the compiler to convert the template into
executable source- or bytecode. This is useful for debugging or to
@@ -550,11 +598,18 @@ class Environment:
except TemplateSyntaxError:
self.handle_exception(source=source)
- def _parse(self, source, name, filename):
+ def _parse(
+ self, source: str, name: t.Optional[str], filename: t.Optional[str]
+ ) -> nodes.Template:
"""Internal parsing function used by `parse` and `compile`."""
return Parser(self, source, name, filename).parse()
- def lex(self, source, name=None, filename=None):
+ def lex(
+ self,
+ source: str,
+ name: t.Optional[str] = None,
+ filename: t.Optional[str] = None,
+ ) -> t.Iterator[t.Tuple[int, str, str]]:
"""Lex the given sourcecode and return a generator that yields
tokens as tuples in the form ``(lineno, token_type, value)``.
This can be useful for :ref:`extension development <writing-extensions>`
@@ -570,7 +625,12 @@ class Environment:
except TemplateSyntaxError:
self.handle_exception(source=source)
- def preprocess(self, source, name=None, filename=None):
+ def preprocess(
+ self,
+ source: str,
+ name: t.Optional[str] = None,
+ filename: t.Optional[str] = None,
+ ) -> str:
"""Preprocesses the source with all extensions. This is automatically
called for all parsing and compiling methods but *not* for :meth:`lex`
because there you usually only want the actual source tokenized.
@@ -581,25 +641,40 @@ class Environment:
str(source),
)
- def _tokenize(self, source, name, filename=None, state=None):
+ def _tokenize(
+ self,
+ source: str,
+ name: t.Optional[str],
+ filename: t.Optional[str] = None,
+ state: t.Optional[str] = None,
+ ) -> TokenStream:
"""Called by the parser to do the preprocessing and filtering
for all the extensions. Returns a :class:`~jinja2.lexer.TokenStream`.
"""
source = self.preprocess(source, name, filename)
stream = self.lexer.tokenize(source, name, filename, state)
+
for ext in self.iter_extensions():
- stream = ext.filter_stream(stream)
+ stream = ext.filter_stream(stream) # type: ignore
+
if not isinstance(stream, TokenStream):
- stream = TokenStream(stream, name, filename)
+ stream = TokenStream(stream, name, filename) # type: ignore
+
return stream
- def _generate(self, source, name, filename, defer_init=False):
+ def _generate(
+ self,
+ source: nodes.Template,
+ name: t.Optional[str],
+ filename: t.Optional[str],
+ defer_init: bool = False,
+ ) -> str:
"""Internal hook that can be overridden to hook a different generate
method in.
.. versionadded:: 2.5
"""
- return generate(
+ return generate( # type: ignore
source,
self,
name,
@@ -608,16 +683,45 @@ class Environment:
optimized=self.optimized,
)
- def _compile(self, source, filename):
+ def _compile(self, source: str, filename: str) -> CodeType:
"""Internal hook that can be overridden to hook a different compile
method in.
.. versionadded:: 2.5
"""
- return compile(source, filename, "exec")
+ return compile(source, filename, "exec") # type: ignore
+
+ @typing.overload
+ def compile( # type: ignore
+ self,
+ source: t.Union[str, nodes.Template],
+ name: t.Optional[str] = None,
+ filename: t.Optional[str] = None,
+ raw: "te.Literal[False]" = False,
+ defer_init: bool = False,
+ ) -> CodeType:
+ ...
+
+ @typing.overload
+ def compile(
+ self,
+ source: t.Union[str, nodes.Template],
+ name: t.Optional[str] = None,
+ filename: t.Optional[str] = None,
+ raw: "te.Literal[True]" = ...,
+ defer_init: bool = False,
+ ) -> str:
+ ...
@internalcode
- def compile(self, source, name=None, filename=None, raw=False, defer_init=False):
+ def compile(
+ self,
+ source: t.Union[str, nodes.Template],
+ name: t.Optional[str] = None,
+ filename: t.Optional[str] = None,
+ raw: bool = False,
+ defer_init: bool = False,
+ ) -> t.Union[str, CodeType]:
"""Compile a node or template source code. The `name` parameter is
the load name of the template after it was joined using
:meth:`join_path` if necessary, not the filename on the file system.
@@ -651,7 +755,9 @@ class Environment:
except TemplateSyntaxError:
self.handle_exception(source=source_hint)
- def compile_expression(self, source, undefined_to_none=True):
+ def compile_expression(
+ self, source: str, undefined_to_none: bool = True
+ ) -> "TemplateExpression":
"""A handy helper method that returns a callable that accepts keyword
arguments that appear as variables in the expression. If called it
returns the result of the expression.
@@ -688,8 +794,7 @@ class Environment:
)
expr.set_environment(self)
except TemplateSyntaxError:
- if sys.exc_info() is not None:
- self.handle_exception(source=source)
+ self.handle_exception(source=source)
body = [nodes.Assign(nodes.Name("result", "store"), expr, lineno=1)]
template = self.from_string(nodes.Template(body, lineno=1))
@@ -697,13 +802,13 @@ class Environment:
def compile_templates(
self,
- target,
- extensions=None,
- filter_func=None,
- zip="deflated",
- log_function=None,
- ignore_errors=True,
- ):
+ target: t.Union[str, os.PathLike],
+ extensions: t.Optional[t.Collection[str]] = None,
+ filter_func: t.Optional[t.Callable[[str], bool]] = None,
+ zip: t.Optional[str] = "deflated",
+ log_function: t.Optional[t.Callable[[str], None]] = None,
+ ignore_errors: bool = True,
+ ) -> None:
"""Finds all the templates the loader can find, compiles them
and stores them in `target`. If `zip` is `None`, instead of in a
zipfile, the templates will be stored in a directory.
@@ -725,20 +830,20 @@ class Environment:
if log_function is None:
- def log_function(x):
+ def log_function(x: str) -> None:
pass
- def write_file(filename, data):
+ assert log_function is not None
+ assert self.loader is not None, "No loader configured."
+
+ def write_file(filename: str, data: str) -> None:
if zip:
info = ZipInfo(filename)
info.external_attr = 0o755 << 16
zip_file.writestr(info, data)
else:
- if isinstance(data, str):
- data = data.encode("utf8")
-
with open(os.path.join(target, filename), "wb") as f:
- f.write(data)
+ f.write(data.encode("utf8"))
if zip is not None:
from zipfile import ZipFile, ZipInfo, ZIP_DEFLATED, ZIP_STORED
@@ -773,7 +878,11 @@ class Environment:
log_function("Finished compiling templates")
- def list_templates(self, extensions=None, filter_func=None):
+ def list_templates(
+ self,
+ extensions: t.Optional[t.Collection[str]] = None,
+ filter_func: t.Optional[t.Callable[[str], bool]] = None,
+ ) -> t.List[str]:
"""Returns a list of templates for this environment. This requires
that the loader supports the loader's
:meth:`~BaseLoader.list_templates` method.
@@ -789,6 +898,7 @@ class Environment:
.. versionadded:: 2.4
"""
+ assert self.loader is not None, "No loader configured."
names = self.loader.list_templates()
if extensions is not None:
@@ -797,15 +907,15 @@ class Environment:
"either extensions or filter_func can be passed, but not both"
)
- def filter_func(x):
- return "." in x and x.rsplit(".", 1)[1] in extensions
+ def filter_func(x: str) -> bool:
+ return "." in x and x.rsplit(".", 1)[1] in extensions # type: ignore
if filter_func is not None:
names = [name for name in names if filter_func(name)]
return names
- def handle_exception(self, source=None):
+ def handle_exception(self, source: t.Optional[str] = None) -> t.NoReturn:
"""Exception handling helper. This is used internally to either raise
rewritten exceptions or return a rendered traceback for the template.
"""
@@ -813,7 +923,7 @@ class Environment:
raise rewrite_traceback_stack(source=source)
- def join_path(self, template, parent):
+ def join_path(self, template: str, parent: str) -> str:
"""Join a template with the parent. By default all the lookups are
relative to the loader root so this method returns the `template`
parameter unchanged, but if the paths should be relative to the
@@ -826,7 +936,9 @@ class Environment:
return template
@internalcode
- def _load_template(self, name, globals):
+ def _load_template(
+ self, name: str, globals: t.Optional[t.Mapping[str, t.Any]]
+ ) -> "Template":
if self.loader is None:
raise TypeError("no loader for this environment specified")
cache_key = (weakref.ref(self.loader), name)
@@ -849,7 +961,12 @@ class Environment:
return template
@internalcode
- def get_template(self, name, parent=None, globals=None):
+ def get_template(
+ self,
+ name: t.Union[str, "Template"],
+ parent: t.Optional[str] = None,
+ globals: t.Optional[t.Mapping[str, t.Any]] = None,
+ ) -> "Template":
"""Load a template by name with :attr:`loader` and return a
:class:`Template`. If the template does not exist a
:exc:`TemplateNotFound` exception is raised.
@@ -879,7 +996,12 @@ class Environment:
return self._load_template(name, globals)
@internalcode
- def select_template(self, names, parent=None, globals=None):
+ def select_template(
+ self,
+ names: t.Iterable[t.Union[str, "Template"]],
+ parent: t.Optional[str] = None,
+ globals: t.Optional[t.Mapping[str, t.Any]] = None,
+ ) -> "Template":
"""Like :meth:`get_template`, but tries loading multiple names.
If none of the names can be loaded a :exc:`TemplatesNotFound`
exception is raised.
@@ -925,10 +1047,17 @@ class Environment:
return self._load_template(name, globals)
except (TemplateNotFound, UndefinedError):
pass
- raise TemplatesNotFound(names)
+ raise TemplatesNotFound(names) # type: ignore
@internalcode
- def get_or_select_template(self, template_name_or_list, parent=None, globals=None):
+ def get_or_select_template(
+ self,
+ template_name_or_list: t.Union[
+ str, "Template", t.List[t.Union[str, "Template"]]
+ ],
+ parent: t.Optional[str] = None,
+ globals: t.Optional[t.Mapping[str, t.Any]] = None,
+ ) -> "Template":
"""Use :meth:`select_template` if an iterable of template names
is given, or :meth:`get_template` if one name is given.
@@ -940,7 +1069,12 @@ class Environment:
return template_name_or_list
return self.select_template(template_name_or_list, parent, globals)
- def from_string(self, source, globals=None, template_class=None):
+ def from_string(
+ self,
+ source: t.Union[str, nodes.Template],
+ globals: t.Optional[t.Mapping[str, t.Any]] = None,
+ template_class: t.Optional[t.Type["Template"]] = None,
+ ) -> "Template":
"""Load a template from a source string without using
:attr:`loader`.
@@ -952,11 +1086,13 @@ class Environment:
:param template_class: Return an instance of this
:class:`Template` class.
"""
- globals = self.make_globals(globals)
+ gs = self.make_globals(globals)
cls = template_class or self.template_class
- return cls.from_code(self, self.compile(source), globals, None)
+ return cls.from_code(self, self.compile(source), gs, None)
- def make_globals(self, d):
+ def make_globals(
+ self, d: t.Optional[t.Mapping[str, t.Any]]
+ ) -> t.MutableMapping[str, t.Any]:
"""Make the globals map for a template. Any given template
globals overlay the environment :attr:`globals`.
@@ -1009,32 +1145,42 @@ class Template:
#: Type of environment to create when creating a template directly
#: rather than through an existing environment.
- environment_class = Environment
+ environment_class: t.Type[Environment] = Environment
+
+ environment: Environment
+ globals: t.MutableMapping[str, t.Any]
+ name: t.Optional[str]
+ filename: t.Optional[str]
+ blocks: t.Dict[str, t.Callable[[Context], t.Iterator[str]]]
+ root_render_func: t.Callable[[Context], t.Iterator[str]]
+ _module: t.Optional["TemplateModule"]
+ _debug_info: str
+ _uptodate: t.Optional[t.Callable[[], bool]]
def __new__(
cls,
- source,
- block_start_string=BLOCK_START_STRING,
- block_end_string=BLOCK_END_STRING,
- variable_start_string=VARIABLE_START_STRING,
- variable_end_string=VARIABLE_END_STRING,
- comment_start_string=COMMENT_START_STRING,
- comment_end_string=COMMENT_END_STRING,
- line_statement_prefix=LINE_STATEMENT_PREFIX,
- line_comment_prefix=LINE_COMMENT_PREFIX,
- trim_blocks=TRIM_BLOCKS,
- lstrip_blocks=LSTRIP_BLOCKS,
- newline_sequence=NEWLINE_SEQUENCE,
- keep_trailing_newline=KEEP_TRAILING_NEWLINE,
- extensions=(),
- optimized=True,
- undefined=Undefined,
- finalize=None,
- autoescape=False,
- enable_async=False,
- ):
+ source: t.Union[str, nodes.Template],
+ block_start_string: str = BLOCK_START_STRING,
+ block_end_string: str = BLOCK_END_STRING,
+ variable_start_string: str = VARIABLE_START_STRING,
+ variable_end_string: str = VARIABLE_END_STRING,
+ comment_start_string: str = COMMENT_START_STRING,
+ comment_end_string: str = COMMENT_END_STRING,
+ line_statement_prefix: t.Optional[str] = LINE_STATEMENT_PREFIX,
+ line_comment_prefix: t.Optional[str] = LINE_COMMENT_PREFIX,
+ trim_blocks: bool = TRIM_BLOCKS,
+ lstrip_blocks: bool = LSTRIP_BLOCKS,
+ newline_sequence: "te.Literal['\\n', '\\r\\n', '\\r']" = NEWLINE_SEQUENCE,
+ keep_trailing_newline: bool = KEEP_TRAILING_NEWLINE,
+ extensions: t.Sequence[t.Union[str, t.Type["Extension"]]] = (),
+ optimized: bool = True,
+ undefined: t.Type[Undefined] = Undefined,
+ finalize: t.Optional[t.Callable[..., t.Any]] = None,
+ autoescape: t.Union[bool, t.Callable[[t.Optional[str]], bool]] = False,
+ enable_async: bool = False,
+ ) -> "Template":
env = get_spontaneous_environment(
- cls.environment_class,
+ cls.environment_class, # type: ignore
block_start_string,
block_end_string,
variable_start_string,
@@ -1049,7 +1195,7 @@ class Template:
keep_trailing_newline,
frozenset(extensions),
optimized,
- undefined,
+ undefined, # type: ignore
finalize,
autoescape,
None,
@@ -1061,7 +1207,13 @@ class Template:
return env.from_string(source, template_class=cls)
@classmethod
- def from_code(cls, environment, code, globals, uptodate=None):
+ def from_code(
+ cls,
+ environment: Environment,
+ code: CodeType,
+ globals: t.MutableMapping[str, t.Any],
+ uptodate: t.Optional[t.Callable[[], bool]] = None,
+ ) -> "Template":
"""Creates a template object from compiled code and the globals. This
is used by the loaders and environment to create a template object.
"""
@@ -1072,7 +1224,12 @@ class Template:
return rv
@classmethod
- def from_module_dict(cls, environment, module_dict, globals):
+ def from_module_dict(
+ cls,
+ environment: Environment,
+ module_dict: t.MutableMapping[str, t.Any],
+ globals: t.MutableMapping[str, t.Any],
+ ) -> "Template":
"""Creates a template object from a module. This is used by the
module loader to create a template object.
@@ -1081,8 +1238,13 @@ class Template:
return cls._from_namespace(environment, module_dict, globals)
@classmethod
- def _from_namespace(cls, environment, namespace, globals):
- t = object.__new__(cls)
+ def _from_namespace(
+ cls,
+ environment: Environment,
+ namespace: t.MutableMapping[str, t.Any],
+ globals: t.MutableMapping[str, t.Any],
+ ) -> "Template":
+ t: "Template" = object.__new__(cls)
t.environment = environment
t.globals = globals
t.name = namespace["name"]
@@ -1090,7 +1252,7 @@ class Template:
t.blocks = namespace["blocks"]
# render function and module
- t.root_render_func = namespace["root"]
+ t.root_render_func = namespace["root"] # type: ignore
t._module = None
# debug and loader helpers
@@ -1103,7 +1265,7 @@ class Template:
return t
- def render(self, *args, **kwargs):
+ def render(self, *args: t.Any, **kwargs: t.Any) -> str:
"""This method accepts the same arguments as the `dict` constructor:
A dict, a dict subclass or some keyword arguments. If no arguments
are given the context will be empty. These two calls do the same::
@@ -1122,11 +1284,11 @@ class Template:
ctx = self.new_context(dict(*args, **kwargs))
try:
- return concat(self.root_render_func(ctx))
+ return concat(self.root_render_func(ctx)) # type: ignore
except Exception:
self.environment.handle_exception()
- async def render_async(self, *args, **kwargs):
+ async def render_async(self, *args: t.Any, **kwargs: t.Any) -> str:
"""This works similar to :meth:`render` but returns a coroutine
that when awaited returns the entire rendered template string. This
requires the async feature to be enabled.
@@ -1143,17 +1305,17 @@ class Template:
ctx = self.new_context(dict(*args, **kwargs))
try:
- return concat([n async for n in self.root_render_func(ctx)])
+ return concat([n async for n in self.root_render_func(ctx)]) # type: ignore
except Exception:
return self.environment.handle_exception()
- def stream(self, *args, **kwargs):
+ def stream(self, *args: t.Any, **kwargs: t.Any) -> "TemplateStream":
"""Works exactly like :meth:`generate` but returns a
:class:`TemplateStream`.
"""
return TemplateStream(self.generate(*args, **kwargs))
- def generate(self, *args, **kwargs):
+ def generate(self, *args: t.Any, **kwargs: t.Any) -> t.Iterator[str]:
"""For very large templates it can be useful to not render the whole
template at once but evaluate each statement after another and yield
piece for piece. This method basically does exactly that and returns
@@ -1176,11 +1338,13 @@ class Template:
ctx = self.new_context(dict(*args, **kwargs))
try:
- yield from self.root_render_func(ctx)
+ yield from self.root_render_func(ctx) # type: ignore
except Exception:
yield self.environment.handle_exception()
- async def generate_async(self, *args, **kwargs):
+ async def generate_async(
+ self, *args: t.Any, **kwargs: t.Any
+ ) -> t.AsyncIterator[str]:
"""An async version of :meth:`generate`. Works very similarly but
returns an async iterator instead.
"""
@@ -1192,12 +1356,17 @@ class Template:
ctx = self.new_context(dict(*args, **kwargs))
try:
- async for event in self.root_render_func(ctx):
+ async for event in self.root_render_func(ctx): # type: ignore
yield event
except Exception:
yield self.environment.handle_exception()
- def new_context(self, vars=None, shared=False, locals=None):
+ def new_context(
+ self,
+ vars: t.Optional[t.Dict[str, t.Any]] = None,
+ shared: bool = False,
+ locals: t.Optional[t.Mapping[str, t.Any]] = None,
+ ) -> Context:
"""Create a new :class:`Context` for this template. The vars
provided will be passed to the template. Per default the globals
are added to the context. If shared is set to `True` the data
@@ -1209,7 +1378,12 @@ class Template:
self.environment, self.name, self.blocks, vars, shared, self.globals, locals
)
- def make_module(self, vars=None, shared=False, locals=None):
+ def make_module(
+ self,
+ vars: t.Optional[t.Dict[str, t.Any]] = None,
+ shared: bool = False,
+ locals: t.Optional[t.Mapping[str, t.Any]] = None,
+ ) -> "TemplateModule":
"""This method works like the :attr:`module` attribute when called
without arguments but it will evaluate the template on every call
rather than caching it. It's also possible to provide
@@ -1219,17 +1393,24 @@ class Template:
ctx = self.new_context(vars, shared, locals)
return TemplateModule(self, ctx)
- async def make_module_async(self, vars=None, shared=False, locals=None):
+ async def make_module_async(
+ self,
+ vars: t.Optional[t.Dict[str, t.Any]] = None,
+ shared: bool = False,
+ locals: t.Optional[t.Mapping[str, t.Any]] = None,
+ ) -> "TemplateModule":
"""As template module creation can invoke template code for
asynchronous executions this method must be used instead of the
normal :meth:`make_module` one. Likewise the module attribute
becomes unavailable in async mode.
"""
ctx = self.new_context(vars, shared, locals)
- return TemplateModule(self, ctx, [x async for x in self.root_render_func(ctx)])
+ return TemplateModule(
+ self, ctx, [x async for x in self.root_render_func(ctx)] # type: ignore
+ )
@internalcode
- def _get_default_module(self, ctx=None):
+ def _get_default_module(self, ctx: t.Optional[Context] = None) -> "TemplateModule":
"""If a context is passed in, this means that the template was
imported. Imported templates have access to the current
template's globals by default, but they can only be accessed via
@@ -1255,7 +1436,9 @@ class Template:
return self._module
- async def _get_default_module_async(self, ctx=None):
+ async def _get_default_module_async(
+ self, ctx: t.Optional[Context] = None
+ ) -> "TemplateModule":
if ctx is not None:
keys = ctx.globals_keys - self.globals.keys()
@@ -1268,7 +1451,7 @@ class Template:
return self._module
@property
- def module(self):
+ def module(self) -> "TemplateModule":
"""The template as module. This is used for imports in the
template runtime but is also useful if one wants to access
exported template variables from the Python layer:
@@ -1283,7 +1466,7 @@ class Template:
"""
return self._get_default_module()
- def get_corresponding_lineno(self, lineno):
+ def get_corresponding_lineno(self, lineno: int) -> int:
"""Return the source line number of a line number in the
generated bytecode as they are not in sync.
"""
@@ -1293,25 +1476,29 @@ class Template:
return 1
@property
- def is_up_to_date(self):
+ def is_up_to_date(self) -> bool:
"""If this variable is `False` there is a newer version available."""
if self._uptodate is None:
return True
return self._uptodate()
@property
- def debug_info(self):
+ def debug_info(self) -> t.List[t.Tuple[int, int]]:
"""The debug info mapping."""
if self._debug_info:
- return [tuple(map(int, x.split("="))) for x in self._debug_info.split("&")]
+ return [
+ tuple(map(int, x.split("="))) # type: ignore
+ for x in self._debug_info.split("&")
+ ]
+
return []
- def __repr__(self):
+ def __repr__(self) -> str:
if self.name is None:
name = f"memory:{id(self):x}"
else:
name = repr(self.name)
- return f"<{self.__class__.__name__} {name}>"
+ return f"<{type(self).__name__} {name}>"
class TemplateModule:
@@ -1320,32 +1507,38 @@ class TemplateModule:
converting it into a string renders the contents.
"""
- def __init__(self, template, context, body_stream=None):
+ def __init__(
+ self,
+ template: Template,
+ context: Context,
+ body_stream: t.Optional[t.Iterable[str]] = None,
+ ) -> None:
if body_stream is None:
if context.environment.is_async:
raise RuntimeError(
- "Async mode requires a body stream "
- "to be passed to a template module. Use "
- "the async methods of the API you are "
- "using."
+ "Async mode requires a body stream to be passed to"
+ " a template module. Use the async methods of the"
+ " API you are using."
)
- body_stream = list(template.root_render_func(context))
+
+ body_stream = list(template.root_render_func(context)) # type: ignore
+
self._body_stream = body_stream
self.__dict__.update(context.get_exported())
self.__name__ = template.name
- def __html__(self):
+ def __html__(self) -> Markup:
return Markup(concat(self._body_stream))
- def __str__(self):
+ def __str__(self) -> str:
return concat(self._body_stream)
- def __repr__(self):
+ def __repr__(self) -> str:
if self.__name__ is None:
name = f"memory:{id(self):x}"
else:
name = repr(self.__name__)
- return f"<{self.__class__.__name__} {name}>"
+ return f"<{type(self).__name__} {name}>"
class TemplateExpression:
@@ -1354,13 +1547,13 @@ class TemplateExpression:
to the template with an expression it wraps.
"""
- def __init__(self, template, undefined_to_none):
+ def __init__(self, template: Template, undefined_to_none: bool) -> None:
self._template = template
self._undefined_to_none = undefined_to_none
- def __call__(self, *args, **kwargs):
+ def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Optional[t.Any]:
context = self._template.new_context(dict(*args, **kwargs))
- consume(self._template.root_render_func(context))
+ consume(self._template.root_render_func(context)) # type: ignore
rv = context.vars["result"]
if self._undefined_to_none and isinstance(rv, Undefined):
rv = None
@@ -1378,11 +1571,16 @@ class TemplateStream:
big templates to a client via WSGI which flushes after each iteration.
"""
- def __init__(self, gen):
+ def __init__(self, gen: t.Iterator[str]) -> None:
self._gen = gen
self.disable_buffering()
- def dump(self, fp, encoding=None, errors="strict"):
+ def dump(
+ self,
+ fp: t.Union[str, t.IO],
+ encoding: t.Optional[str] = None,
+ errors: t.Optional[str] = "strict",
+ ) -> None:
"""Dump the complete stream into a file or file-like object.
Per default strings are written, if you want to encode
before writing specify an `encoding`.
@@ -1392,16 +1590,19 @@ class TemplateStream:
Template('Hello {{ name }}!').stream(name='foo').dump('hello.html')
"""
close = False
+
if isinstance(fp, str):
if encoding is None:
encoding = "utf-8"
+
fp = open(fp, "wb")
close = True
try:
if encoding is not None:
- iterable = (x.encode(encoding, errors) for x in self)
+ iterable = (x.encode(encoding, errors) for x in self) # type: ignore
else:
- iterable = self
+ iterable = self # type: ignore
+
if hasattr(fp, "writelines"):
fp.writelines(iterable)
else:
@@ -1411,17 +1612,17 @@ class TemplateStream:
if close:
fp.close()
- def disable_buffering(self):
+ def disable_buffering(self) -> None:
"""Disable the output buffering."""
self._next = partial(next, self._gen)
self.buffered = False
- def _buffered_generator(self, size):
- buf = []
+ def _buffered_generator(self, size: int) -> t.Iterator[str]:
+ buf: t.List[str] = []
c_size = 0
push = buf.append
- while 1:
+ while True:
try:
while c_size < size:
c = next(self._gen)
@@ -1435,7 +1636,7 @@ class TemplateStream:
del buf[:]
c_size = 0
- def enable_buffering(self, size=5):
+ def enable_buffering(self, size: int = 5) -> None:
"""Enable buffering. Buffer `size` items before yielding them."""
if size <= 1:
raise ValueError("buffer size too small")
@@ -1443,11 +1644,11 @@ class TemplateStream:
self.buffered = True
self._next = partial(next, self._buffered_generator(size))
- def __iter__(self):
+ def __iter__(self) -> "TemplateStream":
return self
- def __next__(self):
- return self._next()
+ def __next__(self) -> str:
+ return self._next() # type: ignore
# hook in default template class. if anyone reads this comment: ignore that
diff --git a/src/jinja2/exceptions.py b/src/jinja2/exceptions.py
index 07cfba2..082ebe8 100644
--- a/src/jinja2/exceptions.py
+++ b/src/jinja2/exceptions.py
@@ -1,13 +1,18 @@
+import typing as t
+
+if t.TYPE_CHECKING:
+ from .runtime import Undefined
+
+
class TemplateError(Exception):
"""Baseclass for all template errors."""
- def __init__(self, message=None):
+ def __init__(self, message: t.Optional[str] = None) -> None:
super().__init__(message)
@property
- def message(self):
- if self.args:
- return self.args[0]
+ def message(self) -> t.Optional[str]:
+ return self.args[0] if self.args else None
class TemplateNotFound(IOError, LookupError, TemplateError):
@@ -20,9 +25,13 @@ class TemplateNotFound(IOError, LookupError, TemplateError):
# Silence the Python warning about message being deprecated since
# it's not valid here.
- message = None
+ message: t.Optional[str] = None
- def __init__(self, name, message=None):
+ def __init__(
+ self,
+ name: t.Optional[t.Union[str, "Undefined"]],
+ message: t.Optional[str] = None,
+ ) -> None:
IOError.__init__(self, name)
if message is None:
@@ -37,8 +46,8 @@ class TemplateNotFound(IOError, LookupError, TemplateError):
self.name = name
self.templates = [name]
- def __str__(self):
- return self.message
+ def __str__(self) -> str:
+ return str(self.message)
class TemplatesNotFound(TemplateNotFound):
@@ -53,7 +62,11 @@ class TemplatesNotFound(TemplateNotFound):
.. versionadded:: 2.2
"""
- def __init__(self, names=(), message=None):
+ def __init__(
+ self,
+ names: t.Sequence[t.Union[str, "Undefined"]] = (),
+ message: t.Optional[str] = None,
+ ) -> None:
if message is None:
from .runtime import Undefined
@@ -65,51 +78,57 @@ class TemplatesNotFound(TemplateNotFound):
else:
parts.append(name)
- message = "none of the templates given were found: " + ", ".join(
- map(str, parts)
- )
- TemplateNotFound.__init__(self, names[-1] if names else None, message)
+ parts_str = ", ".join(map(str, parts))
+ message = f"none of the templates given were found: {parts_str}"
+
+ super().__init__(names[-1] if names else None, message)
self.templates = list(names)
class TemplateSyntaxError(TemplateError):
"""Raised to tell the user that there is a problem with the template."""
- def __init__(self, message, lineno, name=None, filename=None):
- TemplateError.__init__(self, message)
+ def __init__(
+ self,
+ message: str,
+ lineno: int,
+ name: t.Optional[str] = None,
+ filename: t.Optional[str] = None,
+ ) -> None:
+ super().__init__(message)
self.lineno = lineno
self.name = name
self.filename = filename
- self.source = None
+ self.source: t.Optional[str] = None
# this is set to True if the debug.translate_syntax_error
# function translated the syntax error into a new traceback
self.translated = False
- def __str__(self):
+ def __str__(self) -> str:
# for translated errors we only return the message
if self.translated:
- return self.message
+ return t.cast(str, self.message)
# otherwise attach some stuff
location = f"line {self.lineno}"
name = self.filename or self.name
if name:
location = f'File "{name}", {location}'
- lines = [self.message, " " + location]
+ lines = [t.cast(str, self.message), " " + location]
# if the source is set, add the line to the output
if self.source is not None:
try:
line = self.source.splitlines()[self.lineno - 1]
except IndexError:
- line = None
- if line:
+ pass
+ else:
lines.append(" " + line.strip())
return "\n".join(lines)
- def __reduce__(self):
+ def __reduce__(self): # type: ignore
# https://bugs.python.org/issue1692335 Exceptions that take
# multiple required arguments have problems with pickling.
# Without this, raises TypeError: __init__() missing 1 required
diff --git a/src/jinja2/ext.py b/src/jinja2/ext.py
index 2cca1b3..9c5498b 100644
--- a/src/jinja2/ext.py
+++ b/src/jinja2/ext.py
@@ -1,49 +1,58 @@
"""Extension API for adding custom tags and behavior."""
import pprint
import re
+import typing as t
import warnings
-from sys import version_info
-from typing import Set
from markupsafe import Markup
+from . import defaults
from . import nodes
-from .defaults import BLOCK_END_STRING
-from .defaults import BLOCK_START_STRING
-from .defaults import COMMENT_END_STRING
-from .defaults import COMMENT_START_STRING
-from .defaults import KEEP_TRAILING_NEWLINE
-from .defaults import LINE_COMMENT_PREFIX
-from .defaults import LINE_STATEMENT_PREFIX
-from .defaults import LSTRIP_BLOCKS
-from .defaults import NEWLINE_SEQUENCE
-from .defaults import TRIM_BLOCKS
-from .defaults import VARIABLE_END_STRING
-from .defaults import VARIABLE_START_STRING
from .environment import Environment
from .exceptions import TemplateAssertionError
from .exceptions import TemplateSyntaxError
-from .nodes import ContextReference
-from .runtime import concat
+from .runtime import concat # type: ignore
+from .runtime import Context
+from .runtime import Undefined
from .utils import import_string
from .utils import pass_context
-# I18N functions available in Jinja templates. If the I18N library
-# provides ugettext, it will be assigned to gettext.
-GETTEXT_FUNCTIONS = ("_", "gettext", "ngettext", "pgettext", "npgettext")
-_ws_re = re.compile(r"\s*\n\s*")
+if t.TYPE_CHECKING:
+ import typing_extensions as te
+ from .lexer import Token
+ from .lexer import TokenStream
+ from .parser import Parser
+ class _TranslationsBasic(te.Protocol):
+ def gettext(self, message: str) -> str:
+ ...
-class ExtensionRegistry(type):
- """Gives the extension an unique identifier."""
+ def ngettext(self, singular: str, plural: str, n: int) -> str:
+ pass
- def __new__(mcs, name, bases, d):
- rv = type.__new__(mcs, name, bases, d)
- rv.identifier = f"{rv.__module__}.{rv.__name__}"
- return rv
+ class _TranslationsContext(_TranslationsBasic):
+ def pgettext(self, context: str, message: str) -> str:
+ ...
+
+ def npgettext(self, context: str, singular: str, plural: str, n: int) -> str:
+ ...
+ _SupportedTranslations = t.Union[_TranslationsBasic, _TranslationsContext]
-class Extension(metaclass=ExtensionRegistry):
+
+# I18N functions available in Jinja templates. If the I18N library
+# provides ugettext, it will be assigned to gettext.
+GETTEXT_FUNCTIONS: t.Tuple[str, ...] = (
+ "_",
+ "gettext",
+ "ngettext",
+ "pgettext",
+ "npgettext",
+)
+_ws_re = re.compile(r"\s*\n\s*")
+
+
+class Extension:
"""Extensions can be used to add extra functionality to the Jinja template
system at the parser level. Custom extensions are bound to an environment
but may not store environment specific data on `self`. The reason for
@@ -62,8 +71,13 @@ class Extension(metaclass=ExtensionRegistry):
name as includes the name of the extension (fragment cache).
"""
+ identifier: t.ClassVar[str]
+
+ def __init_subclass__(cls) -> None:
+ cls.identifier = f"{cls.__module__}.{cls.__name__}"
+
#: if this extension parses this is the list of tags it's listening to.
- tags: Set[str] = set()
+ tags: t.Set[str] = set()
#: the priority of that extension. This is especially useful for
#: extensions that preprocess values. A lower value means higher
@@ -72,24 +86,28 @@ class Extension(metaclass=ExtensionRegistry):
#: .. versionadded:: 2.4
priority = 100
- def __init__(self, environment):
+ def __init__(self, environment: Environment) -> None:
self.environment = environment
- def bind(self, environment):
+ def bind(self, environment: Environment) -> "Extension":
"""Create a copy of this extension bound to another environment."""
- rv = object.__new__(self.__class__)
+ rv = t.cast(Extension, object.__new__(self.__class__))
rv.__dict__.update(self.__dict__)
rv.environment = environment
return rv
- def preprocess(self, source, name, filename=None):
+ def preprocess(
+ self, source: str, name: t.Optional[str], filename: t.Optional[str] = None
+ ) -> str:
"""This method is called before the actual lexing and can be used to
preprocess the source. The `filename` is optional. The return value
must be the preprocessed source.
"""
return source
- def filter_stream(self, stream):
+ def filter_stream(
+ self, stream: "TokenStream"
+ ) -> t.Union["TokenStream", t.Iterable["Token"]]:
"""It's passed a :class:`~jinja2.lexer.TokenStream` that can be used
to filter tokens returned. This method has to return an iterable of
:class:`~jinja2.lexer.Token`\\s, but it doesn't have to return a
@@ -97,7 +115,7 @@ class Extension(metaclass=ExtensionRegistry):
"""
return stream
- def parse(self, parser):
+ def parse(self, parser: "Parser") -> t.Union[nodes.Node, t.List[nodes.Node]]:
"""If any of the :attr:`tags` matched this method is called with the
parser as first argument. The token the parser stream is pointing at
is the name token that matched. This method has to return one or a
@@ -105,7 +123,9 @@ class Extension(metaclass=ExtensionRegistry):
"""
raise NotImplementedError()
- def attr(self, name, lineno=None):
+ def attr(
+ self, name: str, lineno: t.Optional[int] = None
+ ) -> nodes.ExtensionAttribute:
"""Return an attribute node for the current extension. This is useful
to pass constants on extensions to generated template code.
@@ -116,8 +136,14 @@ class Extension(metaclass=ExtensionRegistry):
return nodes.ExtensionAttribute(self.identifier, name, lineno=lineno)
def call_method(
- self, name, args=None, kwargs=None, dyn_args=None, dyn_kwargs=None, lineno=None
- ):
+ self,
+ name: str,
+ args: t.Optional[t.List[nodes.Expr]] = None,
+ kwargs: t.Optional[t.List[nodes.Keyword]] = None,
+ dyn_args: t.Optional[nodes.Expr] = None,
+ dyn_kwargs: t.Optional[nodes.Expr] = None,
+ lineno: t.Optional[int] = None,
+ ) -> nodes.Call:
"""Call a method of the extension. This is a shortcut for
:meth:`attr` + :class:`jinja2.nodes.Call`.
"""
@@ -136,40 +162,50 @@ class Extension(metaclass=ExtensionRegistry):
@pass_context
-def _gettext_alias(__context, *args, **kwargs):
+def _gettext_alias(
+ __context: Context, *args: t.Any, **kwargs: t.Any
+) -> t.Union[t.Any, Undefined]:
return __context.call(__context.resolve("gettext"), *args, **kwargs)
-def _make_new_gettext(func):
+def _make_new_gettext(func: t.Callable[[str], str]) -> t.Callable[..., str]:
@pass_context
- def gettext(__context, __string, **variables):
+ def gettext(__context: Context, __string: str, **variables: t.Any) -> str:
rv = __context.call(func, __string)
if __context.eval_ctx.autoescape:
rv = Markup(rv)
# Always treat as a format string, even if there are no
# variables. This makes translation strings more consistent
# and predictable. This requires escaping
- return rv % variables
+ return rv % variables # type: ignore
return gettext
-def _make_new_ngettext(func):
+def _make_new_ngettext(func: t.Callable[[str, str, int], str]) -> t.Callable[..., str]:
@pass_context
- def ngettext(__context, __singular, __plural, __num, **variables):
+ def ngettext(
+ __context: Context,
+ __singular: str,
+ __plural: str,
+ __num: int,
+ **variables: t.Any,
+ ) -> str:
variables.setdefault("num", __num)
rv = __context.call(func, __singular, __plural, __num)
if __context.eval_ctx.autoescape:
rv = Markup(rv)
# Always treat as a format string, see gettext comment above.
- return rv % variables
+ return rv % variables # type: ignore
return ngettext
-def _make_new_pgettext(func):
+def _make_new_pgettext(func: t.Callable[[str, str], str]) -> t.Callable[..., str]:
@pass_context
- def pgettext(__context, __string_ctx, __string, **variables):
+ def pgettext(
+ __context: Context, __string_ctx: str, __string: str, **variables: t.Any
+ ) -> str:
variables.setdefault("context", __string_ctx)
rv = __context.call(func, __string_ctx, __string)
@@ -177,14 +213,23 @@ def _make_new_pgettext(func):
rv = Markup(rv)
# Always treat as a format string, see gettext comment above.
- return rv % variables
+ return rv % variables # type: ignore
return pgettext
-def _make_new_npgettext(func):
+def _make_new_npgettext(
+ func: t.Callable[[str, str, str, int], str]
+) -> t.Callable[..., str]:
@pass_context
- def npgettext(__context, __string_ctx, __singular, __plural, __num, **variables):
+ def npgettext(
+ __context: Context,
+ __string_ctx: str,
+ __singular: str,
+ __plural: str,
+ __num: int,
+ **variables: t.Any,
+ ) -> str:
variables.setdefault("context", __string_ctx)
variables.setdefault("num", __num)
rv = __context.call(func, __string_ctx, __singular, __plural, __num)
@@ -193,7 +238,7 @@ def _make_new_npgettext(func):
rv = Markup(rv)
# Always treat as a format string, see gettext comment above.
- return rv % variables
+ return rv % variables # type: ignore
return npgettext
@@ -210,8 +255,8 @@ class InternationalizationExtension(Extension):
# something is called twice here. One time for the gettext value and
# the other time for the n-parameter of the ngettext function.
- def __init__(self, environment):
- Extension.__init__(self, environment)
+ def __init__(self, environment: Environment) -> None:
+ super().__init__(environment)
environment.globals["_"] = _gettext_alias
environment.extend(
install_gettext_translations=self._install,
@@ -222,7 +267,9 @@ class InternationalizationExtension(Extension):
newstyle_gettext=False,
)
- def _install(self, translations, newstyle=None):
+ def _install(
+ self, translations: "_SupportedTranslations", newstyle: t.Optional[bool] = None
+ ) -> None:
# ugettext and ungettext are preferred in case the I18N library
# is providing compatibility with older Python versions.
gettext = getattr(translations, "ugettext", None)
@@ -238,21 +285,45 @@ class InternationalizationExtension(Extension):
gettext, ngettext, newstyle=newstyle, pgettext=pgettext, npgettext=npgettext
)
- def _install_null(self, newstyle=None):
+ def _install_null(self, newstyle: t.Optional[bool] = None) -> None:
+ import gettext
+
+ translations = gettext.NullTranslations()
+
+ if hasattr(translations, "pgettext"):
+ # Python < 3.8
+ pgettext = translations.pgettext # type: ignore
+ else:
+
+ def pgettext(c: str, s: str) -> str:
+ return s
+
+ if hasattr(translations, "npgettext"):
+ npgettext = translations.npgettext # type: ignore
+ else:
+
+ def npgettext(c: str, s: str, p: str, n: int) -> str:
+ return s if n == 1 else p
+
self._install_callables(
- lambda s: s,
- lambda s, p, n: s if n == 1 else p,
+ gettext=translations.gettext,
+ ngettext=translations.ngettext,
newstyle=newstyle,
- pgettext=lambda c, s: s,
- npgettext=lambda c, s, p, n: s if n == 1 else p,
+ pgettext=pgettext,
+ npgettext=npgettext,
)
def _install_callables(
- self, gettext, ngettext, newstyle=None, pgettext=None, npgettext=None
- ):
+ self,
+ gettext: t.Callable[[str], str],
+ ngettext: t.Callable[[str, str, int], str],
+ newstyle: t.Optional[bool] = None,
+ pgettext: t.Optional[t.Callable[[str, str], str]] = None,
+ npgettext: t.Optional[t.Callable[[str, str, str, int], str]] = None,
+ ) -> None:
if newstyle is not None:
- self.environment.newstyle_gettext = newstyle
- if self.environment.newstyle_gettext:
+ self.environment.newstyle_gettext = newstyle # type: ignore
+ if self.environment.newstyle_gettext: # type: ignore
gettext = _make_new_gettext(gettext)
ngettext = _make_new_ngettext(ngettext)
@@ -266,16 +337,22 @@ class InternationalizationExtension(Extension):
gettext=gettext, ngettext=ngettext, pgettext=pgettext, npgettext=npgettext
)
- def _uninstall(self, translations):
+ def _uninstall(self, translations: "_SupportedTranslations") -> None:
for key in ("gettext", "ngettext", "pgettext", "npgettext"):
self.environment.globals.pop(key, None)
- def _extract(self, source, gettext_functions=GETTEXT_FUNCTIONS):
+ def _extract(
+ self,
+ source: t.Union[str, nodes.Template],
+ gettext_functions: t.Sequence[str] = GETTEXT_FUNCTIONS,
+ ) -> t.Iterator[
+ t.Tuple[int, str, t.Union[t.Optional[str], t.Tuple[t.Optional[str], ...]]]
+ ]:
if isinstance(source, str):
source = self.environment.parse(source)
return extract_from_ast(source, gettext_functions)
- def parse(self, parser):
+ def parse(self, parser: "Parser") -> t.Union[nodes.Node, t.List[nodes.Node]]:
"""Parse a translatable tag."""
lineno = next(parser.stream).lineno
num_called_num = False
@@ -283,9 +360,9 @@ class InternationalizationExtension(Extension):
# find all the variables referenced. Additionally a variable can be
# defined in the body of the trans block too, but this is checked at
# a later state.
- plural_expr = None
- plural_expr_assignment = None
- variables = {}
+ plural_expr: t.Optional[nodes.Expr] = None
+ plural_expr_assignment: t.Optional[nodes.Assign] = None
+ variables: t.Dict[str, nodes.Expr] = {}
trimmed = None
while parser.stream.current.type != "block_end":
if variables:
@@ -295,34 +372,34 @@ class InternationalizationExtension(Extension):
if parser.stream.skip_if("colon"):
break
- name = parser.stream.expect("name")
- if name.value in variables:
+ token = parser.stream.expect("name")
+ if token.value in variables:
parser.fail(
- f"translatable variable {name.value!r} defined twice.",
- name.lineno,
+ f"translatable variable {token.value!r} defined twice.",
+ token.lineno,
exc=TemplateAssertionError,
)
# expressions
if parser.stream.current.type == "assign":
next(parser.stream)
- variables[name.value] = var = parser.parse_expression()
- elif trimmed is None and name.value in ("trimmed", "notrimmed"):
- trimmed = name.value == "trimmed"
+ variables[token.value] = var = parser.parse_expression()
+ elif trimmed is None and token.value in ("trimmed", "notrimmed"):
+ trimmed = token.value == "trimmed"
continue
else:
- variables[name.value] = var = nodes.Name(name.value, "load")
+ variables[token.value] = var = nodes.Name(token.value, "load")
if plural_expr is None:
if isinstance(var, nodes.Call):
plural_expr = nodes.Name("_trans", "load")
- variables[name.value] = plural_expr
+ variables[token.value] = plural_expr
plural_expr_assignment = nodes.Assign(
nodes.Name("_trans", "store"), var
)
else:
plural_expr = var
- num_called_num = name.value == "num"
+ num_called_num = token.value == "num"
parser.stream.expect("block_end")
@@ -343,15 +420,15 @@ class InternationalizationExtension(Extension):
have_plural = True
next(parser.stream)
if parser.stream.current.type != "block_end":
- name = parser.stream.expect("name")
- if name.value not in variables:
+ token = parser.stream.expect("name")
+ if token.value not in variables:
parser.fail(
- f"unknown variable {name.value!r} for pluralization",
- name.lineno,
+ f"unknown variable {token.value!r} for pluralization",
+ token.lineno,
exc=TemplateAssertionError,
)
- plural_expr = variables[name.value]
- num_called_num = name.value == "num"
+ plural_expr = variables[token.value]
+ num_called_num = token.value == "num"
parser.stream.expect("block_end")
plural_names, plural = self._parse_block(parser, False)
next(parser.stream)
@@ -360,9 +437,9 @@ class InternationalizationExtension(Extension):
next(parser.stream)
# register free names as simple name expressions
- for var in referenced:
- if var not in variables:
- variables[var] = nodes.Name(var, "load")
+ for name in referenced:
+ if name not in variables:
+ variables[name] = nodes.Name(name, "load")
if not have_plural:
plural_expr = None
@@ -390,14 +467,17 @@ class InternationalizationExtension(Extension):
else:
return node
- def _trim_whitespace(self, string, _ws_re=_ws_re):
+ def _trim_whitespace(self, string: str, _ws_re: t.Pattern[str] = _ws_re) -> str:
return _ws_re.sub(" ", string.strip())
- def _parse_block(self, parser, allow_pluralize):
+ def _parse_block(
+ self, parser: "Parser", allow_pluralize: bool
+ ) -> t.Tuple[t.List[str], str]:
"""Parse until the next block tag with a given name."""
referenced = []
buf = []
- while 1:
+
+ while True:
if parser.stream.current.type == "data":
buf.append(parser.stream.current.value.replace("%", "%%"))
next(parser.stream)
@@ -428,12 +508,21 @@ class InternationalizationExtension(Extension):
return referenced, concat(buf)
def _make_node(
- self, singular, plural, variables, plural_expr, vars_referenced, num_called_num
- ):
+ self,
+ singular: str,
+ plural: t.Optional[str],
+ variables: t.Dict[str, nodes.Expr],
+ plural_expr: t.Optional[nodes.Expr],
+ vars_referenced: bool,
+ num_called_num: bool,
+ ) -> nodes.Output:
"""Generates a useful node from the data provided."""
+ newstyle = self.environment.newstyle_gettext # type: ignore
+ node: nodes.Expr
+
# no variables referenced? no need to escape for old style
# gettext invocations only if there are vars.
- if not vars_referenced and not self.environment.newstyle_gettext:
+ if not vars_referenced and not newstyle:
singular = singular.replace("%%", "%")
if plural:
plural = plural.replace("%%", "%")
@@ -457,7 +546,7 @@ class InternationalizationExtension(Extension):
# in case newstyle gettext is used, the method is powerful
# enough to handle the variable expansion and autoescape
# handling itself
- if self.environment.newstyle_gettext:
+ if newstyle:
for key, value in variables.items():
# the function adds that later anyways in case num was
# called num, so just skip it.
@@ -490,7 +579,7 @@ class ExprStmtExtension(Extension):
tags = {"do"}
- def parse(self, parser):
+ def parse(self, parser: "Parser") -> nodes.ExprStmt:
node = nodes.ExprStmt(lineno=next(parser.stream).lineno)
node.node = parser.parse_tuple()
return node
@@ -501,7 +590,7 @@ class LoopControlExtension(Extension):
tags = {"break", "continue"}
- def parse(self, parser):
+ def parse(self, parser: "Parser") -> t.Union[nodes.Break, nodes.Continue]:
token = next(parser.stream)
if token.value == "break":
return nodes.Break(lineno=token.lineno)
@@ -509,7 +598,7 @@ class LoopControlExtension(Extension):
class WithExtension(Extension):
- def __init__(self, environment):
+ def __init__(self, environment: Environment) -> None:
super().__init__(environment)
warnings.warn(
"The 'with' extension is deprecated and will be removed in"
@@ -520,7 +609,7 @@ class WithExtension(Extension):
class AutoEscapeExtension(Extension):
- def __init__(self, environment):
+ def __init__(self, environment: Environment) -> None:
super().__init__(environment)
warnings.warn(
"The 'autoescape' extension is deprecated and will be"
@@ -553,13 +642,13 @@ class DebugExtension(Extension):
tags = {"debug"}
- def parse(self, parser):
+ def parse(self, parser: "Parser") -> nodes.Output:
lineno = parser.stream.expect("name:debug").lineno
- context = ContextReference()
+ context = nodes.ContextReference()
result = self.call_method("_render", [context], lineno=lineno)
return nodes.Output([result], lineno=lineno)
- def _render(self, context):
+ def _render(self, context: Context) -> str:
result = {
"context": context.get_all(),
"filters": sorted(self.environment.filters.keys()),
@@ -567,13 +656,16 @@ class DebugExtension(Extension):
}
# Set the depth since the intent is to show the top few names.
- if version_info[:2] >= (3, 4):
- return pprint.pformat(result, depth=3, compact=True)
- else:
- return pprint.pformat(result, depth=3)
+ return pprint.pformat(result, depth=3, compact=True)
-def extract_from_ast(node, gettext_functions=GETTEXT_FUNCTIONS, babel_style=True):
+def extract_from_ast(
+ ast: nodes.Template,
+ gettext_functions: t.Sequence[str] = GETTEXT_FUNCTIONS,
+ babel_style: bool = True,
+) -> t.Iterator[
+ t.Tuple[int, str, t.Union[t.Optional[str], t.Tuple[t.Optional[str], ...]]]
+]:
"""Extract localizable strings from the given template node. Per
default this function returns matches in babel style that means non string
parameters as well as keyword arguments are returned as `None`. This
@@ -608,14 +700,17 @@ def extract_from_ast(node, gettext_functions=GETTEXT_FUNCTIONS, babel_style=True
to extract any comments. For comment support you have to use the babel
extraction interface or extract comments yourself.
"""
- for node in node.find_all(nodes.Call):
+ out: t.Union[t.Optional[str], t.Tuple[t.Optional[str], ...]]
+
+ for node in ast.find_all(nodes.Call):
if (
not isinstance(node.node, nodes.Name)
or node.node.name not in gettext_functions
):
continue
- strings = []
+ strings: t.List[t.Optional[str]] = []
+
for arg in node.args:
if isinstance(arg, nodes.Const) and isinstance(arg.value, str):
strings.append(arg.value)
@@ -630,15 +725,17 @@ def extract_from_ast(node, gettext_functions=GETTEXT_FUNCTIONS, babel_style=True
strings.append(None)
if not babel_style:
- strings = tuple(x for x in strings if x is not None)
- if not strings:
+ out = tuple(x for x in strings if x is not None)
+
+ if not out:
continue
else:
if len(strings) == 1:
- strings = strings[0]
+ out = strings[0]
else:
- strings = tuple(strings)
- yield node.lineno, node.node.name, strings
+ out = tuple(strings)
+
+ yield node.lineno, node.node.name, out
class _CommentFinder:
@@ -648,13 +745,15 @@ class _CommentFinder:
usable value.
"""
- def __init__(self, tokens, comment_tags):
+ def __init__(
+ self, tokens: t.Sequence[t.Tuple[int, str, str]], comment_tags: t.Sequence[str]
+ ) -> None:
self.tokens = tokens
self.comment_tags = comment_tags
self.offset = 0
self.last_lineno = 0
- def find_backwards(self, offset):
+ def find_backwards(self, offset: int) -> t.List[str]:
try:
for _, token_type, token_value in reversed(
self.tokens[self.offset : offset]
@@ -670,7 +769,7 @@ class _CommentFinder:
finally:
self.offset = offset
- def find_comments(self, lineno):
+ def find_comments(self, lineno: int) -> t.List[str]:
if not self.comment_tags or self.last_lineno > lineno:
return []
for idx, (token_lineno, _, _) in enumerate(self.tokens[self.offset :]):
@@ -679,7 +778,16 @@ class _CommentFinder:
return self.find_backwards(len(self.tokens))
-def babel_extract(fileobj, keywords, comment_tags, options):
+def babel_extract(
+ fileobj: t.BinaryIO,
+ keywords: t.Sequence[str],
+ comment_tags: t.Sequence[str],
+ options: t.Dict[str, t.Any],
+) -> t.Iterator[
+ t.Tuple[
+ int, str, t.Union[t.Optional[str], t.Tuple[t.Optional[str], ...]], t.List[str]
+ ]
+]:
"""Babel extraction method for Jinja templates.
.. versionchanged:: 2.3
@@ -707,33 +815,37 @@ def babel_extract(fileobj, keywords, comment_tags, options):
:return: an iterator over ``(lineno, funcname, message, comments)`` tuples.
(comments will be empty currently)
"""
- extensions = set()
- for extension in options.get("extensions", "").split(","):
- extension = extension.strip()
- if not extension:
+ extensions: t.Dict[t.Type[Extension], None] = {}
+
+ for extension_name in options.get("extensions", "").split(","):
+ extension_name = extension_name.strip()
+
+ if not extension_name:
continue
- extensions.add(import_string(extension))
+
+ extensions[import_string(extension_name)] = None
+
if InternationalizationExtension not in extensions:
- extensions.add(InternationalizationExtension)
+ extensions[InternationalizationExtension] = None
- def getbool(options, key, default=False):
- return options.get(key, str(default)).lower() in ("1", "on", "yes", "true")
+ def getbool(options: t.Mapping[str, str], key: str, default: bool = False) -> bool:
+ return options.get(key, str(default)).lower() in {"1", "on", "yes", "true"}
silent = getbool(options, "silent", True)
environment = Environment(
- options.get("block_start_string", BLOCK_START_STRING),
- options.get("block_end_string", BLOCK_END_STRING),
- options.get("variable_start_string", VARIABLE_START_STRING),
- options.get("variable_end_string", VARIABLE_END_STRING),
- options.get("comment_start_string", COMMENT_START_STRING),
- options.get("comment_end_string", COMMENT_END_STRING),
- options.get("line_statement_prefix") or LINE_STATEMENT_PREFIX,
- options.get("line_comment_prefix") or LINE_COMMENT_PREFIX,
- getbool(options, "trim_blocks", TRIM_BLOCKS),
- getbool(options, "lstrip_blocks", LSTRIP_BLOCKS),
- NEWLINE_SEQUENCE,
- getbool(options, "keep_trailing_newline", KEEP_TRAILING_NEWLINE),
- frozenset(extensions),
+ options.get("block_start_string", defaults.BLOCK_START_STRING),
+ options.get("block_end_string", defaults.BLOCK_END_STRING),
+ options.get("variable_start_string", defaults.VARIABLE_START_STRING),
+ options.get("variable_end_string", defaults.VARIABLE_END_STRING),
+ options.get("comment_start_string", defaults.COMMENT_START_STRING),
+ options.get("comment_end_string", defaults.COMMENT_END_STRING),
+ options.get("line_statement_prefix") or defaults.LINE_STATEMENT_PREFIX,
+ options.get("line_comment_prefix") or defaults.LINE_COMMENT_PREFIX,
+ getbool(options, "trim_blocks", defaults.TRIM_BLOCKS),
+ getbool(options, "lstrip_blocks", defaults.LSTRIP_BLOCKS),
+ defaults.NEWLINE_SEQUENCE,
+ getbool(options, "keep_trailing_newline", defaults.KEEP_TRAILING_NEWLINE),
+ tuple(extensions),
cache_size=0,
auto_reload=False,
)
@@ -741,7 +853,7 @@ def babel_extract(fileobj, keywords, comment_tags, options):
if getbool(options, "trimmed"):
environment.policies["ext.i18n.trimmed"] = True
if getbool(options, "newstyle_gettext"):
- environment.newstyle_gettext = True
+ environment.newstyle_gettext = True # type: ignore
source = fileobj.read().decode(options.get("encoding", "utf-8"))
try:
diff --git a/src/jinja2/filters.py b/src/jinja2/filters.py
index 2684683..c8d41da 100644
--- a/src/jinja2/filters.py
+++ b/src/jinja2/filters.py
@@ -34,15 +34,17 @@ if t.TYPE_CHECKING:
from .runtime import Context
from .sandbox import SandboxedEnvironment # noqa: F401
- K = t.TypeVar("K")
- V = t.TypeVar("V")
-
class HasHTML(te.Protocol):
def __html__(self) -> str:
pass
-def contextfilter(f):
+F = t.TypeVar("F", bound=t.Callable[..., t.Any])
+K = t.TypeVar("K")
+V = t.TypeVar("V")
+
+
+def contextfilter(f: F) -> F:
"""Pass the context as the first argument to the decorated function.
.. deprecated:: 3.0
@@ -58,7 +60,7 @@ def contextfilter(f):
return pass_context(f)
-def evalcontextfilter(f):
+def evalcontextfilter(f: F) -> F:
"""Pass the eval context as the first argument to the decorated
function.
@@ -77,7 +79,7 @@ def evalcontextfilter(f):
return pass_eval_context(f)
-def environmentfilter(f):
+def environmentfilter(f: F) -> F:
"""Pass the environment as the first argument to the decorated
function.
@@ -94,11 +96,11 @@ def environmentfilter(f):
return pass_environment(f)
-def ignore_case(value: "V") -> "V":
+def ignore_case(value: V) -> V:
"""For use as a postprocessor for :func:`make_attrgetter`. Converts strings
to lowercase and returns other types as-is."""
if isinstance(value, str):
- return t.cast("V", value.lower())
+ return t.cast(V, value.lower())
return value
@@ -334,11 +336,11 @@ def do_title(s: str) -> str:
def do_dictsort(
- value: "t.Mapping[K, V]",
+ value: t.Mapping[K, V],
case_sensitive: bool = False,
by: 'te.Literal["key", "value"]' = "key",
reverse: bool = False,
-) -> "t.List[t.Tuple[K, V]]":
+) -> t.List[t.Tuple[K, V]]:
"""Sort a dict and yield (key, value) pairs. Python dicts may not
be in the order you want to display them in, so sort them first.
@@ -363,7 +365,7 @@ def do_dictsort(
else:
raise FilterArgumentError('You can only sort by either "key" or "value"')
- def sort_func(item):
+ def sort_func(item: t.Tuple[t.Any, t.Any]) -> t.Any:
value = item[pos]
if not case_sensitive:
@@ -524,10 +526,10 @@ def do_max(
def do_default(
- value: "V",
- default_value: "V" = "", # type: ignore
+ value: V,
+ default_value: V = "", # type: ignore
boolean: bool = False,
-) -> "V":
+) -> V:
"""If the value is undefined it will return the passed default value,
otherwise the value of the variable:
@@ -614,7 +616,7 @@ def sync_do_join(
return soft_str(d).join(map(soft_str, value))
-@async_variant(sync_do_join)
+@async_variant(sync_do_join) # type: ignore
async def do_join(
eval_ctx: "EvalContext",
value: t.Union[t.AsyncIterable, t.Iterable],
@@ -640,12 +642,12 @@ def sync_do_first(
return environment.undefined("No first item, sequence was empty.")
-@async_variant(sync_do_first)
+@async_variant(sync_do_first) # type: ignore
async def do_first(
environment: "Environment", seq: "t.Union[t.AsyncIterable[V], t.Iterable[V]]"
) -> "t.Union[V, Undefined]":
try:
- return t.cast("V", await auto_aiter(seq).__anext__())
+ return await auto_aiter(seq).__anext__()
except StopAsyncIteration:
return environment.undefined("No first item, sequence was empty.")
@@ -716,7 +718,7 @@ def do_filesizeformat(value: t.Union[str, float, int], binary: bool = False) ->
def do_pprint(value: t.Any) -> str:
"""Pretty print a variable. Useful for debugging."""
- return t.cast(str, pformat(value))
+ return pformat(value)
_uri_scheme_re = re.compile(r"^([\w.+-]{2,}:(/){0,2})$")
@@ -1079,7 +1081,7 @@ def sync_do_slice(
yield tmp
-@async_variant(sync_do_slice)
+@async_variant(sync_do_slice) # type: ignore
async def do_slice(
value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
slices: int,
@@ -1171,10 +1173,10 @@ class _GroupTuple(t.NamedTuple):
# Use the regular tuple repr to hide this subclass if users print
# out the value during debugging.
- def __repr__(self):
+ def __repr__(self) -> str:
return tuple.__repr__(self)
- def __str__(self):
+ def __str__(self) -> str:
return tuple.__str__(self)
@@ -1237,7 +1239,7 @@ def sync_do_groupby(
]
-@async_variant(sync_do_groupby)
+@async_variant(sync_do_groupby) # type: ignore
async def do_groupby(
environment: "Environment",
value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
@@ -1256,8 +1258,8 @@ def sync_do_sum(
environment: "Environment",
iterable: "t.Iterable[V]",
attribute: t.Optional[t.Union[str, int]] = None,
- start: "V" = 0, # type: ignore
-) -> "V":
+ start: V = 0, # type: ignore
+) -> V:
"""Returns the sum of a sequence of numbers plus the value of parameter
'start' (which defaults to 0). When the sequence is empty it returns
start.
@@ -1278,20 +1280,20 @@ def sync_do_sum(
return sum(iterable, start)
-@async_variant(sync_do_sum)
+@async_variant(sync_do_sum) # type: ignore
async def do_sum(
environment: "Environment",
iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
attribute: t.Optional[t.Union[str, int]] = None,
- start: "V" = 0, # type: ignore
-) -> "V":
+ start: V = 0, # type: ignore
+) -> V:
rv = start
if attribute is not None:
func = make_attrgetter(environment, attribute)
else:
- def func(x):
+ def func(x: V) -> V:
return x
async for item in auto_aiter(iterable):
@@ -1307,7 +1309,7 @@ def sync_do_list(value: "t.Iterable[V]") -> "t.List[V]":
return list(value)
-@async_variant(sync_do_list)
+@async_variant(sync_do_list) # type: ignore
async def do_list(value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]") -> "t.List[V]":
return await auto_to_list(value)
@@ -1334,7 +1336,7 @@ def do_reverse(value: "t.Iterable[V]") -> "t.Iterable[V]":
...
-def do_reverse(value):
+def do_reverse(value: t.Union[str, t.Iterable[V]]) -> t.Union[str, t.Iterable[V]]:
"""Reverse the object or return an iterator that iterates over it the other
way round.
"""
@@ -1342,7 +1344,7 @@ def do_reverse(value):
return value[::-1]
try:
- return reversed(value)
+ return reversed(value) # type: ignore
except TypeError:
try:
rv = list(value)
@@ -1402,7 +1404,9 @@ def sync_do_map(
@pass_context
-def sync_do_map(context, value, *args, **kwargs):
+def sync_do_map(
+ context: "Context", value: t.Iterable, *args: t.Any, **kwargs: t.Any
+) -> t.Iterable:
"""Applies a filter on a sequence of objects or looks up an attribute.
This is useful when dealing with lists of objects but you are really
only interested in a certain value of it.
@@ -1471,8 +1475,13 @@ def do_map(
...
-@async_variant(sync_do_map)
-async def do_map(context, value, *args, **kwargs):
+@async_variant(sync_do_map) # type: ignore
+async def do_map(
+ context: "Context",
+ value: t.Union[t.AsyncIterable, t.Iterable],
+ *args: t.Any,
+ **kwargs: t.Any,
+) -> t.AsyncIterable:
if value:
func = prepare_map(context, args, kwargs)
@@ -1511,7 +1520,7 @@ def sync_do_select(
return select_or_reject(context, value, args, kwargs, lambda x: x, False)
-@async_variant(sync_do_select)
+@async_variant(sync_do_select) # type: ignore
async def do_select(
context: "Context",
value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
@@ -1547,7 +1556,7 @@ def sync_do_reject(
return select_or_reject(context, value, args, kwargs, lambda x: not x, False)
-@async_variant(sync_do_reject)
+@async_variant(sync_do_reject) # type: ignore
async def do_reject(
context: "Context",
value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
@@ -1587,7 +1596,7 @@ def sync_do_selectattr(
return select_or_reject(context, value, args, kwargs, lambda x: x, True)
-@async_variant(sync_do_selectattr)
+@async_variant(sync_do_selectattr) # type: ignore
async def do_selectattr(
context: "Context",
value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
@@ -1625,7 +1634,7 @@ def sync_do_rejectattr(
return select_or_reject(context, value, args, kwargs, lambda x: not x, True)
-@async_variant(sync_do_rejectattr)
+@async_variant(sync_do_rejectattr) # type: ignore
async def do_rejectattr(
context: "Context",
value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
@@ -1684,7 +1693,7 @@ def prepare_map(
except LookupError:
raise FilterArgumentError("map requires a filter argument")
- def func(item):
+ def func(item: t.Any) -> t.Any:
return context.environment.call_filter(
name, item, args, kwargs, context=context
)
@@ -1710,18 +1719,18 @@ def prepare_select_or_reject(
else:
off = 0
- def transfunc(x):
+ def transfunc(x: V) -> V:
return x
try:
name = args[off]
args = args[1 + off :]
- def func(item):
+ def func(item: t.Any) -> t.Any:
return context.environment.call_test(name, item, args, kwargs)
except LookupError:
- func = bool
+ func = bool # type: ignore
return lambda item: modfunc(func(transfunc(item)))
diff --git a/src/jinja2/idtracking.py b/src/jinja2/idtracking.py
index 78cad91..a2e7a05 100644
--- a/src/jinja2/idtracking.py
+++ b/src/jinja2/idtracking.py
@@ -1,3 +1,6 @@
+import typing as t
+
+from . import nodes
from .visitor import NodeVisitor
VAR_LOAD_PARAMETER = "param"
@@ -6,7 +9,9 @@ VAR_LOAD_ALIAS = "alias"
VAR_LOAD_UNDEFINED = "undefined"
-def find_symbols(nodes, parent_symbols=None):
+def find_symbols(
+ nodes: t.Iterable[nodes.Node], parent_symbols: t.Optional["Symbols"] = None
+) -> "Symbols":
sym = Symbols(parent=parent_symbols)
visitor = FrameSymbolVisitor(sym)
for node in nodes:
@@ -14,49 +19,62 @@ def find_symbols(nodes, parent_symbols=None):
return sym
-def symbols_for_node(node, parent_symbols=None):
+def symbols_for_node(
+ node: nodes.Node, parent_symbols: t.Optional["Symbols"] = None
+) -> "Symbols":
sym = Symbols(parent=parent_symbols)
sym.analyze_node(node)
return sym
class Symbols:
- def __init__(self, parent=None, level=None):
+ def __init__(
+ self, parent: t.Optional["Symbols"] = None, level: t.Optional[int] = None
+ ) -> None:
if level is None:
if parent is None:
level = 0
else:
level = parent.level + 1
- self.level = level
+
+ self.level: int = level
self.parent = parent
- self.refs = {}
- self.loads = {}
- self.stores = set()
+ self.refs: t.Dict[str, str] = {}
+ self.loads: t.Dict[str, t.Any] = {}
+ self.stores: t.Set[str] = set()
- def analyze_node(self, node, **kwargs):
+ def analyze_node(self, node: nodes.Node, **kwargs: t.Any) -> None:
visitor = RootVisitor(self)
visitor.visit(node, **kwargs)
- def _define_ref(self, name, load=None):
+ def _define_ref(
+ self, name: str, load: t.Optional[t.Tuple[str, t.Optional[str]]] = None
+ ) -> str:
ident = f"l_{self.level}_{name}"
self.refs[name] = ident
if load is not None:
self.loads[ident] = load
return ident
- def find_load(self, target):
+ def find_load(self, target: str) -> t.Optional[t.Any]:
if target in self.loads:
return self.loads[target]
+
if self.parent is not None:
return self.parent.find_load(target)
- def find_ref(self, name):
+ return None
+
+ def find_ref(self, name: str) -> t.Optional[str]:
if name in self.refs:
return self.refs[name]
+
if self.parent is not None:
return self.parent.find_ref(name)
- def ref(self, name):
+ return None
+
+ def ref(self, name: str) -> str:
rv = self.find_ref(name)
if rv is None:
raise AssertionError(
@@ -65,15 +83,15 @@ class Symbols:
)
return rv
- def copy(self):
- rv = object.__new__(self.__class__)
+ def copy(self) -> "Symbols":
+ rv = t.cast(Symbols, object.__new__(self.__class__))
rv.__dict__.update(self.__dict__)
rv.refs = self.refs.copy()
rv.loads = self.loads.copy()
rv.stores = self.stores.copy()
return rv
- def store(self, name):
+ def store(self, name: str) -> None:
self.stores.add(name)
# If we have not see the name referenced yet, we need to figure
@@ -91,17 +109,16 @@ class Symbols:
# Otherwise we can just set it to undefined.
self._define_ref(name, load=(VAR_LOAD_UNDEFINED, None))
- def declare_parameter(self, name):
+ def declare_parameter(self, name: str) -> str:
self.stores.add(name)
return self._define_ref(name, load=(VAR_LOAD_PARAMETER, None))
- def load(self, name):
- target = self.find_ref(name)
- if target is None:
+ def load(self, name: str) -> None:
+ if self.find_ref(name) is None:
self._define_ref(name, load=(VAR_LOAD_RESOLVE, name))
- def branch_update(self, branch_symbols):
- stores = {}
+ def branch_update(self, branch_symbols: t.Sequence["Symbols"]) -> None:
+ stores: t.Dict[str, int] = {}
for branch in branch_symbols:
for target in branch.stores:
if target in self.stores:
@@ -116,7 +133,8 @@ class Symbols:
for name, branch_count in stores.items():
if branch_count == len(branch_symbols):
continue
- target = self.find_ref(name)
+
+ target = self.find_ref(name) # type: ignore
assert target is not None, "should not happen"
if self.parent is not None:
@@ -126,56 +144,64 @@ class Symbols:
continue
self.loads[target] = (VAR_LOAD_RESOLVE, name)
- def dump_stores(self):
- rv = {}
- node = self
+ def dump_stores(self) -> t.Dict[str, str]:
+ rv: t.Dict[str, str] = {}
+ node: t.Optional["Symbols"] = self
+
while node is not None:
for name in node.stores:
if name not in rv:
- rv[name] = self.find_ref(name)
+ rv[name] = self.find_ref(name) # type: ignore
+
node = node.parent
+
return rv
- def dump_param_targets(self):
+ def dump_param_targets(self) -> t.Set[str]:
rv = set()
- node = self
+ node: t.Optional["Symbols"] = self
+
while node is not None:
for target, (instr, _) in self.loads.items():
if instr == VAR_LOAD_PARAMETER:
rv.add(target)
+
node = node.parent
+
return rv
class RootVisitor(NodeVisitor):
- def __init__(self, symbols):
+ def __init__(self, symbols: "Symbols") -> None:
self.sym_visitor = FrameSymbolVisitor(symbols)
- def _simple_visit(self, node, **kwargs):
+ def _simple_visit(self, node: nodes.Node, **kwargs: t.Any) -> None:
for child in node.iter_child_nodes():
self.sym_visitor.visit(child)
- visit_Template = (
- visit_Block
- ) = (
- visit_Macro
- ) = (
- visit_FilterBlock
- ) = visit_Scope = visit_If = visit_ScopedEvalContextModifier = _simple_visit
+ visit_Template = _simple_visit
+ visit_Block = _simple_visit
+ visit_Macro = _simple_visit
+ visit_FilterBlock = _simple_visit
+ visit_Scope = _simple_visit
+ visit_If = _simple_visit
+ visit_ScopedEvalContextModifier = _simple_visit
- def visit_AssignBlock(self, node, **kwargs):
+ def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None:
for child in node.body:
self.sym_visitor.visit(child)
- def visit_CallBlock(self, node, **kwargs):
+ def visit_CallBlock(self, node: nodes.CallBlock, **kwargs: t.Any) -> None:
for child in node.iter_child_nodes(exclude=("call",)):
self.sym_visitor.visit(child)
- def visit_OverlayScope(self, node, **kwargs):
+ def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None:
for child in node.body:
self.sym_visitor.visit(child)
- def visit_For(self, node, for_branch="body", **kwargs):
+ def visit_For(
+ self, node: nodes.For, for_branch: str = "body", **kwargs: t.Any
+ ) -> None:
if for_branch == "body":
self.sym_visitor.visit(node.target, store_as_param=True)
branch = node.body
@@ -188,28 +214,30 @@ class RootVisitor(NodeVisitor):
return
else:
raise RuntimeError("Unknown for branch")
- for item in branch or ():
- self.sym_visitor.visit(item)
- def visit_With(self, node, **kwargs):
+ if branch:
+ for item in branch:
+ self.sym_visitor.visit(item)
+
+ def visit_With(self, node: nodes.With, **kwargs: t.Any) -> None:
for target in node.targets:
self.sym_visitor.visit(target)
for child in node.body:
self.sym_visitor.visit(child)
- def generic_visit(self, node, *args, **kwargs):
- raise NotImplementedError(
- f"Cannot find symbols for {node.__class__.__name__!r}"
- )
+ def generic_visit(self, node: nodes.Node, *args: t.Any, **kwargs: t.Any) -> None:
+ raise NotImplementedError(f"Cannot find symbols for {type(node).__name__!r}")
class FrameSymbolVisitor(NodeVisitor):
"""A visitor for `Frame.inspect`."""
- def __init__(self, symbols):
+ def __init__(self, symbols: "Symbols") -> None:
self.symbols = symbols
- def visit_Name(self, node, store_as_param=False, **kwargs):
+ def visit_Name(
+ self, node: nodes.Name, store_as_param: bool = False, **kwargs: t.Any
+ ) -> None:
"""All assignments to names go through this function."""
if store_as_param or node.ctx == "param":
self.symbols.declare_parameter(node.name)
@@ -218,72 +246,73 @@ class FrameSymbolVisitor(NodeVisitor):
elif node.ctx == "load":
self.symbols.load(node.name)
- def visit_NSRef(self, node, **kwargs):
+ def visit_NSRef(self, node: nodes.NSRef, **kwargs: t.Any) -> None:
self.symbols.load(node.name)
- def visit_If(self, node, **kwargs):
+ def visit_If(self, node: nodes.If, **kwargs: t.Any) -> None:
self.visit(node.test, **kwargs)
-
original_symbols = self.symbols
- def inner_visit(nodes):
+ def inner_visit(nodes: t.Iterable[nodes.Node]) -> "Symbols":
self.symbols = rv = original_symbols.copy()
+
for subnode in nodes:
self.visit(subnode, **kwargs)
+
self.symbols = original_symbols
return rv
body_symbols = inner_visit(node.body)
elif_symbols = inner_visit(node.elif_)
else_symbols = inner_visit(node.else_ or ())
-
self.symbols.branch_update([body_symbols, elif_symbols, else_symbols])
- def visit_Macro(self, node, **kwargs):
+ def visit_Macro(self, node: nodes.Macro, **kwargs: t.Any) -> None:
self.symbols.store(node.name)
- def visit_Import(self, node, **kwargs):
+ def visit_Import(self, node: nodes.Import, **kwargs: t.Any) -> None:
self.generic_visit(node, **kwargs)
self.symbols.store(node.target)
- def visit_FromImport(self, node, **kwargs):
+ def visit_FromImport(self, node: nodes.FromImport, **kwargs: t.Any) -> None:
self.generic_visit(node, **kwargs)
+
for name in node.names:
if isinstance(name, tuple):
self.symbols.store(name[1])
else:
self.symbols.store(name)
- def visit_Assign(self, node, **kwargs):
+ def visit_Assign(self, node: nodes.Assign, **kwargs: t.Any) -> None:
"""Visit assignments in the correct order."""
self.visit(node.node, **kwargs)
self.visit(node.target, **kwargs)
- def visit_For(self, node, **kwargs):
+ def visit_For(self, node: nodes.For, **kwargs: t.Any) -> None:
"""Visiting stops at for blocks. However the block sequence
is visited as part of the outer scope.
"""
self.visit(node.iter, **kwargs)
- def visit_CallBlock(self, node, **kwargs):
+ def visit_CallBlock(self, node: nodes.CallBlock, **kwargs: t.Any) -> None:
self.visit(node.call, **kwargs)
- def visit_FilterBlock(self, node, **kwargs):
+ def visit_FilterBlock(self, node: nodes.FilterBlock, **kwargs: t.Any) -> None:
self.visit(node.filter, **kwargs)
- def visit_With(self, node, **kwargs):
+ def visit_With(self, node: nodes.With, **kwargs: t.Any) -> None:
for target in node.values:
self.visit(target)
- def visit_AssignBlock(self, node, **kwargs):
+ def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None:
"""Stop visiting at block assigns."""
self.visit(node.target, **kwargs)
- def visit_Scope(self, node, **kwargs):
+ def visit_Scope(self, node: nodes.Scope, **kwargs: t.Any) -> None:
"""Stop visiting at scopes."""
- def visit_Block(self, node, **kwargs):
+ def visit_Block(self, node: nodes.Block, **kwargs: t.Any) -> None:
"""Stop visiting at blocks."""
- def visit_OverlayScope(self, node, **kwargs):
+ def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None:
"""Do not visit into overlay scopes."""
diff --git a/src/jinja2/lexer.py b/src/jinja2/lexer.py
index 0cade7a..c151582 100644
--- a/src/jinja2/lexer.py
+++ b/src/jinja2/lexer.py
@@ -4,18 +4,21 @@ the bitshift operators we don't allow in templates. It separates
template code and python code in expressions.
"""
import re
+import typing as t
from ast import literal_eval
from collections import deque
-from operator import itemgetter
from sys import intern
from ._identifier import pattern as name_re
from .exceptions import TemplateSyntaxError
from .utils import LRUCache
+if t.TYPE_CHECKING:
+ from .environment import Environment
+
# cache for the lexers. Exists in order to be able to have multiple
# environments with the same lexer
-_lexer_cache = LRUCache(50)
+_lexer_cache: t.MutableMapping[t.Tuple, "Lexer"] = LRUCache(50) # type: ignore
# static regular expressions
whitespace_re = re.compile(r"\s+")
@@ -156,9 +159,10 @@ ignore_if_empty = frozenset(
)
-def _describe_token_type(token_type):
+def _describe_token_type(token_type: str) -> str:
if token_type in reverse_operators:
return reverse_operators[token_type]
+
return {
TOKEN_COMMENT_BEGIN: "begin of comment",
TOKEN_COMMENT_END: "end of comment",
@@ -175,32 +179,35 @@ def _describe_token_type(token_type):
}.get(token_type, token_type)
-def describe_token(token):
+def describe_token(token: "Token") -> str:
"""Returns a description of the token."""
if token.type == TOKEN_NAME:
return token.value
+
return _describe_token_type(token.type)
-def describe_token_expr(expr):
+def describe_token_expr(expr: str) -> str:
"""Like `describe_token` but for token expressions."""
if ":" in expr:
type, value = expr.split(":", 1)
+
if type == TOKEN_NAME:
return value
else:
type = expr
+
return _describe_token_type(type)
-def count_newlines(value):
+def count_newlines(value: str) -> int:
"""Count the number of newline characters in the string. This is
useful for extensions that filter a stream.
"""
return len(newline_re.findall(value))
-def compile_rules(environment):
+def compile_rules(environment: "Environment") -> t.List[t.Tuple[str, str]]:
"""Compiles all the rules from the environment into a list of rules."""
e = re.escape
rules = [
@@ -246,31 +253,25 @@ class Failure:
Used by the `Lexer` to specify known errors.
"""
- def __init__(self, message, cls=TemplateSyntaxError):
+ def __init__(
+ self, message: str, cls: t.Type[TemplateSyntaxError] = TemplateSyntaxError
+ ) -> None:
self.message = message
self.error_class = cls
- def __call__(self, lineno, filename):
+ def __call__(self, lineno: int, filename: str) -> t.NoReturn:
raise self.error_class(self.message, lineno, filename)
-class Token(tuple):
- """Token class."""
-
- __slots__ = ()
- lineno, type, value = (property(itemgetter(x)) for x in range(3))
-
- def __new__(cls, lineno, type, value):
- return tuple.__new__(cls, (lineno, intern(str(type)), value))
+class Token(t.NamedTuple):
+ lineno: int
+ type: str
+ value: str
- def __str__(self):
- if self.type in reverse_operators:
- return reverse_operators[self.type]
- elif self.type == "name":
- return self.value
- return self.type
+ def __str__(self) -> str:
+ return describe_token(self)
- def test(self, expr):
+ def test(self, expr: str) -> bool:
"""Test a token against a token expression. This can either be a
token type or ``'token_type:token_value'``. This can only test
against string values and types.
@@ -279,19 +280,15 @@ class Token(tuple):
# passed an iterable of not interned strings.
if self.type == expr:
return True
- elif ":" in expr:
+
+ if ":" in expr:
return expr.split(":", 1) == [self.type, self.value]
- return False
- def test_any(self, *iterable):
- """Test against multiple token expressions."""
- for expr in iterable:
- if self.test(expr):
- return True
return False
- def __repr__(self):
- return f"Token({self.lineno!r}, {self.type!r}, {self.value!r})"
+ def test_any(self, *iterable: str) -> bool:
+ """Test against multiple token expressions."""
+ return any(self.test(expr) for expr in iterable)
class TokenStreamIterator:
@@ -299,17 +296,19 @@ class TokenStreamIterator:
until the eof token is reached.
"""
- def __init__(self, stream):
+ def __init__(self, stream: "TokenStream") -> None:
self.stream = stream
- def __iter__(self):
+ def __iter__(self) -> "TokenStreamIterator":
return self
- def __next__(self):
+ def __next__(self) -> Token:
token = self.stream.current
+
if token.type is TOKEN_EOF:
self.stream.close()
- raise StopIteration()
+ raise StopIteration
+
next(self.stream)
return token
@@ -320,33 +319,36 @@ class TokenStream:
one token ahead. The current active token is stored as :attr:`current`.
"""
- def __init__(self, generator, name, filename):
+ def __init__(
+ self,
+ generator: t.Iterable[Token],
+ name: t.Optional[str],
+ filename: t.Optional[str],
+ ):
self._iter = iter(generator)
- self._pushed = deque()
+ self._pushed: t.Deque[Token] = deque()
self.name = name
self.filename = filename
self.closed = False
self.current = Token(1, TOKEN_INITIAL, "")
next(self)
- def __iter__(self):
+ def __iter__(self) -> TokenStreamIterator:
return TokenStreamIterator(self)
- def __bool__(self):
+ def __bool__(self) -> bool:
return bool(self._pushed) or self.current.type is not TOKEN_EOF
- __nonzero__ = __bool__ # py2
-
@property
- def eos(self):
+ def eos(self) -> bool:
"""Are we at the end of the stream?"""
return not self
- def push(self, token):
+ def push(self, token: Token) -> None:
"""Push a token back to the stream."""
self._pushed.append(token)
- def look(self):
+ def look(self) -> Token:
"""Look at the next token."""
old_token = next(self)
result = self.current
@@ -354,28 +356,31 @@ class TokenStream:
self.current = old_token
return result
- def skip(self, n=1):
+ def skip(self, n: int = 1) -> None:
"""Got n tokens ahead."""
for _ in range(n):
next(self)
- def next_if(self, expr):
+ def next_if(self, expr: str) -> t.Optional[Token]:
"""Perform the token test and return the token if it matched.
Otherwise the return value is `None`.
"""
if self.current.test(expr):
return next(self)
- def skip_if(self, expr):
+ return None
+
+ def skip_if(self, expr: str) -> bool:
"""Like :meth:`next_if` but only returns `True` or `False`."""
return self.next_if(expr) is not None
- def __next__(self):
+ def __next__(self) -> Token:
"""Go one token ahead and return the old one.
Use the built-in :func:`next` instead of calling this directly.
"""
rv = self.current
+
if self._pushed:
self.current = self._pushed.popleft()
elif self.current.type is not TOKEN_EOF:
@@ -383,20 +388,22 @@ class TokenStream:
self.current = next(self._iter)
except StopIteration:
self.close()
+
return rv
- def close(self):
+ def close(self) -> None:
"""Close the stream."""
self.current = Token(self.current.lineno, TOKEN_EOF, "")
- self._iter = None
+ self._iter = iter(())
self.closed = True
- def expect(self, expr):
+ def expect(self, expr: str) -> Token:
"""Expect a given token type and return it. This accepts the same
argument as :meth:`jinja2.lexer.Token.test`.
"""
if not self.current.test(expr):
expr = describe_token_expr(expr)
+
if self.current.type is TOKEN_EOF:
raise TemplateSyntaxError(
f"unexpected end of template, expected {expr!r}.",
@@ -404,19 +411,18 @@ class TokenStream:
self.name,
self.filename,
)
+
raise TemplateSyntaxError(
f"expected token {expr!r}, got {describe_token(self.current)!r}",
self.current.lineno,
self.name,
self.filename,
)
- try:
- return self.current
- finally:
- next(self)
+ return next(self)
-def get_lexer(environment):
+
+def get_lexer(environment: "Environment") -> "Lexer":
"""Return a lexer which is probably cached."""
key = (
environment.block_start_string,
@@ -433,9 +439,10 @@ def get_lexer(environment):
environment.keep_trailing_newline,
)
lexer = _lexer_cache.get(key)
+
if lexer is None:
- lexer = Lexer(environment)
- _lexer_cache[key] = lexer
+ _lexer_cache[key] = lexer = Lexer(environment)
+
return lexer
@@ -448,10 +455,16 @@ class OptionalLStrip(tuple):
# Even though it looks like a no-op, creating instances fails
# without this.
- def __new__(cls, *members, **kwargs):
+ def __new__(cls, *members, **kwargs): # type: ignore
return super().__new__(cls, members)
+class _Rule(t.NamedTuple):
+ pattern: t.Pattern[str]
+ tokens: t.Union[str, t.Tuple[str, ...], t.Tuple[Failure]]
+ command: t.Optional[str]
+
+
class Lexer:
"""Class that implements a lexer for a given environment. Automatically
created by the environment class, usually you don't have to do that.
@@ -460,21 +473,21 @@ class Lexer:
Multiple environments can share the same lexer.
"""
- def __init__(self, environment):
+ def __init__(self, environment: "Environment") -> None:
# shortcuts
e = re.escape
- def c(x):
+ def c(x: str) -> t.Pattern[str]:
return re.compile(x, re.M | re.S)
# lexing rules for tags
- tag_rules = [
- (whitespace_re, TOKEN_WHITESPACE, None),
- (float_re, TOKEN_FLOAT, None),
- (integer_re, TOKEN_INTEGER, None),
- (name_re, TOKEN_NAME, None),
- (string_re, TOKEN_STRING, None),
- (operator_re, TOKEN_OPERATOR, None),
+ tag_rules: t.List[_Rule] = [
+ _Rule(whitespace_re, TOKEN_WHITESPACE, None),
+ _Rule(float_re, TOKEN_FLOAT, None),
+ _Rule(integer_re, TOKEN_INTEGER, None),
+ _Rule(name_re, TOKEN_NAME, None),
+ _Rule(string_re, TOKEN_STRING, None),
+ _Rule(operator_re, TOKEN_OPERATOR, None),
]
# assemble the root lexing rule. because "|" is ungreedy
@@ -509,20 +522,20 @@ class Lexer:
)
# global lexing rules
- self.rules = {
+ self.rules: t.Dict[str, t.List[_Rule]] = {
"root": [
# directives
- (
+ _Rule(
c(fr"(.*?)(?:{root_parts_re})"),
- OptionalLStrip(TOKEN_DATA, "#bygroup"),
+ OptionalLStrip(TOKEN_DATA, "#bygroup"), # type: ignore
"#bygroup",
),
# data
- (c(".+"), TOKEN_DATA, None),
+ _Rule(c(".+"), TOKEN_DATA, None),
],
# comments
TOKEN_COMMENT_BEGIN: [
- (
+ _Rule(
c(
fr"(.*?)((?:\+{comment_end_re}|\-{comment_end_re}\s*"
fr"|{comment_end_re}{block_suffix_re}))"
@@ -530,11 +543,11 @@ class Lexer:
(TOKEN_COMMENT, TOKEN_COMMENT_END),
"#pop",
),
- (c(r"(.)"), (Failure("Missing end of comment tag"),), None),
+ _Rule(c(r"(.)"), (Failure("Missing end of comment tag"),), None),
],
# blocks
TOKEN_BLOCK_BEGIN: [
- (
+ _Rule(
c(
fr"(?:\+{block_end_re}|\-{block_end_re}\s*"
fr"|{block_end_re}{block_suffix_re})"
@@ -546,7 +559,7 @@ class Lexer:
+ tag_rules,
# variables
TOKEN_VARIABLE_BEGIN: [
- (
+ _Rule(
c(fr"\-{variable_end_re}\s*|{variable_end_re}"),
TOKEN_VARIABLE_END,
"#pop",
@@ -555,25 +568,25 @@ class Lexer:
+ tag_rules,
# raw block
TOKEN_RAW_BEGIN: [
- (
+ _Rule(
c(
fr"(.*?)((?:{block_start_re}(\-|\+|))\s*endraw\s*"
fr"(?:\+{block_end_re}|\-{block_end_re}\s*"
fr"|{block_end_re}{block_suffix_re}))"
),
- OptionalLStrip(TOKEN_DATA, TOKEN_RAW_END),
+ OptionalLStrip(TOKEN_DATA, TOKEN_RAW_END), # type: ignore
"#pop",
),
- (c(r"(.)"), (Failure("Missing end of raw directive"),), None),
+ _Rule(c(r"(.)"), (Failure("Missing end of raw directive"),), None),
],
# line statements
TOKEN_LINESTATEMENT_BEGIN: [
- (c(r"\s*(\n|$)"), TOKEN_LINESTATEMENT_END, "#pop")
+ _Rule(c(r"\s*(\n|$)"), TOKEN_LINESTATEMENT_END, "#pop")
]
+ tag_rules,
# line comments
TOKEN_LINECOMMENT_BEGIN: [
- (
+ _Rule(
c(r"(.*?)()(?=\n|$)"),
(TOKEN_LINECOMMENT, TOKEN_LINECOMMENT_END),
"#pop",
@@ -581,25 +594,39 @@ class Lexer:
],
}
- def _normalize_newlines(self, value):
+ def _normalize_newlines(self, value: str) -> str:
"""Replace all newlines with the configured sequence in strings
and template data.
"""
return newline_re.sub(self.newline_sequence, value)
- def tokenize(self, source, name=None, filename=None, state=None):
+ def tokenize(
+ self,
+ source: str,
+ name: t.Optional[str] = None,
+ filename: t.Optional[str] = None,
+ state: t.Optional[str] = None,
+ ) -> TokenStream:
"""Calls tokeniter + tokenize and wraps it in a token stream."""
stream = self.tokeniter(source, name, filename, state)
return TokenStream(self.wrap(stream, name, filename), name, filename)
- def wrap(self, stream, name=None, filename=None):
+ def wrap(
+ self,
+ stream: t.Iterable[t.Tuple[int, str, str]],
+ name: t.Optional[str] = None,
+ filename: t.Optional[str] = None,
+ ) -> t.Iterator[Token]:
"""This is called with the stream as returned by `tokenize` and wraps
every token in a :class:`Token` and converts the value.
"""
- for lineno, token, value in stream:
+ for lineno, token, value_str in stream:
if token in ignored_tokens:
continue
- elif token == TOKEN_LINESTATEMENT_BEGIN:
+
+ value: t.Any = value_str
+
+ if token == TOKEN_LINESTATEMENT_BEGIN:
token = TOKEN_BLOCK_BEGIN
elif token == TOKEN_LINESTATEMENT_END:
token = TOKEN_BLOCK_END
@@ -607,11 +634,12 @@ class Lexer:
elif token in (TOKEN_RAW_BEGIN, TOKEN_RAW_END):
continue
elif token == TOKEN_DATA:
- value = self._normalize_newlines(value)
+ value = self._normalize_newlines(value_str)
elif token == "keyword":
- token = value
+ token = value_str
elif token == TOKEN_NAME:
- value = str(value)
+ value = value_str
+
if not value.isidentifier():
raise TemplateSyntaxError(
"Invalid character in identifier", lineno, name, filename
@@ -620,7 +648,7 @@ class Lexer:
# try to unescape string
try:
value = (
- self._normalize_newlines(value[1:-1])
+ self._normalize_newlines(value_str[1:-1])
.encode("ascii", "backslashreplace")
.decode("unicode-escape")
)
@@ -628,15 +656,22 @@ class Lexer:
msg = str(e).split(":")[-1].strip()
raise TemplateSyntaxError(msg, lineno, name, filename)
elif token == TOKEN_INTEGER:
- value = int(value.replace("_", ""), 0)
+ value = int(value_str.replace("_", ""), 0)
elif token == TOKEN_FLOAT:
# remove all "_" first to support more Python versions
- value = literal_eval(value.replace("_", ""))
+ value = literal_eval(value_str.replace("_", ""))
elif token == TOKEN_OPERATOR:
- token = operators[value]
+ token = operators[value_str]
+
yield Token(lineno, token, value)
- def tokeniter(self, source, name, filename=None, state=None):
+ def tokeniter(
+ self,
+ source: str,
+ name: t.Optional[str],
+ filename: t.Optional[str] = None,
+ state: t.Optional[str] = None,
+ ) -> t.Iterator[t.Tuple[int, str, str]]:
"""This method tokenizes the text and returns the tokens in a
generator. Use this method if you just want to tokenize a template.
@@ -653,20 +688,23 @@ class Lexer:
pos = 0
lineno = 1
stack = ["root"]
+
if state is not None and state != "root":
assert state in ("variable", "block"), "invalid state"
stack.append(state + "_begin")
+
statetokens = self.rules[stack[-1]]
source_length = len(source)
- balancing_stack = []
+ balancing_stack: t.List[str] = []
lstrip_unless_re = self.lstrip_unless_re
newlines_stripped = 0
line_starting = True
- while 1:
+ while True:
# tokenizer loop
for regex, tokens, new_state in statetokens:
m = regex.match(source, pos)
+
# if no match we try again with the next rule
if m is None:
continue
@@ -690,7 +728,6 @@ class Lexer:
# Rule supports lstrip. Match will look like
# text, block type, whitespace control, type, control, ...
text = groups[0]
-
# Skipping the text and first type, every other group is the
# whitespace control for each type. One of the groups will be
# -, +, or empty string instead of None.
@@ -700,7 +737,7 @@ class Lexer:
# Strip all whitespace between the text and the tag.
stripped = text.rstrip()
newlines_stripped = text[len(stripped) :].count("\n")
- groups = (stripped,) + groups[1:]
+ groups = [stripped, *groups[1:]]
elif (
# Not marked for preserving whitespace.
strip_sign != "+"
@@ -711,11 +748,12 @@ class Lexer:
):
# The start of text between the last newline and the tag.
l_pos = text.rfind("\n") + 1
+
if l_pos > 0 or line_starting:
# If there's only whitespace between the newline and the
# tag, strip it.
if not lstrip_unless_re.search(text, l_pos):
- groups = (text[:l_pos],) + groups[1:]
+ groups = [text[:l_pos], *groups[1:]]
for idx, token in enumerate(tokens):
# failure group
@@ -738,14 +776,17 @@ class Lexer:
# normal group
else:
data = groups[idx]
+
if data or token not in ignore_if_empty:
yield lineno, token, data
+
lineno += data.count("\n") + newlines_stripped
newlines_stripped = 0
# strings as token just are yielded as it.
else:
data = m.group()
+
# update brace/parentheses balance
if tokens == TOKEN_OPERATOR:
if data == "{":
@@ -759,7 +800,9 @@ class Lexer:
raise TemplateSyntaxError(
f"unexpected '{data}'", lineno, name, filename
)
+
expected_op = balancing_stack.pop()
+
if expected_op != data:
raise TemplateSyntaxError(
f"unexpected '{data}', expected '{expected_op}'",
@@ -767,13 +810,14 @@ class Lexer:
name,
filename,
)
+
# yield items
if data or tokens not in ignore_if_empty:
yield lineno, tokens, data
+
lineno += data.count("\n")
line_starting = m.group()[-1:] == "\n"
-
# fetch new position into new variable so that we can check
# if there is a internal parsing error which would result
# in an infinite loop
@@ -798,6 +842,7 @@ class Lexer:
# direct state name given
else:
stack.append(new_state)
+
statetokens = self.rules[stack[-1]]
# we are still at the same position and no stack change.
# this means a loop without break condition, avoid that and
@@ -806,6 +851,7 @@ class Lexer:
raise RuntimeError(
f"{regex!r} yielded empty string without stack change"
)
+
# publish new function and start again
pos = pos2
break
@@ -815,6 +861,7 @@ class Lexer:
# end of text
if pos >= source_length:
return
+
# something went wrong
raise TemplateSyntaxError(
f"unexpected char {source[pos]!r} at {pos}", lineno, name, filename
diff --git a/src/jinja2/loaders.py b/src/jinja2/loaders.py
index 6b71b83..bde6a1c 100644
--- a/src/jinja2/loaders.py
+++ b/src/jinja2/loaders.py
@@ -4,6 +4,7 @@ sources.
import importlib.util
import os
import sys
+import typing as t
import weakref
import zipimport
from collections import abc
@@ -15,8 +16,12 @@ from .exceptions import TemplateNotFound
from .utils import internalcode
from .utils import open_if_exists
+if t.TYPE_CHECKING:
+ from .environment import Environment
+ from .environment import Template
-def split_template_path(template):
+
+def split_template_path(template: str) -> t.List[str]:
"""Split a path into segments and perform a sanity check. If it detects
'..' in the path it will raise a `TemplateNotFound` error.
"""
@@ -66,7 +71,9 @@ class BaseLoader:
#: .. versionadded:: 2.4
has_source_access = True
- def get_source(self, environment, template):
+ def get_source(
+ self, environment: "Environment", template: str
+ ) -> t.Tuple[str, t.Optional[str], t.Optional[t.Callable[[], bool]]]:
"""Get the template source, filename and reload helper for a template.
It's passed the environment and template name and has to return a
tuple in the form ``(source, filename, uptodate)`` or raise a
@@ -86,18 +93,23 @@ class BaseLoader:
"""
if not self.has_source_access:
raise RuntimeError(
- f"{self.__class__.__name__} cannot provide access to the source"
+ f"{type(self).__name__} cannot provide access to the source"
)
raise TemplateNotFound(template)
- def list_templates(self):
+ def list_templates(self) -> t.List[str]:
"""Iterates over all templates. If the loader does not support that
it should raise a :exc:`TypeError` which is the default behavior.
"""
raise TypeError("this loader cannot iterate over all templates")
@internalcode
- def load(self, environment, name, globals=None):
+ def load(
+ self,
+ environment: "Environment",
+ name: str,
+ globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
+ ) -> "Template":
"""Loads a template. This method looks up the template in the cache
or loads one by calling :meth:`get_source`. Subclasses should not
override this method as loaders working on collections of other
@@ -163,15 +175,22 @@ class FileSystemLoader(BaseLoader):
Added the ``followlinks`` parameter.
"""
- def __init__(self, searchpath, encoding="utf-8", followlinks=False):
+ def __init__(
+ self,
+ searchpath: t.Union[str, os.PathLike, t.Sequence[t.Union[str, os.PathLike]]],
+ encoding: str = "utf-8",
+ followlinks: bool = False,
+ ) -> None:
if not isinstance(searchpath, abc.Iterable) or isinstance(searchpath, str):
searchpath = [searchpath]
- self.searchpath = list(searchpath)
+ self.searchpath = [os.fspath(p) for p in searchpath]
self.encoding = encoding
self.followlinks = followlinks
- def get_source(self, environment, template):
+ def get_source(
+ self, environment: "Environment", template: str
+ ) -> t.Tuple[str, str, t.Callable[[], bool]]:
pieces = split_template_path(template)
for searchpath in self.searchpath:
filename = os.path.join(searchpath, *pieces)
@@ -185,7 +204,7 @@ class FileSystemLoader(BaseLoader):
mtime = os.path.getmtime(filename)
- def uptodate():
+ def uptodate() -> bool:
try:
return os.path.getmtime(filename) == mtime
except OSError:
@@ -194,7 +213,7 @@ class FileSystemLoader(BaseLoader):
return contents, filename, uptodate
raise TemplateNotFound(template)
- def list_templates(self):
+ def list_templates(self) -> t.List[str]:
found = set()
for searchpath in self.searchpath:
walk_dir = os.walk(searchpath, followlinks=self.followlinks)
@@ -245,7 +264,12 @@ class PackageLoader(BaseLoader):
Limited PEP 420 namespace package support.
"""
- def __init__(self, package_name, package_path="templates", encoding="utf-8"):
+ def __init__(
+ self,
+ package_name: str,
+ package_path: "str" = "templates",
+ encoding: str = "utf-8",
+ ) -> None:
if package_path == os.path.curdir:
package_path = ""
elif package_path[:2] == os.path.curdir + os.path.sep:
@@ -260,14 +284,17 @@ class PackageLoader(BaseLoader):
# packages work, otherwise get_loader returns None.
import_module(package_name)
spec = importlib.util.find_spec(package_name)
- self._loader = loader = spec.loader
+ assert spec is not None, "An import spec was not found for the package."
+ loader = spec.loader
+ assert loader is not None, "A loader was not found for the package."
+ self._loader = loader
self._archive = None
- self._template_root = None
+ template_root = None
if isinstance(loader, zipimport.zipimporter):
self._archive = loader.archive
- pkgdir = next(iter(spec.submodule_search_locations))
- self._template_root = os.path.join(pkgdir, package_path)
+ pkgdir = next(iter(spec.submodule_search_locations)) # type: ignore
+ template_root = os.path.join(pkgdir, package_path)
elif spec.submodule_search_locations:
# This will be one element for regular packages and multiple
# for namespace packages.
@@ -275,17 +302,22 @@ class PackageLoader(BaseLoader):
root = os.path.join(root, package_path)
if os.path.isdir(root):
- self._template_root = root
+ template_root = root
break
- if self._template_root is None:
+ if template_root is None:
raise ValueError(
f"The {package_name!r} package was not installed in a"
" way that PackageLoader understands."
)
- def get_source(self, environment, template):
+ self._template_root = template_root
+
+ def get_source(
+ self, environment: "Environment", template: str
+ ) -> t.Tuple[str, str, t.Optional[t.Callable[[], bool]]]:
p = os.path.join(self._template_root, *split_template_path(template))
+ up_to_date: t.Optional[t.Callable[[], bool]]
if self._archive is None:
# Package is a directory.
@@ -297,13 +329,13 @@ class PackageLoader(BaseLoader):
mtime = os.path.getmtime(p)
- def up_to_date():
+ def up_to_date() -> bool:
return os.path.isfile(p) and os.path.getmtime(p) == mtime
else:
# Package is a zip file.
try:
- source = self._loader.get_data(p)
+ source = self._loader.get_data(p) # type: ignore
except OSError:
raise TemplateNotFound(template)
@@ -314,8 +346,8 @@ class PackageLoader(BaseLoader):
return source.decode(self.encoding), p, up_to_date
- def list_templates(self):
- results = []
+ def list_templates(self) -> t.List[str]:
+ results: t.List[str] = []
if self._archive is None:
# Package is a directory.
@@ -341,7 +373,7 @@ class PackageLoader(BaseLoader):
)
offset = len(prefix)
- for name in self._loader._files.keys():
+ for name in self._loader._files.keys(): # type: ignore
# Find names under the templates directory that aren't directories.
if name.startswith(prefix) and name[-1] != os.path.sep:
results.append(name[offset:].replace(os.path.sep, "/"))
@@ -359,16 +391,18 @@ class DictLoader(BaseLoader):
Because auto reloading is rarely useful this is disabled per default.
"""
- def __init__(self, mapping):
+ def __init__(self, mapping: t.Mapping[str, str]) -> None:
self.mapping = mapping
- def get_source(self, environment, template):
+ def get_source(
+ self, environment: "Environment", template: str
+ ) -> t.Tuple[str, None, t.Callable[[], bool]]:
if template in self.mapping:
source = self.mapping[template]
return source, None, lambda: source == self.mapping.get(template)
raise TemplateNotFound(template)
- def list_templates(self):
+ def list_templates(self) -> t.List[str]:
return sorted(self.mapping)
@@ -390,15 +424,30 @@ class FunctionLoader(BaseLoader):
return value.
"""
- def __init__(self, load_func):
+ def __init__(
+ self,
+ load_func: t.Callable[
+ [str],
+ t.Optional[
+ t.Union[
+ str, t.Tuple[str, t.Optional[str], t.Optional[t.Callable[[], bool]]]
+ ]
+ ],
+ ],
+ ) -> None:
self.load_func = load_func
- def get_source(self, environment, template):
+ def get_source(
+ self, environment: "Environment", template: str
+ ) -> t.Tuple[str, t.Optional[str], t.Optional[t.Callable[[], bool]]]:
rv = self.load_func(template)
+
if rv is None:
raise TemplateNotFound(template)
- elif isinstance(rv, str):
+
+ if isinstance(rv, str):
return rv, None, None
+
return rv
@@ -417,11 +466,13 @@ class PrefixLoader(BaseLoader):
by loading ``'app2/index.html'`` the file from the second.
"""
- def __init__(self, mapping, delimiter="/"):
+ def __init__(
+ self, mapping: t.Mapping[str, BaseLoader], delimiter: str = "/"
+ ) -> None:
self.mapping = mapping
self.delimiter = delimiter
- def get_loader(self, template):
+ def get_loader(self, template: str) -> t.Tuple[BaseLoader, str]:
try:
prefix, name = template.split(self.delimiter, 1)
loader = self.mapping[prefix]
@@ -429,7 +480,9 @@ class PrefixLoader(BaseLoader):
raise TemplateNotFound(template)
return loader, name
- def get_source(self, environment, template):
+ def get_source(
+ self, environment: "Environment", template: str
+ ) -> t.Tuple[str, t.Optional[str], t.Optional[t.Callable[[], bool]]]:
loader, name = self.get_loader(template)
try:
return loader.get_source(environment, name)
@@ -439,7 +492,12 @@ class PrefixLoader(BaseLoader):
raise TemplateNotFound(template)
@internalcode
- def load(self, environment, name, globals=None):
+ def load(
+ self,
+ environment: "Environment",
+ name: str,
+ globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
+ ) -> "Template":
loader, local_name = self.get_loader(name)
try:
return loader.load(environment, local_name, globals)
@@ -448,7 +506,7 @@ class PrefixLoader(BaseLoader):
# (the one that includes the prefix)
raise TemplateNotFound(name)
- def list_templates(self):
+ def list_templates(self) -> t.List[str]:
result = []
for prefix, loader in self.mapping.items():
for template in loader.list_templates():
@@ -470,10 +528,12 @@ class ChoiceLoader(BaseLoader):
from a different location.
"""
- def __init__(self, loaders):
+ def __init__(self, loaders: t.Sequence[BaseLoader]) -> None:
self.loaders = loaders
- def get_source(self, environment, template):
+ def get_source(
+ self, environment: "Environment", template: str
+ ) -> t.Tuple[str, t.Optional[str], t.Optional[t.Callable[[], bool]]]:
for loader in self.loaders:
try:
return loader.get_source(environment, template)
@@ -482,7 +542,12 @@ class ChoiceLoader(BaseLoader):
raise TemplateNotFound(template)
@internalcode
- def load(self, environment, name, globals=None):
+ def load(
+ self,
+ environment: "Environment",
+ name: str,
+ globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
+ ) -> "Template":
for loader in self.loaders:
try:
return loader.load(environment, name, globals)
@@ -490,7 +555,7 @@ class ChoiceLoader(BaseLoader):
pass
raise TemplateNotFound(name)
- def list_templates(self):
+ def list_templates(self) -> t.List[str]:
found = set()
for loader in self.loaders:
found.update(loader.list_templates())
@@ -516,7 +581,9 @@ class ModuleLoader(BaseLoader):
has_source_access = False
- def __init__(self, path):
+ def __init__(
+ self, path: t.Union[str, os.PathLike, t.Sequence[t.Union[str, os.PathLike]]]
+ ) -> None:
package_name = f"_jinja2_module_templates_{id(self):x}"
# create a fake module that looks for the templates in the
@@ -526,7 +593,7 @@ class ModuleLoader(BaseLoader):
if not isinstance(path, abc.Iterable) or isinstance(path, str):
path = [path]
- mod.__path__ = [os.fspath(p) for p in path]
+ mod.__path__ = [os.fspath(p) for p in path] # type: ignore
sys.modules[package_name] = weakref.proxy(
mod, lambda x: sys.modules.pop(package_name, None)
@@ -539,18 +606,24 @@ class ModuleLoader(BaseLoader):
self.package_name = package_name
@staticmethod
- def get_template_key(name):
+ def get_template_key(name: str) -> str:
return "tmpl_" + sha1(name.encode("utf-8")).hexdigest()
@staticmethod
- def get_module_filename(name):
+ def get_module_filename(name: str) -> str:
return ModuleLoader.get_template_key(name) + ".py"
@internalcode
- def load(self, environment, name, globals=None):
+ def load(
+ self,
+ environment: "Environment",
+ name: str,
+ globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
+ ) -> "Template":
key = self.get_template_key(name)
module = f"{self.package_name}.{key}"
mod = getattr(self.module, module, None)
+
if mod is None:
try:
mod = __import__(module, None, None, ["root"])
@@ -561,6 +634,9 @@ class ModuleLoader(BaseLoader):
# on the module object we have stored on the loader.
sys.modules.pop(module, None)
+ if globals is None:
+ globals = {}
+
return environment.template_class.from_module_dict(
environment, mod.__dict__, globals
)
diff --git a/src/jinja2/meta.py b/src/jinja2/meta.py
index 899e179..0057d6e 100644
--- a/src/jinja2/meta.py
+++ b/src/jinja2/meta.py
@@ -1,29 +1,36 @@
"""Functions that expose information about templates that might be
interesting for introspection.
"""
+import typing as t
+
from . import nodes
from .compiler import CodeGenerator
+from .compiler import Frame
+
+if t.TYPE_CHECKING:
+ from .environment import Environment
class TrackingCodeGenerator(CodeGenerator):
"""We abuse the code generator for introspection."""
- def __init__(self, environment):
- CodeGenerator.__init__(self, environment, "<introspection>", "<introspection>")
- self.undeclared_identifiers = set()
+ def __init__(self, environment: "Environment") -> None:
+ super().__init__(environment, "<introspection>", "<introspection>")
+ self.undeclared_identifiers: t.Set[str] = set()
- def write(self, x):
+ def write(self, x: str) -> None:
"""Don't write."""
- def enter_frame(self, frame):
+ def enter_frame(self, frame: Frame) -> None:
"""Remember all undeclared identifiers."""
- CodeGenerator.enter_frame(self, frame)
+ super().enter_frame(frame)
+
for _, (action, param) in frame.symbols.loads.items():
if action == "resolve" and param not in self.environment.globals:
self.undeclared_identifiers.add(param)
-def find_undeclared_variables(ast):
+def find_undeclared_variables(ast: nodes.Template) -> t.Set[str]:
"""Returns a set of all variables in the AST that will be looked up from
the context at runtime. Because at compile time it's not known which
variables will be used depending on the path the execution takes at
@@ -42,12 +49,16 @@ def find_undeclared_variables(ast):
:exc:`TemplateAssertionError` during compilation and as a matter of
fact this function can currently raise that exception as well.
"""
- codegen = TrackingCodeGenerator(ast.environment)
+ codegen = TrackingCodeGenerator(ast.environment) # type: ignore
codegen.visit(ast)
return codegen.undeclared_identifiers
-def find_referenced_templates(ast):
+_ref_types = (nodes.Extends, nodes.FromImport, nodes.Import, nodes.Include)
+_RefType = t.Union[nodes.Extends, nodes.FromImport, nodes.Import, nodes.Include]
+
+
+def find_referenced_templates(ast: nodes.Template) -> t.Iterator[t.Optional[str]]:
"""Finds all the referenced templates from the AST. This will return an
iterator over all the hardcoded template extensions, inclusions and
imports. If dynamic inheritance or inclusion is used, `None` will be
@@ -62,13 +73,15 @@ def find_referenced_templates(ast):
This function is useful for dependency tracking. For example if you want
to rebuild parts of the website after a layout template has changed.
"""
- for node in ast.find_all(
- (nodes.Extends, nodes.FromImport, nodes.Import, nodes.Include)
- ):
- if not isinstance(node.template, nodes.Const):
+ template_name: t.Any
+
+ for node in ast.find_all(_ref_types):
+ template: nodes.Expr = node.template # type: ignore
+
+ if not isinstance(template, nodes.Const):
# a tuple with some non consts in there
- if isinstance(node.template, (nodes.Tuple, nodes.List)):
- for template_name in node.template.items:
+ if isinstance(template, (nodes.Tuple, nodes.List)):
+ for template_name in template.items:
# something const, only yield the strings and ignore
# non-string consts that really just make no sense
if isinstance(template_name, nodes.Const):
@@ -82,15 +95,15 @@ def find_referenced_templates(ast):
yield None
continue
# constant is a basestring, direct template name
- if isinstance(node.template.value, str):
- yield node.template.value
+ if isinstance(template.value, str):
+ yield template.value
# a tuple or list (latter *should* not happen) made of consts,
# yield the consts that are strings. We could warn here for
# non string values
elif isinstance(node, nodes.Include) and isinstance(
- node.template.value, (tuple, list)
+ template.value, (tuple, list)
):
- for template_name in node.template.value:
+ for template_name in template.value:
if isinstance(template_name, str):
yield template_name
# something else we don't care about, we could warn here
diff --git a/src/jinja2/nativetypes.py b/src/jinja2/nativetypes.py
index 6cca518..88eeecc 100644
--- a/src/jinja2/nativetypes.py
+++ b/src/jinja2/nativetypes.py
@@ -1,25 +1,26 @@
+import typing as t
from ast import literal_eval
from itertools import chain
from itertools import islice
-from typing import Any
from . import nodes
from .compiler import CodeGenerator
+from .compiler import Frame
from .compiler import has_safe_repr
from .environment import Environment
from .environment import Template
-def native_concat(nodes):
+def native_concat(values: t.Iterable[t.Any]) -> t.Optional[t.Any]:
"""Return a native Python type from the list of compiled nodes. If
the result is a single node, its value is returned. Otherwise, the
nodes are concatenated as strings. If the result can be parsed with
:func:`ast.literal_eval`, the parsed value is returned. Otherwise,
the string is returned.
- :param nodes: Iterable of nodes to concatenate.
+ :param values: Iterable of outputs to concatenate.
"""
- head = list(islice(nodes, 2))
+ head = list(islice(values, 2))
if not head:
return None
@@ -29,7 +30,7 @@ def native_concat(nodes):
if not isinstance(raw, str):
return raw
else:
- raw = "".join([str(v) for v in chain(head, nodes)])
+ raw = "".join([str(v) for v in chain(head, values)])
try:
return literal_eval(raw)
@@ -43,13 +44,15 @@ class NativeCodeGenerator(CodeGenerator):
"""
@staticmethod
- def _default_finalize(value):
+ def _default_finalize(value: t.Any) -> t.Any:
return value
- def _output_const_repr(self, group):
+ def _output_const_repr(self, group: t.Iterable[t.Any]) -> str:
return repr("".join([str(v) for v in group]))
- def _output_child_to_const(self, node, frame, finalize):
+ def _output_child_to_const(
+ self, node: nodes.Expr, frame: Frame, finalize: CodeGenerator._FinalizeInfo
+ ) -> t.Any:
const = node.as_const(frame.eval_ctx)
if not has_safe_repr(const):
@@ -58,13 +61,17 @@ class NativeCodeGenerator(CodeGenerator):
if isinstance(node, nodes.TemplateData):
return const
- return finalize.const(const)
+ return finalize.const(const) # type: ignore
- def _output_child_pre(self, node, frame, finalize):
+ def _output_child_pre(
+ self, node: nodes.Expr, frame: Frame, finalize: CodeGenerator._FinalizeInfo
+ ) -> None:
if finalize.src is not None:
self.write(finalize.src)
- def _output_child_post(self, node, frame, finalize):
+ def _output_child_post(
+ self, node: nodes.Expr, frame: Frame, finalize: CodeGenerator._FinalizeInfo
+ ) -> None:
if finalize.src is not None:
self.write(")")
@@ -73,13 +80,12 @@ class NativeEnvironment(Environment):
"""An environment that renders templates to native Python types."""
code_generator_class = NativeCodeGenerator
- template_class: Any
class NativeTemplate(Template):
environment_class = NativeEnvironment
- def render(self, *args, **kwargs):
+ def render(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Render the template to produce a native Python type. If the
result is a single node, its value is returned. Otherwise, the
nodes are concatenated as strings. If the result can be parsed
@@ -89,11 +95,11 @@ class NativeTemplate(Template):
ctx = self.new_context(dict(*args, **kwargs))
try:
- return native_concat(self.root_render_func(ctx))
+ return native_concat(self.root_render_func(ctx)) # type: ignore
except Exception:
return self.environment.handle_exception()
- async def render_async(self, *args, **kwargs):
+ async def render_async(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if not self.environment.is_async:
raise RuntimeError(
"The environment was not created with async mode enabled."
@@ -102,7 +108,9 @@ class NativeTemplate(Template):
ctx = self.new_context(dict(*args, **kwargs))
try:
- return native_concat([n async for n in self.root_render_func(ctx)])
+ return native_concat(
+ [n async for n in self.root_render_func(ctx)] # type: ignore
+ )
except Exception:
return self.environment.handle_exception()
diff --git a/src/jinja2/nodes.py b/src/jinja2/nodes.py
index bbe1ab7..d867c9b 100644
--- a/src/jinja2/nodes.py
+++ b/src/jinja2/nodes.py
@@ -4,15 +4,19 @@ to normalize nodes.
"""
import inspect
import operator
+import typing as t
from collections import deque
-from typing import Any
-from typing import Tuple as TupleType
from markupsafe import Markup
from .utils import _PassArg
-_binop_to_func = {
+if t.TYPE_CHECKING:
+ from .environment import Environment
+
+_NodeBound = t.TypeVar("_NodeBound", bound="Node")
+
+_binop_to_func: t.Dict[str, t.Callable[[t.Any, t.Any], t.Any]] = {
"*": operator.mul,
"/": operator.truediv,
"//": operator.floordiv,
@@ -22,13 +26,13 @@ _binop_to_func = {
"-": operator.sub,
}
-_uaop_to_func = {
+_uaop_to_func: t.Dict[str, t.Callable[[t.Any], t.Any]] = {
"not": operator.not_,
"+": operator.pos,
"-": operator.neg,
}
-_cmpop_to_func = {
+_cmpop_to_func: t.Dict[str, t.Callable[[t.Any, t.Any], t.Any]] = {
"eq": operator.eq,
"ne": operator.ne,
"gt": operator.gt,
@@ -49,7 +53,7 @@ class NodeType(type):
inheritance. fields and attributes from the parent class are
automatically forwarded to the child."""
- def __new__(mcs, name, bases, d):
+ def __new__(mcs, name, bases, d): # type: ignore
for attr in "fields", "attributes":
storage = []
storage.extend(getattr(bases[0] if bases else object, attr, ()))
@@ -66,7 +70,9 @@ class EvalContext:
to it in extensions.
"""
- def __init__(self, environment, template_name=None):
+ def __init__(
+ self, environment: "Environment", template_name: t.Optional[str] = None
+ ) -> None:
self.environment = environment
if callable(environment.autoescape):
self.autoescape = environment.autoescape(template_name)
@@ -74,15 +80,15 @@ class EvalContext:
self.autoescape = environment.autoescape
self.volatile = False
- def save(self):
+ def save(self) -> t.Mapping[str, t.Any]:
return self.__dict__.copy()
- def revert(self, old):
+ def revert(self, old: t.Mapping[str, t.Any]) -> None:
self.__dict__.clear()
self.__dict__.update(old)
-def get_eval_context(node, ctx):
+def get_eval_context(node: "Node", ctx: t.Optional[EvalContext]) -> EvalContext:
if ctx is None:
if node.environment is None:
raise RuntimeError(
@@ -110,19 +116,22 @@ class Node(metaclass=NodeType):
all nodes automatically.
"""
- fields: TupleType = ()
- attributes = ("lineno", "environment")
+ fields: t.Tuple[str, ...] = ()
+ attributes: t.Tuple[str, ...] = ("lineno", "environment")
abstract = True
- def __init__(self, *fields, **attributes):
+ lineno: int
+ environment: t.Optional["Environment"]
+
+ def __init__(self, *fields: t.Any, **attributes: t.Any) -> None:
if self.abstract:
raise TypeError("abstract nodes are not instantiable")
if fields:
if len(fields) != len(self.fields):
if not self.fields:
- raise TypeError(f"{self.__class__.__name__!r} takes 0 arguments")
+ raise TypeError(f"{type(self).__name__!r} takes 0 arguments")
raise TypeError(
- f"{self.__class__.__name__!r} takes 0 or {len(self.fields)}"
+ f"{type(self).__name__!r} takes 0 or {len(self.fields)}"
f" argument{'s' if len(self.fields) != 1 else ''}"
)
for name, arg in zip(self.fields, fields):
@@ -132,7 +141,11 @@ class Node(metaclass=NodeType):
if attributes:
raise TypeError(f"unknown attribute {next(iter(attributes))!r}")
- def iter_fields(self, exclude=None, only=None):
+ def iter_fields(
+ self,
+ exclude: t.Optional[t.Container[str]] = None,
+ only: t.Optional[t.Container[str]] = None,
+ ) -> t.Iterator[t.Tuple[str, t.Any]]:
"""This method iterates over all fields that are defined and yields
``(key, value)`` tuples. Per default all fields are returned, but
it's possible to limit that to some fields by providing the `only`
@@ -141,7 +154,7 @@ class Node(metaclass=NodeType):
"""
for name in self.fields:
if (
- (exclude is only is None)
+ (exclude is None and only is None)
or (exclude is not None and name not in exclude)
or (only is not None and name in only)
):
@@ -150,7 +163,11 @@ class Node(metaclass=NodeType):
except AttributeError:
pass
- def iter_child_nodes(self, exclude=None, only=None):
+ def iter_child_nodes(
+ self,
+ exclude: t.Optional[t.Container[str]] = None,
+ only: t.Optional[t.Container[str]] = None,
+ ) -> t.Iterator["Node"]:
"""Iterates over all direct child nodes of the node. This iterates
over all fields and yields the values of they are nodes. If the value
of a field is a list all the nodes in that list are returned.
@@ -163,23 +180,27 @@ class Node(metaclass=NodeType):
elif isinstance(item, Node):
yield item
- def find(self, node_type):
+ def find(self, node_type: t.Type[_NodeBound]) -> t.Optional[_NodeBound]:
"""Find the first node of a given type. If no such node exists the
return value is `None`.
"""
for result in self.find_all(node_type):
return result
- def find_all(self, node_type):
+ return None
+
+ def find_all(
+ self, node_type: t.Union[t.Type[_NodeBound], t.Tuple[t.Type[_NodeBound], ...]]
+ ) -> t.Iterator[_NodeBound]:
"""Find all the nodes of a given type. If the type is a tuple,
the check is performed for any of the tuple items.
"""
for child in self.iter_child_nodes():
if isinstance(child, node_type):
- yield child
+ yield child # type: ignore
yield from child.find_all(node_type)
- def set_ctx(self, ctx):
+ def set_ctx(self, ctx: str) -> "Node":
"""Reset the context of a node and all child nodes. Per default the
parser will all generate nodes that have a 'load' context as it's the
most common one. This method is used in the parser to set assignment
@@ -189,11 +210,11 @@ class Node(metaclass=NodeType):
while todo:
node = todo.popleft()
if "ctx" in node.fields:
- node.ctx = ctx
+ node.ctx = ctx # type: ignore
todo.extend(node.iter_child_nodes())
return self
- def set_lineno(self, lineno, override=False):
+ def set_lineno(self, lineno: int, override: bool = False) -> "Node":
"""Set the line numbers of the node and children."""
todo = deque([self])
while todo:
@@ -204,7 +225,7 @@ class Node(metaclass=NodeType):
todo.extend(node.iter_child_nodes())
return self
- def set_environment(self, environment):
+ def set_environment(self, environment: "Environment") -> "Node":
"""Set the environment for all nodes."""
todo = deque([self])
while todo:
@@ -213,26 +234,26 @@ class Node(metaclass=NodeType):
todo.extend(node.iter_child_nodes())
return self
- def __eq__(self, other):
+ def __eq__(self, other: t.Any) -> bool:
if type(self) is not type(other):
return NotImplemented
return tuple(self.iter_fields()) == tuple(other.iter_fields())
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(tuple(self.iter_fields()))
- def __repr__(self):
+ def __repr__(self) -> str:
args_str = ", ".join(f"{a}={getattr(self, a, None)!r}" for a in self.fields)
- return f"{self.__class__.__name__}({args_str})"
+ return f"{type(self).__name__}({args_str})"
- def dump(self):
- def _dump(node):
+ def dump(self) -> str:
+ def _dump(node: t.Union[Node, t.Any]) -> None:
if not isinstance(node, Node):
buf.append(repr(node))
return
- buf.append(f"nodes.{node.__class__.__name__}(")
+ buf.append(f"nodes.{type(node).__name__}(")
if not node.fields:
buf.append(")")
return
@@ -251,7 +272,7 @@ class Node(metaclass=NodeType):
_dump(value)
buf.append(")")
- buf = []
+ buf: t.List[str] = []
_dump(self)
return "".join(buf)
@@ -274,6 +295,7 @@ class Template(Node):
"""
fields = ("body",)
+ body: t.List[Node]
class Output(Stmt):
@@ -282,12 +304,14 @@ class Output(Stmt):
"""
fields = ("nodes",)
+ nodes: t.List["Expr"]
class Extends(Stmt):
"""Represents an extends statement."""
fields = ("template",)
+ template: "Expr"
class For(Stmt):
@@ -300,12 +324,22 @@ class For(Stmt):
"""
fields = ("target", "iter", "body", "else_", "test", "recursive")
+ target: Node
+ iter: Node
+ body: t.List[Node]
+ else_: t.List[Node]
+ test: t.Optional[Node]
+ recursive: bool
class If(Stmt):
"""If `test` is true, `body` is rendered, else `else_`."""
fields = ("test", "body", "elif_", "else_")
+ test: Node
+ body: t.List[Node]
+ elif_: t.List["If"]
+ else_: t.List[Node]
class Macro(Stmt):
@@ -315,6 +349,10 @@ class Macro(Stmt):
"""
fields = ("name", "args", "defaults", "body")
+ name: str
+ args: t.List["Name"]
+ defaults: t.List["Expr"]
+ body: t.List[Node]
class CallBlock(Stmt):
@@ -323,12 +361,18 @@ class CallBlock(Stmt):
"""
fields = ("call", "args", "defaults", "body")
+ call: "Call"
+ args: t.List["Name"]
+ defaults: t.List["Expr"]
+ body: t.List[Node]
class FilterBlock(Stmt):
"""Node for filter sections."""
fields = ("body", "filter")
+ body: t.List[Node]
+ filter: "Filter"
class With(Stmt):
@@ -339,6 +383,9 @@ class With(Stmt):
"""
fields = ("targets", "values", "body")
+ targets: t.List["Expr"]
+ values: t.List["Expr"]
+ body: t.List[Node]
class Block(Stmt):
@@ -349,18 +396,28 @@ class Block(Stmt):
"""
fields = ("name", "body", "scoped", "required")
+ name: str
+ body: t.List[Node]
+ scoped: bool
+ required: bool
class Include(Stmt):
"""A node that represents the include tag."""
fields = ("template", "with_context", "ignore_missing")
+ template: "Expr"
+ with_context: bool
+ ignore_missing: bool
class Import(Stmt):
"""A node that represents the import tag."""
fields = ("template", "target", "with_context")
+ template: "Expr"
+ target: str
+ with_context: bool
class FromImport(Stmt):
@@ -376,24 +433,33 @@ class FromImport(Stmt):
"""
fields = ("template", "names", "with_context")
+ template: "Expr"
+ names: t.List[t.Union[str, t.Tuple[str, str]]]
+ with_context: bool
class ExprStmt(Stmt):
"""A statement that evaluates an expression and discards the result."""
fields = ("node",)
+ node: Node
class Assign(Stmt):
"""Assigns an expression to a target."""
fields = ("target", "node")
+ target: "Expr"
+ node: Node
class AssignBlock(Stmt):
"""Assigns a block to a target."""
fields = ("target", "filter", "body")
+ target: "Expr"
+ filter: t.Optional["Filter"]
+ body: t.List[Node]
class Expr(Node):
@@ -401,7 +467,7 @@ class Expr(Node):
abstract = True
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
"""Return the value of the expression as constant or raise
:exc:`Impossible` if this was not possible.
@@ -414,7 +480,7 @@ class Expr(Node):
"""
raise Impossible()
- def can_assign(self):
+ def can_assign(self) -> bool:
"""Check if it's possible to assign something to this node."""
return False
@@ -423,15 +489,18 @@ class BinExpr(Expr):
"""Baseclass for all binary expressions."""
fields = ("left", "right")
- operator: Any = None
+ left: Expr
+ right: Expr
+ operator: str
abstract = True
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
eval_ctx = get_eval_context(self, eval_ctx)
+
# intercepted operators cannot be folded at compile time
if (
- self.environment.sandboxed
- and self.operator in self.environment.intercepted_binops
+ eval_ctx.environment.sandboxed
+ and self.operator in eval_ctx.environment.intercepted_binops # type: ignore
):
raise Impossible()
f = _binop_to_func[self.operator]
@@ -445,15 +514,17 @@ class UnaryExpr(Expr):
"""Baseclass for all unary expressions."""
fields = ("node",)
- operator: Any = None
+ node: Expr
+ operator: str
abstract = True
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
eval_ctx = get_eval_context(self, eval_ctx)
+
# intercepted operators cannot be folded at compile time
if (
- self.environment.sandboxed
- and self.operator in self.environment.intercepted_unops
+ eval_ctx.environment.sandboxed
+ and self.operator in eval_ctx.environment.intercepted_unops # type: ignore
):
raise Impossible()
f = _uaop_to_func[self.operator]
@@ -473,17 +544,21 @@ class Name(Expr):
"""
fields = ("name", "ctx")
+ name: str
+ ctx: str
- def can_assign(self):
- return self.name not in ("true", "false", "none", "True", "False", "None")
+ def can_assign(self) -> bool:
+ return self.name not in {"true", "false", "none", "True", "False", "None"}
class NSRef(Expr):
"""Reference to a namespace value assignment"""
fields = ("name", "attr")
+ name: str
+ attr: str
- def can_assign(self):
+ def can_assign(self) -> bool:
# We don't need any special checks here; NSRef assignments have a
# runtime check to ensure the target is a namespace object which will
# have been checked already as it is created using a normal assignment
@@ -505,12 +580,18 @@ class Const(Literal):
"""
fields = ("value",)
+ value: t.Any
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
return self.value
@classmethod
- def from_untrusted(cls, value, lineno=None, environment=None):
+ def from_untrusted(
+ cls,
+ value: t.Any,
+ lineno: t.Optional[int] = None,
+ environment: "t.Optional[Environment]" = None,
+ ) -> "Const":
"""Return a const object if the value is representable as
constant value in the generated code, otherwise it will raise
an `Impossible` exception.
@@ -526,8 +607,9 @@ class TemplateData(Literal):
"""A constant template string."""
fields = ("data",)
+ data: str
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> str:
eval_ctx = get_eval_context(self, eval_ctx)
if eval_ctx.volatile:
raise Impossible()
@@ -543,12 +625,14 @@ class Tuple(Literal):
"""
fields = ("items", "ctx")
+ items: t.List[Expr]
+ ctx: str
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Tuple[t.Any, ...]:
eval_ctx = get_eval_context(self, eval_ctx)
return tuple(x.as_const(eval_ctx) for x in self.items)
- def can_assign(self):
+ def can_assign(self) -> bool:
for item in self.items:
if not item.can_assign():
return False
@@ -559,8 +643,9 @@ class List(Literal):
"""Any list literal such as ``[1, 2, 3]``"""
fields = ("items",)
+ items: t.List[Expr]
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.List[t.Any]:
eval_ctx = get_eval_context(self, eval_ctx)
return [x.as_const(eval_ctx) for x in self.items]
@@ -571,8 +656,11 @@ class Dict(Literal):
"""
fields = ("items",)
+ items: t.List["Pair"]
- def as_const(self, eval_ctx=None):
+ def as_const(
+ self, eval_ctx: t.Optional[EvalContext] = None
+ ) -> t.Dict[t.Any, t.Any]:
eval_ctx = get_eval_context(self, eval_ctx)
return dict(x.as_const(eval_ctx) for x in self.items)
@@ -581,8 +669,12 @@ class Pair(Helper):
"""A key, value pair for dicts."""
fields = ("key", "value")
+ key: Expr
+ value: Expr
- def as_const(self, eval_ctx=None):
+ def as_const(
+ self, eval_ctx: t.Optional[EvalContext] = None
+ ) -> t.Tuple[t.Any, t.Any]:
eval_ctx = get_eval_context(self, eval_ctx)
return self.key.as_const(eval_ctx), self.value.as_const(eval_ctx)
@@ -591,8 +683,10 @@ class Keyword(Helper):
"""A key, value pair for keyword arguments where key is a string."""
fields = ("key", "value")
+ key: str
+ value: Expr
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Tuple[str, t.Any]:
eval_ctx = get_eval_context(self, eval_ctx)
return self.key, self.value.as_const(eval_ctx)
@@ -603,8 +697,11 @@ class CondExpr(Expr):
"""
fields = ("test", "expr1", "expr2")
+ test: Expr
+ expr1: Expr
+ expr2: t.Optional[Expr]
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
eval_ctx = get_eval_context(self, eval_ctx)
if self.test.as_const(eval_ctx):
return self.expr1.as_const(eval_ctx)
@@ -616,7 +713,9 @@ class CondExpr(Expr):
return self.expr2.as_const(eval_ctx)
-def args_as_const(node, eval_ctx):
+def args_as_const(
+ node: t.Union["_FilterTestCommon", "Call"], eval_ctx: t.Optional[EvalContext]
+) -> t.Tuple[t.List[t.Any], t.Dict[t.Any, t.Any]]:
args = [x.as_const(eval_ctx) for x in node.args]
kwargs = dict(x.as_const(eval_ctx) for x in node.kwargs)
@@ -637,10 +736,16 @@ def args_as_const(node, eval_ctx):
class _FilterTestCommon(Expr):
fields = ("node", "name", "args", "kwargs", "dyn_args", "dyn_kwargs")
+ node: Expr
+ name: str
+ args: t.List[Expr]
+ kwargs: t.List[Pair]
+ dyn_args: t.Optional[Expr]
+ dyn_kwargs: t.Optional[Expr]
abstract = True
_is_filter = True
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
eval_ctx = get_eval_context(self, eval_ctx)
if eval_ctx.volatile:
@@ -652,7 +757,7 @@ class _FilterTestCommon(Expr):
env_map = eval_ctx.environment.tests
func = env_map.get(self.name)
- pass_arg = _PassArg.from_obj(func)
+ pass_arg = _PassArg.from_obj(func) # type: ignore
if func is None or pass_arg is _PassArg.context:
raise Impossible()
@@ -685,7 +790,9 @@ class Filter(_FilterTestCommon):
and is applied to the content of the block.
"""
- def as_const(self, eval_ctx=None):
+ node: t.Optional[Expr] # type: ignore
+
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
if self.node is None:
raise Impossible()
@@ -714,27 +821,34 @@ class Call(Expr):
"""
fields = ("node", "args", "kwargs", "dyn_args", "dyn_kwargs")
+ node: Expr
+ args: t.List[Expr]
+ kwargs: t.List[Keyword]
+ dyn_args: t.Optional[Expr]
+ dyn_kwargs: t.Optional[Expr]
class Getitem(Expr):
"""Get an attribute or item from an expression and prefer the item."""
fields = ("node", "arg", "ctx")
+ node: Expr
+ arg: Expr
+ ctx: str
- def as_const(self, eval_ctx=None):
- eval_ctx = get_eval_context(self, eval_ctx)
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
if self.ctx != "load":
raise Impossible()
+
+ eval_ctx = get_eval_context(self, eval_ctx)
+
try:
- return self.environment.getitem(
+ return eval_ctx.environment.getitem(
self.node.as_const(eval_ctx), self.arg.as_const(eval_ctx)
)
except Exception:
raise Impossible()
- def can_assign(self):
- return False
-
class Getattr(Expr):
"""Get an attribute or item from an expression that is a ascii-only
@@ -742,19 +856,21 @@ class Getattr(Expr):
"""
fields = ("node", "attr", "ctx")
+ node: Expr
+ attr: str
+ ctx: str
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
if self.ctx != "load":
raise Impossible()
+
+ eval_ctx = get_eval_context(self, eval_ctx)
+
try:
- eval_ctx = get_eval_context(self, eval_ctx)
- return self.environment.getattr(self.node.as_const(eval_ctx), self.attr)
+ return eval_ctx.environment.getattr(self.node.as_const(eval_ctx), self.attr)
except Exception:
raise Impossible()
- def can_assign(self):
- return False
-
class Slice(Expr):
"""Represents a slice object. This must only be used as argument for
@@ -762,11 +878,14 @@ class Slice(Expr):
"""
fields = ("start", "stop", "step")
+ start: t.Optional[Expr]
+ stop: t.Optional[Expr]
+ step: t.Optional[Expr]
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> slice:
eval_ctx = get_eval_context(self, eval_ctx)
- def const(obj):
+ def const(obj: t.Optional[Expr]) -> t.Optional[t.Any]:
if obj is None:
return None
return obj.as_const(eval_ctx)
@@ -780,8 +899,9 @@ class Concat(Expr):
"""
fields = ("nodes",)
+ nodes: t.List[Expr]
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> str:
eval_ctx = get_eval_context(self, eval_ctx)
return "".join(str(x.as_const(eval_ctx)) for x in self.nodes)
@@ -792,8 +912,10 @@ class Compare(Expr):
"""
fields = ("expr", "ops")
+ expr: Expr
+ ops: t.List["Operand"]
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
eval_ctx = get_eval_context(self, eval_ctx)
result = value = self.expr.as_const(eval_ctx)
@@ -816,6 +938,8 @@ class Operand(Helper):
"""Holds an operator and an expression."""
fields = ("op", "expr")
+ op: str
+ expr: Expr
class Mul(BinExpr):
@@ -867,7 +991,7 @@ class And(BinExpr):
operator = "and"
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
eval_ctx = get_eval_context(self, eval_ctx)
return self.left.as_const(eval_ctx) and self.right.as_const(eval_ctx)
@@ -877,7 +1001,7 @@ class Or(BinExpr):
operator = "or"
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Any:
eval_ctx = get_eval_context(self, eval_ctx)
return self.left.as_const(eval_ctx) or self.right.as_const(eval_ctx)
@@ -909,6 +1033,7 @@ class EnvironmentAttribute(Expr):
"""
fields = ("name",)
+ name: str
class ExtensionAttribute(Expr):
@@ -920,6 +1045,8 @@ class ExtensionAttribute(Expr):
"""
fields = ("identifier", "name")
+ identifier: str
+ name: str
class ImportedName(Expr):
@@ -930,6 +1057,7 @@ class ImportedName(Expr):
"""
fields = ("importname",)
+ importname: str
class InternalName(Expr):
@@ -937,12 +1065,13 @@ class InternalName(Expr):
yourself but the parser provides a
:meth:`~jinja2.parser.Parser.free_identifier` method that creates
a new identifier for you. This identifier is not available from the
- template and is not threated specially by the compiler.
+ template and is not treated specially by the compiler.
"""
fields = ("name",)
+ name: str
- def __init__(self):
+ def __init__(self) -> None:
raise TypeError(
"Can't create internal names. Use the "
"`free_identifier` method on a parser."
@@ -953,8 +1082,9 @@ class MarkSafe(Expr):
"""Mark the wrapped expression as safe (wrap it as `Markup`)."""
fields = ("expr",)
+ expr: Expr
- def as_const(self, eval_ctx=None):
+ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> Markup:
eval_ctx = get_eval_context(self, eval_ctx)
return Markup(self.expr.as_const(eval_ctx))
@@ -967,8 +1097,11 @@ class MarkSafeIfAutoescape(Expr):
"""
fields = ("expr",)
+ expr: Expr
- def as_const(self, eval_ctx=None):
+ def as_const(
+ self, eval_ctx: t.Optional[EvalContext] = None
+ ) -> t.Union[Markup, t.Any]:
eval_ctx = get_eval_context(self, eval_ctx)
if eval_ctx.volatile:
raise Impossible()
@@ -1017,6 +1150,7 @@ class Scope(Stmt):
"""An artificial scope."""
fields = ("body",)
+ body: t.List[Node]
class OverlayScope(Stmt):
@@ -1034,6 +1168,8 @@ class OverlayScope(Stmt):
"""
fields = ("context", "body")
+ context: Expr
+ body: t.List[Node]
class EvalContextModifier(Stmt):
@@ -1046,6 +1182,7 @@ class EvalContextModifier(Stmt):
"""
fields = ("options",)
+ options: t.List[Keyword]
class ScopedEvalContextModifier(EvalContextModifier):
@@ -1055,10 +1192,11 @@ class ScopedEvalContextModifier(EvalContextModifier):
"""
fields = ("body",)
+ body: t.List[Node]
# make sure nobody creates custom nodes
-def _failing_new(*args, **kwargs):
+def _failing_new(*args: t.Any, **kwargs: t.Any) -> t.NoReturn:
raise TypeError("can't create custom node types")
diff --git a/src/jinja2/optimizer.py b/src/jinja2/optimizer.py
index 39d059f..fe10107 100644
--- a/src/jinja2/optimizer.py
+++ b/src/jinja2/optimizer.py
@@ -7,22 +7,29 @@ want. For example, loop unrolling doesn't work because unrolled loops
would have a different scope. The solution would be a second syntax tree
that stored the scoping rules.
"""
+import typing as t
+
from . import nodes
from .visitor import NodeTransformer
+if t.TYPE_CHECKING:
+ from .environment import Environment
+
-def optimize(node, environment):
+def optimize(node: nodes.Node, environment: "Environment") -> nodes.Node:
"""The context hint can be used to perform an static optimization
based on the context given."""
optimizer = Optimizer(environment)
- return optimizer.visit(node)
+ return t.cast(nodes.Node, optimizer.visit(node))
class Optimizer(NodeTransformer):
- def __init__(self, environment):
+ def __init__(self, environment: "t.Optional[Environment]") -> None:
self.environment = environment
- def generic_visit(self, node, *args, **kwargs):
+ def generic_visit(
+ self, node: nodes.Node, *args: t.Any, **kwargs: t.Any
+ ) -> nodes.Node:
node = super().generic_visit(node, *args, **kwargs)
# Do constant folding. Some other nodes besides Expr have
diff --git a/src/jinja2/parser.py b/src/jinja2/parser.py
index 9bdecba..408864c 100644
--- a/src/jinja2/parser.py
+++ b/src/jinja2/parser.py
@@ -1,10 +1,20 @@
"""Parse tokens from the lexer into nodes for the compiler."""
+import typing
+import typing as t
+
from . import nodes
from .exceptions import TemplateAssertionError
from .exceptions import TemplateSyntaxError
from .lexer import describe_token
from .lexer import describe_token_expr
+if t.TYPE_CHECKING:
+ import typing_extensions as te
+ from .environment import Environment
+
+_ImportInclude = t.TypeVar("_ImportInclude", nodes.Import, nodes.Include)
+_MacroCall = t.TypeVar("_MacroCall", nodes.Macro, nodes.CallBlock)
+
_statement_keywords = frozenset(
[
"for",
@@ -23,7 +33,7 @@ _statement_keywords = frozenset(
)
_compare_operators = frozenset(["eq", "ne", "lt", "lteq", "gt", "gteq"])
-_math_nodes = {
+_math_nodes: t.Dict[str, t.Type[nodes.Expr]] = {
"add": nodes.Add,
"sub": nodes.Sub,
"mul": nodes.Mul,
@@ -38,21 +48,35 @@ class Parser:
extensions and can be used to parse expressions or statements.
"""
- def __init__(self, environment, source, name=None, filename=None, state=None):
+ def __init__(
+ self,
+ environment: "Environment",
+ source: str,
+ name: t.Optional[str] = None,
+ filename: t.Optional[str] = None,
+ state: t.Optional[str] = None,
+ ) -> None:
self.environment = environment
self.stream = environment._tokenize(source, name, filename, state)
self.name = name
self.filename = filename
self.closed = False
- self.extensions = {}
+ self.extensions: t.Dict[
+ str, t.Callable[["Parser"], t.Union[nodes.Node, t.List[nodes.Node]]]
+ ] = {}
for extension in environment.iter_extensions():
for tag in extension.tags:
self.extensions[tag] = extension.parse
self._last_identifier = 0
- self._tag_stack = []
- self._end_token_stack = []
+ self._tag_stack: t.List[str] = []
+ self._end_token_stack: t.List[t.Tuple[str, ...]] = []
- def fail(self, msg, lineno=None, exc=TemplateSyntaxError):
+ def fail(
+ self,
+ msg: str,
+ lineno: t.Optional[int] = None,
+ exc: t.Type[TemplateSyntaxError] = TemplateSyntaxError,
+ ) -> t.NoReturn:
"""Convenience method that raises `exc` with the message, passed
line number or last line number as well as the current name and
filename.
@@ -61,12 +85,17 @@ class Parser:
lineno = self.stream.current.lineno
raise exc(msg, lineno, self.name, self.filename)
- def _fail_ut_eof(self, name, end_token_stack, lineno):
- expected = []
+ def _fail_ut_eof(
+ self,
+ name: t.Optional[str],
+ end_token_stack: t.List[t.Tuple[str, ...]],
+ lineno: t.Optional[int],
+ ) -> t.NoReturn:
+ expected: t.Set[str] = set()
for exprs in end_token_stack:
- expected.extend(map(describe_token_expr, exprs))
+ expected.update(map(describe_token_expr, exprs))
if end_token_stack:
- currently_looking = " or ".join(
+ currently_looking: t.Optional[str] = " or ".join(
map(repr, map(describe_token_expr, end_token_stack[-1]))
)
else:
@@ -96,36 +125,42 @@ class Parser:
self.fail(" ".join(message), lineno)
- def fail_unknown_tag(self, name, lineno=None):
+ def fail_unknown_tag(self, name: str, lineno: t.Optional[int] = None) -> t.NoReturn:
"""Called if the parser encounters an unknown tag. Tries to fail
with a human readable error message that could help to identify
the problem.
"""
- return self._fail_ut_eof(name, self._end_token_stack, lineno)
+ self._fail_ut_eof(name, self._end_token_stack, lineno)
- def fail_eof(self, end_tokens=None, lineno=None):
+ def fail_eof(
+ self,
+ end_tokens: t.Optional[t.Tuple[str, ...]] = None,
+ lineno: t.Optional[int] = None,
+ ) -> t.NoReturn:
"""Like fail_unknown_tag but for end of template situations."""
stack = list(self._end_token_stack)
if end_tokens is not None:
stack.append(end_tokens)
- return self._fail_ut_eof(None, stack, lineno)
+ self._fail_ut_eof(None, stack, lineno)
- def is_tuple_end(self, extra_end_rules=None):
+ def is_tuple_end(
+ self, extra_end_rules: t.Optional[t.Tuple[str, ...]] = None
+ ) -> bool:
"""Are we at the end of a tuple?"""
if self.stream.current.type in ("variable_end", "block_end", "rparen"):
return True
elif extra_end_rules is not None:
- return self.stream.current.test_any(extra_end_rules)
+ return self.stream.current.test_any(extra_end_rules) # type: ignore
return False
- def free_identifier(self, lineno=None):
+ def free_identifier(self, lineno: t.Optional[int] = None) -> nodes.InternalName:
"""Return a new free identifier as :class:`~jinja2.nodes.InternalName`."""
self._last_identifier += 1
rv = object.__new__(nodes.InternalName)
nodes.Node.__init__(rv, f"fi{self._last_identifier}", lineno=lineno)
- return rv
+ return rv # type: ignore
- def parse_statement(self):
+ def parse_statement(self) -> t.Union[nodes.Node, t.List[nodes.Node]]:
"""Parse a single statement."""
token = self.stream.current
if token.type != "name":
@@ -134,7 +169,8 @@ class Parser:
pop_tag = True
try:
if token.value in _statement_keywords:
- return getattr(self, "parse_" + self.stream.current.value)()
+ f = getattr(self, f"parse_{self.stream.current.value}")
+ return f() # type: ignore
if token.value == "call":
return self.parse_call_block()
if token.value == "filter":
@@ -153,7 +189,9 @@ class Parser:
if pop_tag:
self._tag_stack.pop()
- def parse_statements(self, end_tokens, drop_needle=False):
+ def parse_statements(
+ self, end_tokens: t.Tuple[str, ...], drop_needle: bool = False
+ ) -> t.List[nodes.Node]:
"""Parse multiple statements into a list until one of the end tokens
is reached. This is used to parse the body of statements as it also
parses template data if appropriate. The parser checks first if the
@@ -180,7 +218,7 @@ class Parser:
next(self.stream)
return result
- def parse_set(self):
+ def parse_set(self) -> t.Union[nodes.Assign, nodes.AssignBlock]:
"""Parse an assign statement."""
lineno = next(self.stream).lineno
target = self.parse_assign_target(with_namespace=True)
@@ -191,7 +229,7 @@ class Parser:
body = self.parse_statements(("name:endset",), drop_needle=True)
return nodes.AssignBlock(target, filter_node, body, lineno=lineno)
- def parse_for(self):
+ def parse_for(self) -> nodes.For:
"""Parse a for loop."""
lineno = self.stream.expect("name:for").lineno
target = self.parse_assign_target(extra_end_rules=("name:in",))
@@ -210,10 +248,10 @@ class Parser:
else_ = self.parse_statements(("name:endfor",), drop_needle=True)
return nodes.For(target, iter, body, else_, test, recursive, lineno=lineno)
- def parse_if(self):
+ def parse_if(self) -> nodes.If:
"""Parse an if construct."""
node = result = nodes.If(lineno=self.stream.expect("name:if").lineno)
- while 1:
+ while True:
node.test = self.parse_tuple(with_condexpr=False)
node.body = self.parse_statements(("name:elif", "name:else", "name:endif"))
node.elif_ = []
@@ -228,10 +266,10 @@ class Parser:
break
return result
- def parse_with(self):
+ def parse_with(self) -> nodes.With:
node = nodes.With(lineno=next(self.stream).lineno)
- targets = []
- values = []
+ targets: t.List[nodes.Expr] = []
+ values: t.List[nodes.Expr] = []
while self.stream.current.type != "block_end":
if targets:
self.stream.expect("comma")
@@ -245,13 +283,13 @@ class Parser:
node.body = self.parse_statements(("name:endwith",), drop_needle=True)
return node
- def parse_autoescape(self):
+ def parse_autoescape(self) -> nodes.Scope:
node = nodes.ScopedEvalContextModifier(lineno=next(self.stream).lineno)
node.options = [nodes.Keyword("autoescape", self.parse_expression())]
node.body = self.parse_statements(("name:endautoescape",), drop_needle=True)
return nodes.Scope([node])
- def parse_block(self):
+ def parse_block(self) -> nodes.Block:
node = nodes.Block(lineno=next(self.stream).lineno)
node.name = self.stream.expect("name").value
node.scoped = self.stream.skip_if("name:scoped")
@@ -274,19 +312,21 @@ class Parser:
if node.required and not all(
isinstance(child, nodes.TemplateData) and child.data.isspace()
for body in node.body
- for child in body.nodes
+ for child in body.nodes # type: ignore
):
self.fail("Required blocks can only contain comments or whitespace")
self.stream.skip_if("name:" + node.name)
return node
- def parse_extends(self):
+ def parse_extends(self) -> nodes.Extends:
node = nodes.Extends(lineno=next(self.stream).lineno)
node.template = self.parse_expression()
return node
- def parse_import_context(self, node, default):
+ def parse_import_context(
+ self, node: _ImportInclude, default: bool
+ ) -> _ImportInclude:
if self.stream.current.test_any(
"name:with", "name:without"
) and self.stream.look().test("name:context"):
@@ -296,7 +336,7 @@ class Parser:
node.with_context = default
return node
- def parse_include(self):
+ def parse_include(self) -> nodes.Include:
node = nodes.Include(lineno=next(self.stream).lineno)
node.template = self.parse_expression()
if self.stream.current.test("name:ignore") and self.stream.look().test(
@@ -308,20 +348,20 @@ class Parser:
node.ignore_missing = False
return self.parse_import_context(node, True)
- def parse_import(self):
+ def parse_import(self) -> nodes.Import:
node = nodes.Import(lineno=next(self.stream).lineno)
node.template = self.parse_expression()
self.stream.expect("name:as")
node.target = self.parse_assign_target(name_only=True).name
return self.parse_import_context(node, False)
- def parse_from(self):
+ def parse_from(self) -> nodes.FromImport:
node = nodes.FromImport(lineno=next(self.stream).lineno)
node.template = self.parse_expression()
self.stream.expect("name:import")
node.names = []
- def parse_context():
+ def parse_context() -> bool:
if (
self.stream.current.value
in {
@@ -335,7 +375,7 @@ class Parser:
return True
return False
- while 1:
+ while True:
if node.names:
self.stream.expect("comma")
if self.stream.current.type == "name":
@@ -361,9 +401,9 @@ class Parser:
node.with_context = False
return node
- def parse_signature(self, node):
- node.args = args = []
- node.defaults = defaults = []
+ def parse_signature(self, node: _MacroCall) -> None:
+ args = node.args = []
+ defaults = node.defaults = []
self.stream.expect("lparen")
while self.stream.current.type != "rparen":
if args:
@@ -377,7 +417,7 @@ class Parser:
args.append(arg)
self.stream.expect("rparen")
- def parse_call_block(self):
+ def parse_call_block(self) -> nodes.CallBlock:
node = nodes.CallBlock(lineno=next(self.stream).lineno)
if self.stream.current.type == "lparen":
self.parse_signature(node)
@@ -385,26 +425,27 @@ class Parser:
node.args = []
node.defaults = []
- node.call = self.parse_expression()
- if not isinstance(node.call, nodes.Call):
+ call_node = self.parse_expression()
+ if not isinstance(call_node, nodes.Call):
self.fail("expected call", node.lineno)
+ node.call = call_node
node.body = self.parse_statements(("name:endcall",), drop_needle=True)
return node
- def parse_filter_block(self):
+ def parse_filter_block(self) -> nodes.FilterBlock:
node = nodes.FilterBlock(lineno=next(self.stream).lineno)
- node.filter = self.parse_filter(None, start_inline=True)
+ node.filter = self.parse_filter(None, start_inline=True) # type: ignore
node.body = self.parse_statements(("name:endfilter",), drop_needle=True)
return node
- def parse_macro(self):
+ def parse_macro(self) -> nodes.Macro:
node = nodes.Macro(lineno=next(self.stream).lineno)
node.name = self.parse_assign_target(name_only=True).name
self.parse_signature(node)
node.body = self.parse_statements(("name:endmacro",), drop_needle=True)
return node
- def parse_print(self):
+ def parse_print(self) -> nodes.Output:
node = nodes.Output(lineno=next(self.stream).lineno)
node.nodes = []
while self.stream.current.type != "block_end":
@@ -413,13 +454,29 @@ class Parser:
node.nodes.append(self.parse_expression())
return node
+ @typing.overload
+ def parse_assign_target(
+ self, with_tuple: bool = ..., name_only: "te.Literal[True]" = ...
+ ) -> nodes.Name:
+ ...
+
+ @typing.overload
def parse_assign_target(
self,
- with_tuple=True,
- name_only=False,
- extra_end_rules=None,
- with_namespace=False,
- ):
+ with_tuple: bool = True,
+ name_only: bool = False,
+ extra_end_rules: t.Optional[t.Tuple[str, ...]] = None,
+ with_namespace: bool = False,
+ ) -> t.Union[nodes.NSRef, nodes.Name, nodes.Tuple]:
+ ...
+
+ def parse_assign_target(
+ self,
+ with_tuple: bool = True,
+ name_only: bool = False,
+ extra_end_rules: t.Optional[t.Tuple[str, ...]] = None,
+ with_namespace: bool = False,
+ ) -> t.Union[nodes.NSRef, nodes.Name, nodes.Tuple]:
"""Parse an assignment target. As Jinja allows assignments to
tuples, this function can parse all allowed assignment targets. Per
default assignments to tuples are parsed, that can be disable however
@@ -428,6 +485,8 @@ class Parser:
parameter is forwarded to the tuple parsing function. If
`with_namespace` is enabled, a namespace assignment may be parsed.
"""
+ target: nodes.Expr
+
if with_namespace and self.stream.look().type == "dot":
token = self.stream.expect("name")
next(self.stream) # dot
@@ -443,14 +502,17 @@ class Parser:
)
else:
target = self.parse_primary()
+
target.set_ctx("store")
+
if not target.can_assign():
self.fail(
- f"can't assign to {target.__class__.__name__.lower()!r}", target.lineno
+ f"can't assign to {type(target).__name__.lower()!r}", target.lineno
)
- return target
- def parse_expression(self, with_condexpr=True):
+ return target # type: ignore
+
+ def parse_expression(self, with_condexpr: bool = True) -> nodes.Expr:
"""Parse an expression. Per default all expressions are parsed, if
the optional `with_condexpr` parameter is set to `False` conditional
expressions are not parsed.
@@ -459,9 +521,11 @@ class Parser:
return self.parse_condexpr()
return self.parse_or()
- def parse_condexpr(self):
+ def parse_condexpr(self) -> nodes.Expr:
lineno = self.stream.current.lineno
expr1 = self.parse_or()
+ expr3: t.Optional[nodes.Expr]
+
while self.stream.skip_if("name:if"):
expr2 = self.parse_or()
if self.stream.skip_if("name:else"):
@@ -472,7 +536,7 @@ class Parser:
lineno = self.stream.current.lineno
return expr1
- def parse_or(self):
+ def parse_or(self) -> nodes.Expr:
lineno = self.stream.current.lineno
left = self.parse_and()
while self.stream.skip_if("name:or"):
@@ -481,7 +545,7 @@ class Parser:
lineno = self.stream.current.lineno
return left
- def parse_and(self):
+ def parse_and(self) -> nodes.Expr:
lineno = self.stream.current.lineno
left = self.parse_not()
while self.stream.skip_if("name:and"):
@@ -490,17 +554,17 @@ class Parser:
lineno = self.stream.current.lineno
return left
- def parse_not(self):
+ def parse_not(self) -> nodes.Expr:
if self.stream.current.test("name:not"):
lineno = next(self.stream).lineno
return nodes.Not(self.parse_not(), lineno=lineno)
return self.parse_compare()
- def parse_compare(self):
+ def parse_compare(self) -> nodes.Expr:
lineno = self.stream.current.lineno
expr = self.parse_math1()
ops = []
- while 1:
+ while True:
token_type = self.stream.current.type
if token_type in _compare_operators:
next(self.stream)
@@ -519,7 +583,7 @@ class Parser:
return expr
return nodes.Compare(expr, ops, lineno=lineno)
- def parse_math1(self):
+ def parse_math1(self) -> nodes.Expr:
lineno = self.stream.current.lineno
left = self.parse_concat()
while self.stream.current.type in ("add", "sub"):
@@ -530,7 +594,7 @@ class Parser:
lineno = self.stream.current.lineno
return left
- def parse_concat(self):
+ def parse_concat(self) -> nodes.Expr:
lineno = self.stream.current.lineno
args = [self.parse_math2()]
while self.stream.current.type == "tilde":
@@ -540,7 +604,7 @@ class Parser:
return args[0]
return nodes.Concat(args, lineno=lineno)
- def parse_math2(self):
+ def parse_math2(self) -> nodes.Expr:
lineno = self.stream.current.lineno
left = self.parse_pow()
while self.stream.current.type in ("mul", "div", "floordiv", "mod"):
@@ -551,7 +615,7 @@ class Parser:
lineno = self.stream.current.lineno
return left
- def parse_pow(self):
+ def parse_pow(self) -> nodes.Expr:
lineno = self.stream.current.lineno
left = self.parse_unary()
while self.stream.current.type == "pow":
@@ -561,9 +625,11 @@ class Parser:
lineno = self.stream.current.lineno
return left
- def parse_unary(self, with_filter=True):
+ def parse_unary(self, with_filter: bool = True) -> nodes.Expr:
token_type = self.stream.current.type
lineno = self.stream.current.lineno
+ node: nodes.Expr
+
if token_type == "sub":
next(self.stream)
node = nodes.Neg(self.parse_unary(False), lineno=lineno)
@@ -577,8 +643,9 @@ class Parser:
node = self.parse_filter_expr(node)
return node
- def parse_primary(self):
+ def parse_primary(self) -> nodes.Expr:
token = self.stream.current
+ node: nodes.Expr
if token.type == "name":
if token.value in ("true", "false", "True", "False"):
node = nodes.Const(token.value in ("true", "True"), lineno=token.lineno)
@@ -612,11 +679,11 @@ class Parser:
def parse_tuple(
self,
- simplified=False,
- with_condexpr=True,
- extra_end_rules=None,
- explicit_parentheses=False,
- ):
+ simplified: bool = False,
+ with_condexpr: bool = True,
+ extra_end_rules: t.Optional[t.Tuple[str, ...]] = None,
+ explicit_parentheses: bool = False,
+ ) -> t.Union[nodes.Tuple, nodes.Expr]:
"""Works like `parse_expression` but if multiple expressions are
delimited by a comma a :class:`~jinja2.nodes.Tuple` node is created.
This method could also return a regular expression instead of a tuple
@@ -642,12 +709,13 @@ class Parser:
parse = self.parse_expression
else:
- def parse():
+ def parse() -> nodes.Expr:
return self.parse_expression(with_condexpr=False)
- args = []
+ args: t.List[nodes.Expr] = []
is_tuple = False
- while 1:
+
+ while True:
if args:
self.stream.expect("comma")
if self.is_tuple_end(extra_end_rules):
@@ -675,9 +743,9 @@ class Parser:
return nodes.Tuple(args, "load", lineno=lineno)
- def parse_list(self):
+ def parse_list(self) -> nodes.List:
token = self.stream.expect("lbracket")
- items = []
+ items: t.List[nodes.Expr] = []
while self.stream.current.type != "rbracket":
if items:
self.stream.expect("comma")
@@ -687,9 +755,9 @@ class Parser:
self.stream.expect("rbracket")
return nodes.List(items, lineno=token.lineno)
- def parse_dict(self):
+ def parse_dict(self) -> nodes.Dict:
token = self.stream.expect("lbrace")
- items = []
+ items: t.List[nodes.Pair] = []
while self.stream.current.type != "rbrace":
if items:
self.stream.expect("comma")
@@ -702,8 +770,8 @@ class Parser:
self.stream.expect("rbrace")
return nodes.Dict(items, lineno=token.lineno)
- def parse_postfix(self, node):
- while 1:
+ def parse_postfix(self, node: nodes.Expr) -> nodes.Expr:
+ while True:
token_type = self.stream.current.type
if token_type == "dot" or token_type == "lbracket":
node = self.parse_subscript(node)
@@ -715,11 +783,11 @@ class Parser:
break
return node
- def parse_filter_expr(self, node):
- while 1:
+ def parse_filter_expr(self, node: nodes.Expr) -> nodes.Expr:
+ while True:
token_type = self.stream.current.type
if token_type == "pipe":
- node = self.parse_filter(node)
+ node = self.parse_filter(node) # type: ignore
elif token_type == "name" and self.stream.current.value == "is":
node = self.parse_test(node)
# calls are valid both after postfix expressions (getattr
@@ -730,8 +798,12 @@ class Parser:
break
return node
- def parse_subscript(self, node):
+ def parse_subscript(
+ self, node: nodes.Expr
+ ) -> t.Union[nodes.Getattr, nodes.Getitem]:
token = next(self.stream)
+ arg: nodes.Expr
+
if token.type == "dot":
attr_token = self.stream.current
next(self.stream)
@@ -744,7 +816,7 @@ class Parser:
arg = nodes.Const(attr_token.value, lineno=attr_token.lineno)
return nodes.Getitem(node, arg, "load", lineno=token.lineno)
if token.type == "lbracket":
- args = []
+ args: t.List[nodes.Expr] = []
while self.stream.current.type != "rbracket":
if args:
self.stream.expect("comma")
@@ -757,8 +829,9 @@ class Parser:
return nodes.Getitem(node, arg, "load", lineno=token.lineno)
self.fail("expected subscript expression", token.lineno)
- def parse_subscribed(self):
+ def parse_subscribed(self) -> nodes.Expr:
lineno = self.stream.current.lineno
+ args: t.List[t.Optional[nodes.Expr]]
if self.stream.current.type == "colon":
next(self.stream)
@@ -788,23 +861,26 @@ class Parser:
return nodes.Slice(lineno=lineno, *args)
- def parse_call(self, node):
+ def parse_call_args(self) -> t.Tuple:
token = self.stream.expect("lparen")
args = []
kwargs = []
- dyn_args = dyn_kwargs = None
+ dyn_args = None
+ dyn_kwargs = None
require_comma = False
- def ensure(expr):
+ def ensure(expr: bool) -> None:
if not expr:
self.fail("invalid syntax for function call expression", token.lineno)
while self.stream.current.type != "rparen":
if require_comma:
self.stream.expect("comma")
+
# support for trailing comma
if self.stream.current.type == "rparen":
break
+
if self.stream.current.type == "mul":
ensure(dyn_args is None and dyn_kwargs is None)
next(self.stream)
@@ -830,13 +906,20 @@ class Parser:
args.append(self.parse_expression())
require_comma = True
+
self.stream.expect("rparen")
+ return args, kwargs, dyn_args, dyn_kwargs
- if node is None:
- return args, kwargs, dyn_args, dyn_kwargs
+ def parse_call(self, node: nodes.Expr) -> nodes.Call:
+ # The lparen will be expected in parse_call_args, but the lineno
+ # needs to be recorded before the stream is advanced.
+ token = self.stream.current
+ args, kwargs, dyn_args, dyn_kwargs = self.parse_call_args()
return nodes.Call(node, args, kwargs, dyn_args, dyn_kwargs, lineno=token.lineno)
- def parse_filter(self, node, start_inline=False):
+ def parse_filter(
+ self, node: t.Optional[nodes.Expr], start_inline: bool = False
+ ) -> t.Optional[nodes.Expr]:
while self.stream.current.type == "pipe" or start_inline:
if not start_inline:
next(self.stream)
@@ -846,7 +929,7 @@ class Parser:
next(self.stream)
name += "." + self.stream.expect("name").value
if self.stream.current.type == "lparen":
- args, kwargs, dyn_args, dyn_kwargs = self.parse_call(None)
+ args, kwargs, dyn_args, dyn_kwargs = self.parse_call_args()
else:
args = []
kwargs = []
@@ -857,7 +940,7 @@ class Parser:
start_inline = False
return node
- def parse_test(self, node):
+ def parse_test(self, node: nodes.Expr) -> nodes.Expr:
token = next(self.stream)
if self.stream.current.test("name:not"):
next(self.stream)
@@ -871,7 +954,7 @@ class Parser:
dyn_args = dyn_kwargs = None
kwargs = []
if self.stream.current.type == "lparen":
- args, kwargs, dyn_args, dyn_kwargs = self.parse_call(None)
+ args, kwargs, dyn_args, dyn_kwargs = self.parse_call_args()
elif (
self.stream.current.type
in {
@@ -899,15 +982,17 @@ class Parser:
node = nodes.Not(node, lineno=token.lineno)
return node
- def subparse(self, end_tokens=None):
- body = []
- data_buffer = []
+ def subparse(
+ self, end_tokens: t.Optional[t.Tuple[str, ...]] = None
+ ) -> t.List[nodes.Node]:
+ body: t.List[nodes.Node] = []
+ data_buffer: t.List[nodes.Node] = []
add_data = data_buffer.append
if end_tokens is not None:
self._end_token_stack.append(end_tokens)
- def flush_data():
+ def flush_data() -> None:
if data_buffer:
lineno = data_buffer[0].lineno
body.append(nodes.Output(data_buffer[:], lineno=lineno))
@@ -946,7 +1031,7 @@ class Parser:
self._end_token_stack.pop()
return body
- def parse(self):
+ def parse(self) -> nodes.Template:
"""Parse the whole template into a `Template` node."""
result = nodes.Template(self.subparse(), lineno=1)
result.set_environment(self.environment)
diff --git a/src/jinja2/py.typed b/src/jinja2/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/jinja2/py.typed
diff --git a/src/jinja2/runtime.py b/src/jinja2/runtime.py
index 1cdc858..883c2f7 100644
--- a/src/jinja2/runtime.py
+++ b/src/jinja2/runtime.py
@@ -1,4 +1,5 @@
"""The runtime functions and state used by compiled templates."""
+import functools
import sys
import typing as t
from collections import abc
@@ -22,9 +23,24 @@ from .utils import Namespace # noqa: F401
from .utils import object_type_repr
from .utils import pass_eval_context
+V = t.TypeVar("V")
+F = t.TypeVar("F", bound=t.Callable[..., t.Any])
+
if t.TYPE_CHECKING:
+ import logging
+ import typing_extensions as te
from .environment import Environment
+ class LoopRenderFunc(te.Protocol):
+ def __call__(
+ self,
+ reciter: t.Iterable[V],
+ loop_render_func: "LoopRenderFunc",
+ depth: int = 0,
+ ) -> str:
+ ...
+
+
# these variables are exported to the template runtime
exported = [
"LoopContext",
@@ -50,14 +66,14 @@ async_exported = [
]
-def identity(x):
+def identity(x: V) -> V:
"""Returns its argument. Useful for certain things in the
environment.
"""
return x
-def markup_join(seq):
+def markup_join(seq: t.Iterable[t.Any]) -> str:
"""Concatenation that escapes if necessary and converts to string."""
buf = []
iterator = map(soft_str, seq)
@@ -68,12 +84,12 @@ def markup_join(seq):
return concat(buf)
-def str_join(seq):
+def str_join(seq: t.Iterable[t.Any]) -> str:
"""Simple args to string conversion and concatenation."""
return concat(map(str, seq))
-def unicode_join(seq):
+def unicode_join(seq: t.Iterable[t.Any]) -> str:
import warnings
warnings.warn(
@@ -86,14 +102,14 @@ def unicode_join(seq):
def new_context(
- environment,
- template_name,
- blocks,
- vars=None,
- shared=None,
- globals=None,
- locals=None,
-):
+ environment: "Environment",
+ template_name: t.Optional[str],
+ blocks: t.Dict[str, t.Callable[["Context"], t.Iterator[str]]],
+ vars: t.Optional[t.Dict[str, t.Any]] = None,
+ shared: bool = False,
+ globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
+ locals: t.Optional[t.Mapping[str, t.Any]] = None,
+) -> "Context":
"""Internal helper for context creation."""
if vars is None:
vars = {}
@@ -117,47 +133,27 @@ def new_context(
class TemplateReference:
"""The `self` in templates."""
- def __init__(self, context):
+ def __init__(self, context: "Context") -> None:
self.__context = context
- def __getitem__(self, name):
+ def __getitem__(self, name: str) -> t.Any:
blocks = self.__context.blocks[name]
return BlockReference(name, self.__context, blocks, 0)
- def __repr__(self):
- return f"<{self.__class__.__name__} {self.__context.name!r}>"
-
-
-class ContextMeta(type):
- def __new__(mcs, name, bases, d):
- rv = type.__new__(mcs, name, bases, d)
-
- if not bases:
- return rv
+ def __repr__(self) -> str:
+ return f"<{type(self).__name__} {self.__context.name!r}>"
- if "resolve_or_missing" in d:
- # If the subclass overrides resolve_or_missing it opts in to
- # modern mode no matter what.
- rv._legacy_resolve_mode = False
- elif "resolve" in d or rv._legacy_resolve_mode:
- # If the subclass overrides resolve, or if its base is
- # already in legacy mode, warn about legacy behavior.
- import warnings
- warnings.warn(
- "Overriding 'resolve' is deprecated and will not have"
- " the expected behavior in Jinja 3.1. Override"
- " 'resolve_or_missing' instead ",
- DeprecationWarning,
- stacklevel=2,
- )
- rv._legacy_resolve_mode = True
+def _dict_method_all(dict_method: F) -> F:
+ @functools.wraps(dict_method)
+ def f_all(self: "Context") -> t.Any:
+ return dict_method(self.get_all())
- return rv
+ return t.cast(F, f_all)
@abc.Mapping.register
-class Context(metaclass=ContextMeta):
+class Context:
"""The template context holds the variables of a template. It stores the
values passed to the template and also the names the template exports.
Creating instances is neither supported nor useful as it's created
@@ -177,14 +173,40 @@ class Context(metaclass=ContextMeta):
:class:`Undefined` object for missing variables.
"""
- _legacy_resolve_mode = False
+ _legacy_resolve_mode: t.ClassVar[bool] = False
+
+ def __init_subclass__(cls) -> None:
+ if "resolve_or_missing" in cls.__dict__:
+ # If the subclass overrides resolve_or_missing it opts in to
+ # modern mode no matter what.
+ cls._legacy_resolve_mode = False
+ elif "resolve" in cls.__dict__ or cls._legacy_resolve_mode:
+ # If the subclass overrides resolve, or if its base is
+ # already in legacy mode, warn about legacy behavior.
+ import warnings
- def __init__(self, environment, parent, name, blocks, globals=None):
+ warnings.warn(
+ "Overriding 'resolve' is deprecated and will not have"
+ " the expected behavior in Jinja 3.1. Override"
+ " 'resolve_or_missing' instead ",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ cls._legacy_resolve_mode = True
+
+ def __init__(
+ self,
+ environment: "Environment",
+ parent: t.Dict[str, t.Any],
+ name: t.Optional[str],
+ blocks: t.Dict[str, t.Callable[["Context"], t.Iterator[str]]],
+ globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
+ ):
self.parent = parent
- self.vars = {}
+ self.vars: t.Dict[str, t.Any] = {}
self.environment: "Environment" = environment
self.eval_ctx = EvalContext(self.environment, name)
- self.exported_vars = set()
+ self.exported_vars: t.Set[str] = set()
self.name = name
self.globals_keys = set() if globals is None else set(globals)
@@ -193,7 +215,9 @@ class Context(metaclass=ContextMeta):
# from the template.
self.blocks = {k: [v] for k, v in blocks.items()}
- def super(self, name, current):
+ def super(
+ self, name: str, current: t.Callable[["Context"], t.Iterator[str]]
+ ) -> t.Union["BlockReference", "Undefined"]:
"""Render a parent block."""
try:
blocks = self.blocks[name]
@@ -205,7 +229,7 @@ class Context(metaclass=ContextMeta):
)
return BlockReference(name, self, blocks, index)
- def get(self, key, default=None):
+ def get(self, key: str, default: t.Any = None) -> t.Any:
"""Look up a variable by name, or return a default if the key is
not found.
@@ -217,7 +241,7 @@ class Context(metaclass=ContextMeta):
except KeyError:
return default
- def resolve(self, key):
+ def resolve(self, key: str) -> t.Union[t.Any, "Undefined"]:
"""Look up a variable by name, or return an :class:`Undefined`
object if the key is not found.
@@ -243,7 +267,7 @@ class Context(metaclass=ContextMeta):
return rv
- def resolve_or_missing(self, key):
+ def resolve_or_missing(self, key: str) -> t.Any:
"""Look up a variable by name, or return a ``missing`` sentinel
if the key is not found.
@@ -269,11 +293,11 @@ class Context(metaclass=ContextMeta):
return missing
- def get_exported(self):
+ def get_exported(self) -> t.Dict[str, t.Any]:
"""Get a new dict with the exported variables."""
return {k: self.vars[k] for k in self.exported_vars}
- def get_all(self):
+ def get_all(self) -> t.Dict[str, t.Any]:
"""Return the complete context as dict including the exported
variables. For optimizations reasons this might not return an
actual copy so be careful with using it.
@@ -285,7 +309,9 @@ class Context(metaclass=ContextMeta):
return dict(self.parent, **self.vars)
@internalcode
- def call(__self, __obj, *args, **kwargs): # noqa: B902
+ def call(
+ __self, __obj: t.Callable, *args: t.Any, **kwargs: t.Any # noqa: B902
+ ) -> t.Union[t.Any, "Undefined"]:
"""Call the callable with the arguments and keyword arguments
provided but inject the active context or environment as first
argument if the callable has :func:`pass_context` or
@@ -297,28 +323,28 @@ class Context(metaclass=ContextMeta):
# Allow callable classes to take a context
if (
hasattr(__obj, "__call__") # noqa: B004
- and _PassArg.from_obj(__obj.__call__) is not None
+ and _PassArg.from_obj(__obj.__call__) is not None # type: ignore
):
- __obj = __obj.__call__
-
- if callable(__obj):
- pass_arg = _PassArg.from_obj(__obj)
-
- if pass_arg is _PassArg.context:
- # the active context should have access to variables set in
- # loops and blocks without mutating the context itself
- if kwargs.get("_loop_vars"):
- __self = __self.derived(kwargs["_loop_vars"])
- if kwargs.get("_block_vars"):
- __self = __self.derived(kwargs["_block_vars"])
- args = (__self,) + args
- elif pass_arg is _PassArg.eval_context:
- args = (__self.eval_ctx,) + args
- elif pass_arg is _PassArg.environment:
- args = (__self.environment,) + args
+ __obj = __obj.__call__ # type: ignore
+
+ pass_arg = _PassArg.from_obj(__obj)
+
+ if pass_arg is _PassArg.context:
+ # the active context should have access to variables set in
+ # loops and blocks without mutating the context itself
+ if kwargs.get("_loop_vars"):
+ __self = __self.derived(kwargs["_loop_vars"])
+ if kwargs.get("_block_vars"):
+ __self = __self.derived(kwargs["_block_vars"])
+ args = (__self,) + args
+ elif pass_arg is _PassArg.eval_context:
+ args = (__self.eval_ctx,) + args
+ elif pass_arg is _PassArg.environment:
+ args = (__self.environment,) + args
kwargs.pop("_block_vars", None)
kwargs.pop("_loop_vars", None)
+
try:
return __obj(*args, **kwargs)
except StopIteration:
@@ -327,7 +353,7 @@ class Context(metaclass=ContextMeta):
" StopIteration exception"
)
- def derived(self, locals=None):
+ def derived(self, locals: t.Optional[t.Dict[str, t.Any]] = None) -> "Context":
"""Internal helper function to create a derived context. This is
used in situations where the system needs a new context in the same
template that is independent.
@@ -339,24 +365,14 @@ class Context(metaclass=ContextMeta):
context.blocks.update((k, list(v)) for k, v in self.blocks.items())
return context
- # ignore: true
- def _all(meth): # noqa: B902
- def proxy(self):
- return getattr(self.get_all(), meth)()
-
- proxy.__doc__ = getattr(dict, meth).__doc__
- proxy.__name__ = meth
- return proxy
+ keys = _dict_method_all(dict.keys)
+ values = _dict_method_all(dict.values)
+ items = _dict_method_all(dict.items)
- keys = _all("keys") # type:ignore
- values = _all("values") # type:ignore
- items = _all("items") # type:ignore
- del _all
-
- def __contains__(self, name):
+ def __contains__(self, name: str) -> bool:
return name in self.vars or name in self.parent
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> t.Any:
"""Look up a variable by name with ``[]`` syntax, or raise a
``KeyError`` if the key is not found.
"""
@@ -367,21 +383,27 @@ class Context(metaclass=ContextMeta):
return item
- def __repr__(self):
- return f"<{self.__class__.__name__} {self.get_all()!r} of {self.name!r}>"
+ def __repr__(self) -> str:
+ return f"<{type(self).__name__} {self.get_all()!r} of {self.name!r}>"
class BlockReference:
"""One block on a template reference."""
- def __init__(self, name, context, stack, depth):
+ def __init__(
+ self,
+ name: str,
+ context: "Context",
+ stack: t.List[t.Callable[["Context"], t.Iterator[str]]],
+ depth: int,
+ ) -> None:
self.name = name
self._context = context
self._stack = stack
self._depth = depth
@property
- def super(self):
+ def super(self) -> t.Union["BlockReference", "Undefined"]:
"""Super the block."""
if self._depth + 1 >= len(self._stack):
return self._context.environment.undefined(
@@ -390,8 +412,10 @@ class BlockReference:
return BlockReference(self.name, self._context, self._stack, self._depth + 1)
@internalcode
- async def _async_call(self):
- rv = concat([x async for x in self._stack[self._depth](self._context)])
+ async def _async_call(self) -> str:
+ rv = concat(
+ [x async for x in self._stack[self._depth](self._context)] # type: ignore
+ )
if self._context.eval_ctx.autoescape:
return Markup(rv)
@@ -399,9 +423,9 @@ class BlockReference:
return rv
@internalcode
- def __call__(self):
+ def __call__(self) -> str:
if self._context.environment.is_async:
- return self._async_call()
+ return self._async_call() # type: ignore
rv = concat(self._stack[self._depth](self._context))
@@ -420,12 +444,18 @@ class LoopContext:
index0 = -1
_length: t.Optional[int] = None
- _after = missing
- _current = missing
- _before = missing
- _last_changed_value = missing
+ _after: t.Any = missing
+ _current: t.Any = missing
+ _before: t.Any = missing
+ _last_changed_value: t.Any = missing
- def __init__(self, iterable, undefined, recurse=None, depth0=0):
+ def __init__(
+ self,
+ iterable: t.Iterable[V],
+ undefined: t.Type["Undefined"],
+ recurse: t.Optional["LoopRenderFunc"] = None,
+ depth0: int = 0,
+ ) -> None:
"""
:param iterable: Iterable to wrap.
:param undefined: :class:`Undefined` class to use for next and
@@ -442,11 +472,11 @@ class LoopContext:
self.depth0 = depth0
@staticmethod
- def _to_iterator(iterable):
+ def _to_iterator(iterable: t.Iterable[V]) -> t.Iterator[V]:
return iter(iterable)
@property
- def length(self):
+ def length(self) -> int:
"""Length of the iterable.
If the iterable is a generator or otherwise does not have a
@@ -456,7 +486,7 @@ class LoopContext:
return self._length
try:
- self._length = len(self._iterable)
+ self._length = len(self._iterable) # type: ignore
except TypeError:
iterable = list(self._iterator)
self._iterator = self._to_iterator(iterable)
@@ -464,21 +494,21 @@ class LoopContext:
return self._length
- def __len__(self):
+ def __len__(self) -> int:
return self.length
@property
- def depth(self):
+ def depth(self) -> int:
"""How many levels deep a recursive loop currently is, starting at 1."""
return self.depth0 + 1
@property
- def index(self):
+ def index(self) -> int:
"""Current iteration of the loop, starting at 1."""
return self.index0 + 1
@property
- def revindex0(self):
+ def revindex0(self) -> int:
"""Number of iterations from the end of the loop, ending at 0.
Requires calculating :attr:`length`.
@@ -486,7 +516,7 @@ class LoopContext:
return self.length - self.index
@property
- def revindex(self):
+ def revindex(self) -> int:
"""Number of iterations from the end of the loop, ending at 1.
Requires calculating :attr:`length`.
@@ -494,11 +524,11 @@ class LoopContext:
return self.length - self.index0
@property
- def first(self):
+ def first(self) -> bool:
"""Whether this is the first iteration of the loop."""
return self.index0 == 0
- def _peek_next(self):
+ def _peek_next(self) -> t.Any:
"""Return the next element in the iterable, or :data:`missing`
if the iterable is exhausted. Only peeks one item ahead, caching
the result in :attr:`_last` for use in subsequent checks. The
@@ -511,7 +541,7 @@ class LoopContext:
return self._after
@property
- def last(self):
+ def last(self) -> bool:
"""Whether this is the last iteration of the loop.
Causes the iterable to advance early. See
@@ -521,7 +551,7 @@ class LoopContext:
return self._peek_next() is missing
@property
- def previtem(self):
+ def previtem(self) -> t.Union[t.Any, "Undefined"]:
"""The item in the previous iteration. Undefined during the
first iteration.
"""
@@ -531,13 +561,13 @@ class LoopContext:
return self._before
@property
- def nextitem(self):
+ def nextitem(self) -> t.Union[t.Any, "Undefined"]:
"""The item in the next iteration. Undefined during the last
iteration.
Causes the iterable to advance early. See
:func:`itertools.groupby` for issues this can cause.
- The :func:`groupby` filter avoids that issue.
+ The :func:`jinja-filters.groupby` filter avoids that issue.
"""
rv = self._peek_next()
@@ -546,7 +576,7 @@ class LoopContext:
return rv
- def cycle(self, *args):
+ def cycle(self, *args: V) -> V:
"""Return a value from the given args, cycling through based on
the current :attr:`index0`.
@@ -557,7 +587,7 @@ class LoopContext:
return args[self.index0 % len(args)]
- def changed(self, *value):
+ def changed(self, *value: t.Any) -> bool:
"""Return ``True`` if previously called with a different value
(including when called for the first time).
@@ -569,10 +599,10 @@ class LoopContext:
return False
- def __iter__(self):
+ def __iter__(self) -> "LoopContext":
return self
- def __next__(self):
+ def __next__(self) -> t.Tuple[t.Any, "LoopContext"]:
if self._after is not missing:
rv = self._after
self._after = missing
@@ -585,7 +615,7 @@ class LoopContext:
return rv, self
@internalcode
- def __call__(self, iterable):
+ def __call__(self, iterable: t.Iterable[V]) -> str:
"""When iterating over nested data, render the body of the loop
recursively with the given inner iterable data.
@@ -598,22 +628,26 @@ class LoopContext:
return self._recurse(iterable, self._recurse, depth=self.depth)
- def __repr__(self):
- return f"<{self.__class__.__name__} {self.index}/{self.length}>"
+ def __repr__(self) -> str:
+ return f"<{type(self).__name__} {self.index}/{self.length}>"
class AsyncLoopContext(LoopContext):
+ _iterator: t.AsyncIterator[t.Any] # type: ignore
+
@staticmethod
- def _to_iterator(iterable):
+ def _to_iterator( # type: ignore
+ iterable: t.Union[t.Iterable[V], t.AsyncIterable[V]]
+ ) -> t.AsyncIterator[V]:
return auto_aiter(iterable)
@property
- async def length(self):
+ async def length(self) -> int: # type: ignore
if self._length is not None:
return self._length
try:
- self._length = len(self._iterable)
+ self._length = len(self._iterable) # type: ignore
except TypeError:
iterable = [x async for x in self._iterator]
self._iterator = self._to_iterator(iterable)
@@ -622,14 +656,14 @@ class AsyncLoopContext(LoopContext):
return self._length
@property
- async def revindex0(self):
+ async def revindex0(self) -> int: # type: ignore
return await self.length - self.index
@property
- async def revindex(self):
+ async def revindex(self) -> int: # type: ignore
return await self.length - self.index0
- async def _peek_next(self):
+ async def _peek_next(self) -> t.Any:
if self._after is not missing:
return self._after
@@ -641,11 +675,11 @@ class AsyncLoopContext(LoopContext):
return self._after
@property
- async def last(self):
+ async def last(self) -> bool: # type: ignore
return await self._peek_next() is missing
@property
- async def nextitem(self):
+ async def nextitem(self) -> t.Union[t.Any, "Undefined"]:
rv = await self._peek_next()
if rv is missing:
@@ -653,10 +687,10 @@ class AsyncLoopContext(LoopContext):
return rv
- def __aiter__(self):
+ def __aiter__(self) -> "AsyncLoopContext":
return self
- async def __anext__(self):
+ async def __anext__(self) -> t.Tuple[t.Any, "AsyncLoopContext"]:
if self._after is not missing:
rv = self._after
self._after = missing
@@ -674,14 +708,14 @@ class Macro:
def __init__(
self,
- environment,
- func,
- name,
- arguments,
- catch_kwargs,
- catch_varargs,
- caller,
- default_autoescape=None,
+ environment: "Environment",
+ func: t.Callable[..., str],
+ name: str,
+ arguments: t.List[str],
+ catch_kwargs: bool,
+ catch_varargs: bool,
+ caller: bool,
+ default_autoescape: t.Optional[bool] = None,
):
self._environment = environment
self._func = func
@@ -692,13 +726,18 @@ class Macro:
self.catch_varargs = catch_varargs
self.caller = caller
self.explicit_caller = "caller" in arguments
+
if default_autoescape is None:
- default_autoescape = environment.autoescape
+ if callable(environment.autoescape):
+ default_autoescape = environment.autoescape(None)
+ else:
+ default_autoescape = environment.autoescape
+
self._default_autoescape = default_autoescape
@internalcode
@pass_eval_context
- def __call__(self, *args, **kwargs):
+ def __call__(self, *args: t.Any, **kwargs: t.Any) -> str:
# This requires a bit of explanation, In the past we used to
# decide largely based on compile-time information if a macro is
# safe or unsafe. While there was a volatile mode it was largely
@@ -774,17 +813,17 @@ class Macro:
return self._invoke(arguments, autoescape)
- async def _async_invoke(self, arguments, autoescape):
- rv = await self._func(*arguments)
+ async def _async_invoke(self, arguments: t.List[t.Any], autoescape: bool) -> str:
+ rv = await self._func(*arguments) # type: ignore
if autoescape:
return Markup(rv)
- return rv
+ return rv # type: ignore
- def _invoke(self, arguments, autoescape):
+ def _invoke(self, arguments: t.List[t.Any], autoescape: bool) -> str:
if self._environment.is_async:
- return self._async_invoke(arguments, autoescape)
+ return self._async_invoke(arguments, autoescape) # type: ignore
rv = self._func(*arguments)
@@ -793,9 +832,9 @@ class Macro:
return rv
- def __repr__(self):
+ def __repr__(self) -> str:
name = "anonymous" if self.name is None else repr(self.name)
- return f"<{self.__class__.__name__} {name}>"
+ return f"<{type(self).__name__} {name}>"
class Undefined:
@@ -820,14 +859,20 @@ class Undefined:
"_undefined_exception",
)
- def __init__(self, hint=None, obj=missing, name=None, exc=UndefinedError):
+ def __init__(
+ self,
+ hint: t.Optional[str] = None,
+ obj: t.Any = missing,
+ name: t.Optional[str] = None,
+ exc: t.Type[TemplateRuntimeError] = UndefinedError,
+ ) -> None:
self._undefined_hint = hint
self._undefined_obj = obj
self._undefined_name = name
self._undefined_exception = exc
@property
- def _undefined_message(self):
+ def _undefined_message(self) -> str:
"""Build a message about the undefined value based on how it was
accessed.
"""
@@ -849,16 +894,17 @@ class Undefined:
)
@internalcode
- def _fail_with_undefined_error(self, *args, **kwargs):
+ def _fail_with_undefined_error(self, *args: t.Any, **kwargs: t.Any) -> t.NoReturn:
"""Raise an :exc:`UndefinedError` when operations are performed
on the undefined value.
"""
raise self._undefined_exception(self._undefined_message)
@internalcode
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> t.Any:
if name[:2] == "__":
raise AttributeError(name)
+
return self._fail_with_undefined_error()
__add__ = __radd__ = __sub__ = __rsub__ = _fail_with_undefined_error
@@ -872,36 +918,38 @@ class Undefined:
__int__ = __float__ = __complex__ = _fail_with_undefined_error
__pow__ = __rpow__ = _fail_with_undefined_error
- def __eq__(self, other):
+ def __eq__(self, other: t.Any) -> bool:
return type(self) is type(other)
- def __ne__(self, other):
+ def __ne__(self, other: t.Any) -> bool:
return not self.__eq__(other)
- def __hash__(self):
+ def __hash__(self) -> int:
return id(type(self))
- def __str__(self):
+ def __str__(self) -> str:
return ""
- def __len__(self):
+ def __len__(self) -> int:
return 0
- def __iter__(self):
+ def __iter__(self) -> t.Iterator[t.Any]:
yield from ()
- async def __aiter__(self):
+ async def __aiter__(self) -> t.AsyncIterator[t.Any]:
for _ in ():
yield
- def __bool__(self):
+ def __bool__(self) -> bool:
return False
- def __repr__(self):
+ def __repr__(self) -> str:
return "Undefined"
-def make_logging_undefined(logger=None, base=None):
+def make_logging_undefined(
+ logger: t.Optional["logging.Logger"] = None, base: t.Type[Undefined] = Undefined
+) -> t.Type[Undefined]:
"""Given a logger object this returns a new undefined class that will
log certain failures. It will log iterations and printing. If no
logger is given a default logger is created.
@@ -926,31 +974,35 @@ def make_logging_undefined(logger=None, base=None):
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler(sys.stderr))
- if base is None:
- base = Undefined
- def _log_message(undef):
- logger.warning("Template variable warning: %s", undef._undefined_message)
+ def _log_message(undef: Undefined) -> None:
+ logger.warning( # type: ignore
+ "Template variable warning: %s", undef._undefined_message
+ )
+
+ class LoggingUndefined(base): # type: ignore
+ __slots__ = ()
- class LoggingUndefined(base):
- def _fail_with_undefined_error(self, *args, **kwargs):
+ def _fail_with_undefined_error( # type: ignore
+ self, *args: t.Any, **kwargs: t.Any
+ ) -> t.NoReturn:
try:
- return super()._fail_with_undefined_error(*args, **kwargs)
+ super()._fail_with_undefined_error(*args, **kwargs)
except self._undefined_exception as e:
- logger.error("Template variable error: %s", e)
+ logger.error("Template variable error: %s", e) # type: ignore
raise e
- def __str__(self):
+ def __str__(self) -> str:
_log_message(self)
- return super().__str__()
+ return super().__str__() # type: ignore
- def __iter__(self):
+ def __iter__(self) -> t.Iterator[t.Any]:
_log_message(self)
- return super().__iter__()
+ return super().__iter__() # type: ignore
- def __bool__(self):
+ def __bool__(self) -> bool:
_log_message(self)
- return super().__bool__()
+ return super().__bool__() # type: ignore
return LoggingUndefined
@@ -973,13 +1025,13 @@ class ChainableUndefined(Undefined):
__slots__ = ()
- def __html__(self):
- return self.__str__()
+ def __html__(self) -> str:
+ return str(self)
- def __getattr__(self, _):
+ def __getattr__(self, _: str) -> "ChainableUndefined":
return self
- __getitem__ = __getattr__
+ __getitem__ = __getattr__ # type: ignore
class DebugUndefined(Undefined):
@@ -998,12 +1050,12 @@ class DebugUndefined(Undefined):
__slots__ = ()
- def __str__(self):
+ def __str__(self) -> str:
if self._undefined_hint:
message = f"undefined value printed: {self._undefined_hint}"
elif self._undefined_obj is missing:
- message = self._undefined_name
+ message = self._undefined_name # type: ignore
else:
message = (
diff --git a/src/jinja2/sandbox.py b/src/jinja2/sandbox.py
index 6311a5d..4294884 100644
--- a/src/jinja2/sandbox.py
+++ b/src/jinja2/sandbox.py
@@ -3,27 +3,30 @@ Useful when the template itself comes from an untrusted source.
"""
import operator
import types
+import typing as t
from _string import formatter_field_name_split # type: ignore
from collections import abc
from collections import deque
from string import Formatter
-from typing import FrozenSet
-from typing import Set
from markupsafe import EscapeFormatter
from markupsafe import Markup
from .environment import Environment
from .exceptions import SecurityError
+from .runtime import Context
+from .runtime import Undefined
+
+F = t.TypeVar("F", bound=t.Callable[..., t.Any])
#: maximum number of items a range may produce
MAX_RANGE = 100000
#: Unsafe function attributes.
-UNSAFE_FUNCTION_ATTRIBUTES: Set = set()
+UNSAFE_FUNCTION_ATTRIBUTES: t.Set[str] = set()
#: Unsafe method attributes. Function attributes are unsafe for methods too.
-UNSAFE_METHOD_ATTRIBUTES: Set = set()
+UNSAFE_METHOD_ATTRIBUTES: t.Set[str] = set()
#: unsafe generator attributes.
UNSAFE_GENERATOR_ATTRIBUTES = {"gi_frame", "gi_code"}
@@ -34,7 +37,7 @@ UNSAFE_COROUTINE_ATTRIBUTES = {"cr_frame", "cr_code"}
#: unsafe attributes on async generators
UNSAFE_ASYNC_GENERATOR_ATTRIBUTES = {"ag_code", "ag_frame"}
-_mutable_spec = (
+_mutable_spec: t.Tuple[t.Tuple[t.Type, t.FrozenSet[str]], ...] = (
(
abc.MutableSet,
frozenset(
@@ -77,17 +80,21 @@ _mutable_spec = (
)
-def inspect_format_method(callable):
+def inspect_format_method(callable: t.Callable) -> t.Optional[str]:
if not isinstance(
callable, (types.MethodType, types.BuiltinMethodType)
) or callable.__name__ not in ("format", "format_map"):
return None
+
obj = callable.__self__
+
if isinstance(obj, str):
return obj
+ return None
+
-def safe_range(*args):
+def safe_range(*args: int) -> range:
"""A range that can't generate ranges with a length of more than
MAX_RANGE items.
"""
@@ -102,7 +109,7 @@ def safe_range(*args):
return rng
-def unsafe(f):
+def unsafe(f: F) -> F:
"""Marks a function or method as unsafe.
.. code-block: python
@@ -111,11 +118,11 @@ def unsafe(f):
def delete(self):
pass
"""
- f.unsafe_callable = True
+ f.unsafe_callable = True # type: ignore
return f
-def is_internal_attribute(obj, attr):
+def is_internal_attribute(obj: t.Any, attr: str) -> bool:
"""Test if the attribute given is an internal python attribute. For
example this function returns `True` for the `func_code` attribute of
python objects. This is useful if the environment method
@@ -152,7 +159,7 @@ def is_internal_attribute(obj, attr):
return attr.startswith("__")
-def modifies_known_mutable(obj, attr):
+def modifies_known_mutable(obj: t.Any, attr: str) -> bool:
"""This function checks if an attribute on a builtin mutable object
(list, dict, set or deque) or the corresponding ABCs would modify it
if called.
@@ -193,7 +200,7 @@ class SandboxedEnvironment(Environment):
#: default callback table for the binary operators. A copy of this is
#: available on each instance of a sandboxed environment as
#: :attr:`binop_table`
- default_binop_table = {
+ default_binop_table: t.Dict[str, t.Callable[[t.Any, t.Any], t.Any]] = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
@@ -206,7 +213,10 @@ class SandboxedEnvironment(Environment):
#: default callback table for the unary operators. A copy of this is
#: available on each instance of a sandboxed environment as
#: :attr:`unop_table`
- default_unop_table = {"+": operator.pos, "-": operator.neg}
+ default_unop_table: t.Dict[str, t.Callable[[t.Any], t.Any]] = {
+ "+": operator.pos,
+ "-": operator.neg,
+ }
#: a set of binary operators that should be intercepted. Each operator
#: that is added to this set (empty by default) is delegated to the
@@ -222,7 +232,7 @@ class SandboxedEnvironment(Environment):
#: interested in.
#:
#: .. versionadded:: 2.6
- intercepted_binops: FrozenSet = frozenset()
+ intercepted_binops: t.FrozenSet[str] = frozenset()
#: a set of unary operators that should be intercepted. Each operator
#: that is added to this set (empty by default) is delegated to the
@@ -237,32 +247,15 @@ class SandboxedEnvironment(Environment):
#: interested in.
#:
#: .. versionadded:: 2.6
- intercepted_unops: FrozenSet = frozenset()
-
- def intercept_unop(self, operator):
- """Called during template compilation with the name of a unary
- operator to check if it should be intercepted at runtime. If this
- method returns `True`, :meth:`call_unop` is executed for this unary
- operator. The default implementation of :meth:`call_unop` will use
- the :attr:`unop_table` dictionary to perform the operator with the
- same logic as the builtin one.
-
- The following unary operators are interceptable: ``+`` and ``-``
-
- Intercepted calls are always slower than the native operator call,
- so make sure only to intercept the ones you are interested in.
+ intercepted_unops: t.FrozenSet[str] = frozenset()
- .. versionadded:: 2.6
- """
- return False
-
- def __init__(self, *args, **kwargs):
- Environment.__init__(self, *args, **kwargs)
+ def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
+ super().__init__(*args, **kwargs)
self.globals["range"] = safe_range
self.binop_table = self.default_binop_table.copy()
self.unop_table = self.default_unop_table.copy()
- def is_safe_attribute(self, obj, attr, value):
+ def is_safe_attribute(self, obj: t.Any, attr: str, value: t.Any) -> bool:
"""The sandboxed environment will call this method to check if the
attribute of an object is safe to access. Per default all attributes
starting with an underscore are considered private as well as the
@@ -271,17 +264,20 @@ class SandboxedEnvironment(Environment):
"""
return not (attr.startswith("_") or is_internal_attribute(obj, attr))
- def is_safe_callable(self, obj):
- """Check if an object is safely callable. Per default a function is
- considered safe unless the `unsafe_callable` attribute exists and is
- True. Override this method to alter the behavior, but this won't
- affect the `unsafe` decorator from this module.
+ def is_safe_callable(self, obj: t.Any) -> bool:
+ """Check if an object is safely callable. By default callables
+ are considered safe unless decorated with :func:`unsafe`.
+
+ This also recognizes the Django convention of setting
+ ``func.alters_data = True``.
"""
return not (
getattr(obj, "unsafe_callable", False) or getattr(obj, "alters_data", False)
)
- def call_binop(self, context, operator, left, right):
+ def call_binop(
+ self, context: Context, operator: str, left: t.Any, right: t.Any
+ ) -> t.Any:
"""For intercepted binary operator calls (:meth:`intercepted_binops`)
this function is executed instead of the builtin operator. This can
be used to fine tune the behavior of certain operators.
@@ -290,7 +286,7 @@ class SandboxedEnvironment(Environment):
"""
return self.binop_table[operator](left, right)
- def call_unop(self, context, operator, arg):
+ def call_unop(self, context: Context, operator: str, arg: t.Any) -> t.Any:
"""For intercepted unary operator calls (:meth:`intercepted_unops`)
this function is executed instead of the builtin operator. This can
be used to fine tune the behavior of certain operators.
@@ -299,7 +295,9 @@ class SandboxedEnvironment(Environment):
"""
return self.unop_table[operator](arg)
- def getitem(self, obj, argument):
+ def getitem(
+ self, obj: t.Any, argument: t.Union[str, t.Any]
+ ) -> t.Union[t.Any, Undefined]:
"""Subscribe an object from sandboxed code."""
try:
return obj[argument]
@@ -320,7 +318,7 @@ class SandboxedEnvironment(Environment):
return self.unsafe_undefined(obj, argument)
return self.undefined(obj=obj, name=argument)
- def getattr(self, obj, attribute):
+ def getattr(self, obj: t.Any, attribute: str) -> t.Union[t.Any, Undefined]:
"""Subscribe an object from sandboxed code and prefer the
attribute. The attribute passed *must* be a bytestring.
"""
@@ -337,22 +335,29 @@ class SandboxedEnvironment(Environment):
return self.unsafe_undefined(obj, attribute)
return self.undefined(obj=obj, name=attribute)
- def unsafe_undefined(self, obj, attribute):
+ def unsafe_undefined(self, obj: t.Any, attribute: str) -> Undefined:
"""Return an undefined object for unsafe attributes."""
return self.undefined(
f"access to attribute {attribute!r} of"
- f" {obj.__class__.__name__!r} object is unsafe.",
+ f" {type(obj).__name__!r} object is unsafe.",
name=attribute,
obj=obj,
exc=SecurityError,
)
- def format_string(self, s, args, kwargs, format_func=None):
+ def format_string(
+ self,
+ s: str,
+ args: t.Tuple[t.Any, ...],
+ kwargs: t.Dict[str, t.Any],
+ format_func: t.Optional[t.Callable] = None,
+ ) -> str:
"""If a format call is detected, then this is routed through this
method so that our safety sandbox can be used for it.
"""
+ formatter: SandboxedFormatter
if isinstance(s, Markup):
- formatter = SandboxedEscapeFormatter(self, s.escape)
+ formatter = SandboxedEscapeFormatter(self, escape=s.escape)
else:
formatter = SandboxedFormatter(self)
@@ -364,12 +369,18 @@ class SandboxedEnvironment(Environment):
)
kwargs = args[0]
- args = None
+ args = ()
rv = formatter.vformat(s, args, kwargs)
return type(s)(rv)
- def call(__self, __context, __obj, *args, **kwargs): # noqa: B902
+ def call(
+ __self, # noqa: B902
+ __context: Context,
+ __obj: t.Any,
+ *args: t.Any,
+ **kwargs: t.Any,
+ ) -> t.Any:
"""Call an object from sandboxed code."""
fmt = inspect_format_method(__obj)
if fmt is not None:
@@ -388,17 +399,21 @@ class ImmutableSandboxedEnvironment(SandboxedEnvironment):
`dict` by using the :func:`modifies_known_mutable` function.
"""
- def is_safe_attribute(self, obj, attr, value):
- if not SandboxedEnvironment.is_safe_attribute(self, obj, attr, value):
+ def is_safe_attribute(self, obj: t.Any, attr: str, value: t.Any) -> bool:
+ if not super().is_safe_attribute(obj, attr, value):
return False
+
return not modifies_known_mutable(obj, attr)
-class SandboxedFormatterMixin:
- def __init__(self, env):
+class SandboxedFormatter(Formatter):
+ def __init__(self, env: Environment, **kwargs: t.Any) -> None:
self._env = env
+ super().__init__(**kwargs) # type: ignore
- def get_field(self, field_name, args, kwargs):
+ def get_field(
+ self, field_name: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
+ ) -> t.Tuple[t.Any, str]:
first, rest = formatter_field_name_split(field_name)
obj = self.get_value(first, args, kwargs)
for is_attr, i in rest:
@@ -409,13 +424,5 @@ class SandboxedFormatterMixin:
return obj, first
-class SandboxedFormatter(SandboxedFormatterMixin, Formatter):
- def __init__(self, env):
- SandboxedFormatterMixin.__init__(self, env)
- Formatter.__init__(self)
-
-
-class SandboxedEscapeFormatter(SandboxedFormatterMixin, EscapeFormatter):
- def __init__(self, env, escape):
- SandboxedFormatterMixin.__init__(self, env)
- EscapeFormatter.__init__(self, escape)
+class SandboxedEscapeFormatter(SandboxedFormatter, EscapeFormatter):
+ pass
diff --git a/src/jinja2/utils.py b/src/jinja2/utils.py
index 51e03d8..d06c1e4 100644
--- a/src/jinja2/utils.py
+++ b/src/jinja2/utils.py
@@ -14,18 +14,17 @@ from urllib.parse import quote_from_bytes
import markupsafe
-if t.TYPE_CHECKING:
- F = t.TypeVar("F", bound=t.Callable[..., t.Any])
+F = t.TypeVar("F", bound=t.Callable[..., t.Any])
# special singleton representing missing values for the runtime
-missing = type("MissingType", (), {"__repr__": lambda x: "missing"})()
+missing: t.Any = type("MissingType", (), {"__repr__": lambda x: "missing"})()
internal_code: t.MutableSet[CodeType] = set()
concat = "".join
-def pass_context(f: "F") -> "F":
+def pass_context(f: F) -> F:
"""Pass the :class:`~jinja2.runtime.Context` as the first argument
to the decorated function when called while rendering a template.
@@ -42,7 +41,7 @@ def pass_context(f: "F") -> "F":
return f
-def pass_eval_context(f: "F") -> "F":
+def pass_eval_context(f: F) -> F:
"""Pass the :class:`~jinja2.nodes.EvalContext` as the first argument
to the decorated function when called while rendering a template.
See :ref:`eval-context`.
@@ -59,7 +58,7 @@ def pass_eval_context(f: "F") -> "F":
return f
-def pass_environment(f: "F") -> "F":
+def pass_environment(f: F) -> F:
"""Pass the :class:`~jinja2.Environment` as the first argument to
the decorated function when called while rendering a template.
@@ -78,9 +77,9 @@ class _PassArg(enum.Enum):
environment = enum.auto()
@classmethod
- def from_obj(cls, obj):
+ def from_obj(cls, obj: F) -> t.Optional["_PassArg"]:
if hasattr(obj, "jinja_pass_arg"):
- return obj.jinja_pass_arg
+ return obj.jinja_pass_arg # type: ignore
for prefix in "context", "eval_context", "environment":
squashed = prefix.replace("_", "")
@@ -95,8 +94,10 @@ class _PassArg(enum.Enum):
)
return cls[prefix]
+ return None
+
-def contextfunction(f):
+def contextfunction(f: F) -> F:
"""Pass the context as the first argument to the decorated function.
.. deprecated:: 3.0
@@ -112,7 +113,7 @@ def contextfunction(f):
return pass_context(f)
-def evalcontextfunction(f):
+def evalcontextfunction(f: F) -> F:
"""Pass the eval context as the first argument to the decorated
function.
@@ -131,7 +132,7 @@ def evalcontextfunction(f):
return pass_eval_context(f)
-def environmentfunction(f):
+def environmentfunction(f: F) -> F:
"""Pass the environment as the first argument to the decorated
function.
@@ -148,13 +149,13 @@ def environmentfunction(f):
return pass_environment(f)
-def internalcode(f):
+def internalcode(f: F) -> F:
"""Marks the function as internally used"""
internal_code.add(f.__code__)
return f
-def is_undefined(obj):
+def is_undefined(obj: t.Any) -> bool:
"""Check if the object passed is undefined. This does nothing more than
performing an instance check against :class:`Undefined` but looks nicer.
This can be used for custom filters or tests that want to react to
@@ -171,26 +172,26 @@ def is_undefined(obj):
return isinstance(obj, Undefined)
-def consume(iterable):
+def consume(iterable: t.Iterable[t.Any]) -> None:
"""Consumes an iterable without doing anything with it."""
for _ in iterable:
pass
-def clear_caches():
+def clear_caches() -> None:
"""Jinja keeps internal caches for environments and lexers. These are
used so that Jinja doesn't have to recreate environments and lexers all
the time. Normally you don't have to care about that but if you are
measuring memory consumption you may want to clean the caches.
"""
- from .environment import _spontaneous_environments
+ from .environment import get_spontaneous_environment
from .lexer import _lexer_cache
- _spontaneous_environments.clear()
+ get_spontaneous_environment.cache_clear()
_lexer_cache.clear()
-def import_string(import_name, silent=False):
+def import_string(import_name: str, silent: bool = False) -> t.Any:
"""Imports an object based on a string. This is useful if you want to
use import paths as endpoints or something similar. An import path can
be specified either in dotted notation (``xml.sax.saxutils.escape``)
@@ -214,7 +215,7 @@ def import_string(import_name, silent=False):
raise
-def open_if_exists(filename, mode="rb"):
+def open_if_exists(filename: str, mode: str = "rb") -> t.Optional[t.IO]:
"""Returns a file descriptor for the filename if that file exists,
otherwise ``None``.
"""
@@ -224,7 +225,7 @@ def open_if_exists(filename, mode="rb"):
return open(filename, mode)
-def object_type_repr(obj):
+def object_type_repr(obj: t.Any) -> str:
"""Returns the name of the object's type. For some recognized
singletons the name of the object is returned instead. (For
example for `None` and `Ellipsis`).
@@ -242,9 +243,9 @@ def object_type_repr(obj):
return f"{cls.__module__}.{cls.__name__} object"
-def pformat(obj):
+def pformat(obj: t.Any) -> str:
"""Format an object using :func:`pprint.pformat`."""
- from pprint import pformat
+ from pprint import pformat # type: ignore
return pformat(obj)
@@ -320,15 +321,15 @@ def urlize(
"""
if trim_url_limit is not None:
- def trim_url(x):
- if len(x) > trim_url_limit:
+ def trim_url(x: str) -> str:
+ if len(x) > trim_url_limit: # type: ignore
return f"{x[:trim_url_limit]}..."
return x
else:
- def trim_url(x):
+ def trim_url(x: str) -> str:
return x
words = re.split(r"(\s+)", str(markupsafe.escape(text)))
@@ -401,7 +402,9 @@ def urlize(
return "".join(words)
-def generate_lorem_ipsum(n=5, html=True, min=20, max=100):
+def generate_lorem_ipsum(
+ n: int = 5, html: bool = True, min: int = 20, max: int = 100
+) -> str:
"""Generate some lorem ipsum for the template."""
from .constants import LOREM_IPSUM_WORDS
@@ -438,12 +441,14 @@ def generate_lorem_ipsum(n=5, html=True, min=20, max=100):
p.append(word)
# ensure that the paragraph ends with a dot.
- p = " ".join(p)
- if p.endswith(","):
- p = p[:-1] + "."
- elif not p.endswith("."):
- p += "."
- result.append(p)
+ p_str = " ".join(p)
+
+ if p_str.endswith(","):
+ p_str = p_str[:-1] + "."
+ elif not p_str.endswith("."):
+ p_str += "."
+
+ result.append(p_str)
if not html:
return "\n\n".join(result)
@@ -475,7 +480,7 @@ def url_quote(obj: t.Any, charset: str = "utf-8", for_qs: bool = False) -> str:
return rv
-def unicode_urlencode(obj, charset="utf-8", for_qs=False):
+def unicode_urlencode(obj: t.Any, charset: str = "utf-8", for_qs: bool = False) -> str:
import warnings
warnings.warn(
@@ -495,13 +500,13 @@ class LRUCache:
# scale. But as long as it's only used as storage for templates this
# won't do any harm.
- def __init__(self, capacity):
+ def __init__(self, capacity: int) -> None:
self.capacity = capacity
- self._mapping = {}
- self._queue = deque()
+ self._mapping: t.Dict[t.Any, t.Any] = {}
+ self._queue: t.Deque[t.Any] = deque()
self._postinit()
- def _postinit(self):
+ def _postinit(self) -> None:
# alias all queue methods for faster lookup
self._popleft = self._queue.popleft
self._pop = self._queue.pop
@@ -509,35 +514,35 @@ class LRUCache:
self._wlock = Lock()
self._append = self._queue.append
- def __getstate__(self):
+ def __getstate__(self) -> t.Mapping[str, t.Any]:
return {
"capacity": self.capacity,
"_mapping": self._mapping,
"_queue": self._queue,
}
- def __setstate__(self, d):
+ def __setstate__(self, d: t.Mapping[str, t.Any]) -> None:
self.__dict__.update(d)
self._postinit()
- def __getnewargs__(self):
+ def __getnewargs__(self) -> t.Tuple:
return (self.capacity,)
- def copy(self):
+ def copy(self) -> "LRUCache":
"""Return a shallow copy of the instance."""
rv = self.__class__(self.capacity)
rv._mapping.update(self._mapping)
rv._queue.extend(self._queue)
return rv
- def get(self, key, default=None):
+ def get(self, key: t.Any, default: t.Any = None) -> t.Any:
"""Return an item from the cache dict or `default`"""
try:
return self[key]
except KeyError:
return default
- def setdefault(self, key, default=None):
+ def setdefault(self, key: t.Any, default: t.Any = None) -> t.Any:
"""Set `default` if the key is not in the cache otherwise
leave unchanged. Return the value of this key.
"""
@@ -547,35 +552,32 @@ class LRUCache:
self[key] = default
return default
- def clear(self):
+ def clear(self) -> None:
"""Clear the cache."""
- self._wlock.acquire()
- try:
+ with self._wlock:
self._mapping.clear()
self._queue.clear()
- finally:
- self._wlock.release()
- def __contains__(self, key):
+ def __contains__(self, key: t.Any) -> bool:
"""Check if a key exists in this cache."""
return key in self._mapping
- def __len__(self):
+ def __len__(self) -> int:
"""Return the current size of the cache."""
return len(self._mapping)
- def __repr__(self):
- return f"<{self.__class__.__name__} {self._mapping!r}>"
+ def __repr__(self) -> str:
+ return f"<{type(self).__name__} {self._mapping!r}>"
- def __getitem__(self, key):
+ def __getitem__(self, key: t.Any) -> t.Any:
"""Get an item from the cache. Moves the item up so that it has the
highest priority then.
Raise a `KeyError` if it does not exist.
"""
- self._wlock.acquire()
- try:
+ with self._wlock:
rv = self._mapping[key]
+
if self._queue[-1] != key:
try:
self._remove(key)
@@ -584,58 +586,54 @@ class LRUCache:
# when we read, ignore the ValueError that we would
# get otherwise.
pass
+
self._append(key)
+
return rv
- finally:
- self._wlock.release()
- def __setitem__(self, key, value):
+ def __setitem__(self, key: t.Any, value: t.Any) -> None:
"""Sets the value for an item. Moves the item up so that it
has the highest priority then.
"""
- self._wlock.acquire()
- try:
+ with self._wlock:
if key in self._mapping:
self._remove(key)
elif len(self._mapping) == self.capacity:
del self._mapping[self._popleft()]
+
self._append(key)
self._mapping[key] = value
- finally:
- self._wlock.release()
- def __delitem__(self, key):
+ def __delitem__(self, key: t.Any) -> None:
"""Remove an item from the cache dict.
Raise a `KeyError` if it does not exist.
"""
- self._wlock.acquire()
- try:
+ with self._wlock:
del self._mapping[key]
+
try:
self._remove(key)
except ValueError:
pass
- finally:
- self._wlock.release()
- def items(self):
+ def items(self) -> t.Iterable[t.Tuple[t.Any, t.Any]]:
"""Return a list of items."""
result = [(key, self._mapping[key]) for key in list(self._queue)]
result.reverse()
return result
- def values(self):
+ def values(self) -> t.Iterable[t.Any]:
"""Return a list of all values."""
return [x[1] for x in self.items()]
- def keys(self):
+ def keys(self) -> t.Iterable[t.Any]:
"""Return a list of all keys ordered by most recent usage."""
return list(self)
- def __iter__(self):
+ def __iter__(self) -> t.Iterator[t.Any]:
return reversed(tuple(self._queue))
- def __reversed__(self):
+ def __reversed__(self) -> t.Iterator[t.Any]:
"""Iterate over the keys in the cache dict, oldest items
coming first.
"""
@@ -645,11 +643,11 @@ class LRUCache:
def select_autoescape(
- enabled_extensions=("html", "htm", "xml"),
- disabled_extensions=(),
- default_for_string=True,
- default=False,
-):
+ enabled_extensions: t.Collection[str] = ("html", "htm", "xml"),
+ disabled_extensions: t.Collection[str] = (),
+ default_for_string: bool = True,
+ default: bool = False,
+) -> t.Callable[[t.Optional[str]], bool]:
"""Intelligently sets the initial value of autoescaping based on the
filename of the template. This is the recommended way to configure
autoescaping if you do not want to write a custom function yourself.
@@ -687,7 +685,7 @@ def select_autoescape(
enabled_patterns = tuple(f".{x.lstrip('.').lower()}" for x in enabled_extensions)
disabled_patterns = tuple(f".{x.lstrip('.').lower()}" for x in disabled_extensions)
- def autoescape(template_name):
+ def autoescape(template_name: t.Optional[str]) -> bool:
if template_name is None:
return default_for_string
template_name = template_name.lower()
@@ -766,24 +764,24 @@ class Cycler:
.. versionadded:: 2.1
"""
- def __init__(self, *items):
+ def __init__(self, *items: t.Any) -> None:
if not items:
raise RuntimeError("at least one item has to be provided")
self.items = items
self.pos = 0
- def reset(self):
+ def reset(self) -> None:
"""Resets the current item to the first item."""
self.pos = 0
@property
- def current(self):
+ def current(self) -> t.Any:
"""Return the current item. Equivalent to the item that will be
returned next time :meth:`next` is called.
"""
return self.items[self.pos]
- def next(self):
+ def next(self) -> t.Any:
"""Return the current item, then advance :attr:`current` to the
next item.
"""
@@ -797,11 +795,11 @@ class Cycler:
class Joiner:
"""A joining helper for templates."""
- def __init__(self, sep=", "):
+ def __init__(self, sep: str = ", ") -> None:
self.sep = sep
self.used = False
- def __call__(self):
+ def __call__(self) -> str:
if not self.used:
self.used = True
return ""
@@ -812,11 +810,11 @@ class Namespace:
"""A namespace object that can hold arbitrary attributes. It may be
initialized from a dictionary or with keyword arguments."""
- def __init__(*args, **kwargs): # noqa: B902
+ def __init__(*args: t.Any, **kwargs: t.Any) -> None: # noqa: B902
self, args = args[0], args[1:]
self.__attrs = dict(*args, **kwargs)
- def __getattribute__(self, name):
+ def __getattribute__(self, name: str) -> t.Any:
# __class__ is needed for the awaitable check in async mode
if name in {"_Namespace__attrs", "__class__"}:
return object.__getattribute__(self, name)
@@ -825,15 +823,15 @@ class Namespace:
except KeyError:
raise AttributeError(name)
- def __setitem__(self, name, value):
+ def __setitem__(self, name: str, value: t.Any) -> None:
self.__attrs[name] = value
- def __repr__(self):
+ def __repr__(self) -> str:
return f"<Namespace {self.__attrs!r}>"
class Markup(markupsafe.Markup):
- def __new__(cls, base, encoding=None, errors="strict"):
+ def __new__(cls, base, encoding=None, errors="strict"): # type: ignore
warnings.warn(
"'jinja2.Markup' is deprecated and will be removed in Jinja"
" 3.1. Import 'markupsafe.Markup' instead.",
@@ -843,7 +841,7 @@ class Markup(markupsafe.Markup):
return super().__new__(cls, base, encoding, errors)
-def escape(s):
+def escape(s: t.Any) -> str:
warnings.warn(
"'jinja2.escape' is deprecated and will be removed in Jinja"
" 3.1. Import 'markupsafe.escape' instead.",
diff --git a/src/jinja2/visitor.py b/src/jinja2/visitor.py
index 590fa9e..b150e57 100644
--- a/src/jinja2/visitor.py
+++ b/src/jinja2/visitor.py
@@ -1,8 +1,17 @@
"""API for traversing the AST nodes. Implemented by the compiler and
meta introspection.
"""
+import typing as t
+
from .nodes import Node
+if t.TYPE_CHECKING:
+ import typing_extensions as te
+
+ class VisitCallable(te.Protocol):
+ def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
+ ...
+
class NodeVisitor:
"""Walks the abstract syntax tree and call visitor functions for every
@@ -16,21 +25,23 @@ class NodeVisitor:
(return value `None`) the `generic_visit` visitor is used instead.
"""
- def get_visitor(self, node):
+ def get_visitor(self, node: Node) -> "t.Optional[VisitCallable]":
"""Return the visitor function for this node or `None` if no visitor
exists for this node. In that case the generic visit function is
used instead.
"""
- return getattr(self, f"visit_{node.__class__.__name__}", None)
+ return getattr(self, f"visit_{type(node).__name__}", None) # type: ignore
- def visit(self, node, *args, **kwargs):
+ def visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Visit a node."""
f = self.get_visitor(node)
+
if f is not None:
return f(node, *args, **kwargs)
+
return self.generic_visit(node, *args, **kwargs)
- def generic_visit(self, node, *args, **kwargs):
+ def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Called if no explicit visitor function exists for a node."""
for node in node.iter_child_nodes():
self.visit(node, *args, **kwargs)
@@ -47,7 +58,7 @@ class NodeTransformer(NodeVisitor):
replacement takes place.
"""
- def generic_visit(self, node, *args, **kwargs):
+ def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> Node:
for field, old_value in node.iter_fields():
if isinstance(old_value, list):
new_values = []
@@ -69,11 +80,13 @@ class NodeTransformer(NodeVisitor):
setattr(node, field, new_node)
return node
- def visit_list(self, node, *args, **kwargs):
+ def visit_list(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.List[Node]:
"""As transformers may return lists in some places this method
can be used to enforce a list as return value.
"""
rv = self.visit(node, *args, **kwargs)
+
if not isinstance(rv, list):
- rv = [rv]
+ return [rv]
+
return rv
diff --git a/tests/test_ext.py b/tests/test_ext.py
index 20b19d8..238d95e 100644
--- a/tests/test_ext.py
+++ b/tests/test_ext.py
@@ -193,7 +193,7 @@ class StreamFilterExtension(Extension):
pos = 0
end = len(token.value)
lineno = token.lineno
- while 1:
+ while True:
match = _gettext_re.search(token.value, pos)
if match is None:
break