diff options
author | David Lord <davidism@gmail.com> | 2021-04-05 09:25:26 -0700 |
---|---|---|
committer | David Lord <davidism@gmail.com> | 2021-04-05 09:25:26 -0700 |
commit | f0a9f319b7c2a4ebb2f90fec92258719174cc988 (patch) | |
tree | 8f8bfc703edf71102b6a845b719190669f462d5f | |
parent | 43d422893065b3672b4c771edef02c50fd1c3866 (diff) | |
download | jinja2-f0a9f319b7c2a4ebb2f90fec92258719174cc988.tar.gz |
add type annotations to filters and tests
-rw-r--r-- | src/jinja2/asyncfilters.py | 151 | ||||
-rw-r--r-- | src/jinja2/environment.py | 2 | ||||
-rw-r--r-- | src/jinja2/filters.py | 483 | ||||
-rw-r--r-- | src/jinja2/runtime.py | 5 | ||||
-rw-r--r-- | src/jinja2/tests.py | 57 | ||||
-rw-r--r-- | src/jinja2/utils.py | 17 |
6 files changed, 509 insertions, 206 deletions
diff --git a/src/jinja2/asyncfilters.py b/src/jinja2/asyncfilters.py index 0aad12c..11b031a 100644 --- a/src/jinja2/asyncfilters.py +++ b/src/jinja2/asyncfilters.py @@ -1,31 +1,57 @@ +import typing +import typing as t from functools import wraps +from itertools import groupby from . import filters from .asyncsupport import auto_aiter from .asyncsupport import auto_await +if t.TYPE_CHECKING: + from .environment import Environment + from .nodes import EvalContext + from .runtime import Context + from .runtime import Undefined -async def auto_to_seq(value): + V = t.TypeVar("V") + + +async def auto_to_seq( + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", +) -> "t.List[V]": seq = [] + if hasattr(value, "__aiter__"): - async for item in value: + async for item in t.cast(t.AsyncIterable, value): seq.append(item) else: - for item in value: + for item in t.cast(t.Iterable, value): seq.append(item) + return seq -async def async_select_or_reject(args, kwargs, modfunc, lookup_attr): - seq, func = filters.prepare_select_or_reject(args, kwargs, modfunc, lookup_attr) - if seq: - async for item in auto_aiter(seq): +async def async_select_or_reject( + context: "Context", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + args: t.Tuple, + kwargs: t.Dict[str, t.Any], + modfunc: t.Callable[[t.Any], t.Any], + lookup_attr: bool, +) -> "t.AsyncIterator[V]": + if value: + func = filters.prepare_select_or_reject( + context, args, kwargs, modfunc, lookup_attr + ) + + async for item in auto_aiter(value): if func(item): yield item def dualfilter(normal_filter, async_filter): wrap_evalctx = False + if getattr(normal_filter, "environmentfilter", False) is True: def is_async(args): @@ -43,17 +69,19 @@ def dualfilter(normal_filter, async_filter): @wraps(normal_filter) def wrapper(*args, **kwargs): b = is_async(args) + if wrap_evalctx: args = args[1:] + if b: return async_filter(*args, **kwargs) + return normal_filter(*args, **kwargs) if wrap_evalctx: wrapper.evalcontextfilter = True wrapper.asyncfiltervariant = True - return wrapper @@ -65,65 +93,123 @@ def asyncfiltervariant(original): @asyncfiltervariant(filters.do_first) -async def do_first(environment, seq): +async def do_first( + environment: "Environment", seq: "t.Union[t.AsyncIterable[V], t.Iterable[V]]" +) -> "t.Union[V, Undefined]": try: - return await auto_aiter(seq).__anext__() + return t.cast("V", await auto_aiter(seq).__anext__()) except StopAsyncIteration: return environment.undefined("No first item, sequence was empty.") @asyncfiltervariant(filters.do_groupby) -async def do_groupby(environment, value, attribute): +async def do_groupby( + environment: "Environment", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + attribute: t.Union[str, int], +) -> "t.List[t.Tuple[t.Any, t.List[V]]]": expr = filters.make_attrgetter(environment, attribute) return [ filters._GroupTuple(key, await auto_to_seq(values)) - for key, values in filters.groupby( - sorted(await auto_to_seq(value), key=expr), expr - ) + for key, values in groupby(sorted(await auto_to_seq(value), key=expr), expr) ] @asyncfiltervariant(filters.do_join) -async def do_join(eval_ctx, value, d="", attribute=None): +async def do_join( + eval_ctx: "EvalContext", + value: t.Union[t.AsyncIterable, t.Iterable], + d: str = "", + attribute: t.Optional[t.Union[str, int]] = None, +) -> str: return filters.do_join(eval_ctx, await auto_to_seq(value), d, attribute) @asyncfiltervariant(filters.do_list) -async def do_list(value): +async def do_list(value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]") -> "t.List[V]": return await auto_to_seq(value) @asyncfiltervariant(filters.do_reject) -async def do_reject(*args, **kwargs): - return async_select_or_reject(args, kwargs, lambda x: not x, False) +async def do_reject( + context: "Context", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + *args: t.Any, + **kwargs: t.Any, +) -> "t.AsyncIterator[V]": + return async_select_or_reject(context, value, args, kwargs, lambda x: not x, False) @asyncfiltervariant(filters.do_rejectattr) -async def do_rejectattr(*args, **kwargs): - return async_select_or_reject(args, kwargs, lambda x: not x, True) +async def do_rejectattr( + context: "Context", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + *args: t.Any, + **kwargs: t.Any, +) -> "t.AsyncIterator[V]": + return async_select_or_reject(context, value, args, kwargs, lambda x: not x, True) @asyncfiltervariant(filters.do_select) -async def do_select(*args, **kwargs): - return async_select_or_reject(args, kwargs, lambda x: x, False) +async def do_select( + context: "Context", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + *args: t.Any, + **kwargs: t.Any, +) -> "t.AsyncIterator[V]": + return async_select_or_reject(context, value, args, kwargs, lambda x: x, False) @asyncfiltervariant(filters.do_selectattr) -async def do_selectattr(*args, **kwargs): - return async_select_or_reject(args, kwargs, lambda x: x, True) +async def do_selectattr( + context: "Context", + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + *args: t.Any, + **kwargs: t.Any, +) -> "t.AsyncIterator[V]": + return async_select_or_reject(context, value, args, kwargs, lambda x: x, True) + + +@typing.overload +def do_map( + context: "Context", + value: t.Union[t.AsyncIterable, t.Iterable], + name: str, + *args: t.Any, + **kwargs: t.Any, +) -> t.Iterable: + ... + + +@typing.overload +def do_map( + context: "Context", + value: t.Union[t.AsyncIterable, t.Iterable], + *, + attribute: str = ..., + default: t.Optional[t.Any] = None, +) -> t.Iterable: + ... @asyncfiltervariant(filters.do_map) -async def do_map(*args, **kwargs): - seq, func = filters.prepare_map(args, kwargs) - if seq: - async for item in auto_aiter(seq): +async def do_map(context, value, *args, **kwargs): + if value: + func = filters.prepare_map(context, args, kwargs) + + async for item in auto_aiter(value): yield await auto_await(func(item)) @asyncfiltervariant(filters.do_sum) -async def do_sum(environment, iterable, attribute=None, start=0): +async def do_sum( + environment: "Environment", + iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + attribute: t.Optional[t.Union[str, int]] = None, + start: "V" = 0, # type: ignore +) -> "V": rv = start + if attribute is not None: func = filters.make_attrgetter(environment, attribute) else: @@ -133,11 +219,16 @@ async def do_sum(environment, iterable, attribute=None, start=0): async for item in auto_aiter(iterable): rv += func(item) + return rv @asyncfiltervariant(filters.do_slice) -async def do_slice(value, slices, fill_with=None): +async def do_slice( + value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", + slices: int, + fill_with: t.Optional[t.Any] = None, +) -> "t.Iterator[t.List[V]]": return filters.do_slice(await auto_to_seq(value), slices, fill_with) diff --git a/src/jinja2/environment.py b/src/jinja2/environment.py index ad7190f..6211340 100644 --- a/src/jinja2/environment.py +++ b/src/jinja2/environment.py @@ -319,7 +319,7 @@ class Environment: self.keep_trailing_newline = keep_trailing_newline # runtime information - self.undefined = undefined + self.undefined: t.Type[Undefined] = undefined self.optimized = optimized self.finalize = finalize self.autoescape = autoescape diff --git a/src/jinja2/filters.py b/src/jinja2/filters.py index 7c95dce..356248a 100644 --- a/src/jinja2/filters.py +++ b/src/jinja2/filters.py @@ -2,8 +2,9 @@ import math import random import re +import typing +import typing as t from collections import abc -from collections import namedtuple from itertools import chain from itertools import groupby @@ -18,53 +19,73 @@ from .utils import pformat from .utils import url_quote from .utils import urlize -_word_re = re.compile(r"\w+") -_word_beginning_split_re = re.compile(r"([-\s({\[<]+)") +if t.TYPE_CHECKING: + import typing_extensions as te + from .environment import Environment + from .nodes import EvalContext + from .runtime import Context + from .sandbox import SandboxedEnvironment # noqa: F401 + + K = t.TypeVar("K") + V = t.TypeVar("V") + F = t.TypeVar("F", bound=t.Callable[..., t.Any]) + + class HasHTML(te.Protocol): + def __html__(self) -> str: + pass -def contextfilter(f): +def contextfilter(f: "F") -> "F": """Decorator for marking context dependent filters. The current :class:`Context` will be passed as first argument. """ - f.contextfilter = True + f.contextfilter = True # type: ignore return f -def evalcontextfilter(f): +def evalcontextfilter(f: "F") -> "F": """Decorator for marking eval-context dependent filters. An eval context object is passed as first argument. For more information about the eval context, see :ref:`eval-context`. .. versionadded:: 2.4 """ - f.evalcontextfilter = True + f.evalcontextfilter = True # type: ignore return f -def environmentfilter(f): +def environmentfilter(f: "F") -> "F": """Decorator for marking environment dependent filters. The current :class:`Environment` is passed to the filter as first argument. """ - f.environmentfilter = True + f.environmentfilter = True # type: ignore return f -def ignore_case(value): +def ignore_case(value: "V") -> "V": """For use as a postprocessor for :func:`make_attrgetter`. Converts strings to lowercase and returns other types as-is.""" - return value.lower() if isinstance(value, str) else value + if isinstance(value, str): + return t.cast("V", value.lower()) + + return value -def make_attrgetter(environment, attribute, postprocess=None, default=None): +def make_attrgetter( + environment: "Environment", + attribute: t.Optional[t.Union[str, int]], + postprocess: t.Optional[t.Callable[[t.Any], t.Any]] = None, + default: t.Optional[t.Any] = None, +) -> t.Callable[[t.Any], t.Any]: """Returns a callable that looks up the given attribute from a passed object with the rules of the environment. Dots are allowed to access attributes of attributes. Integer parts in paths are looked up as integers. """ - attribute = _prepare_attribute_parts(attribute) + parts = _prepare_attribute_parts(attribute) - def attrgetter(item): - for part in attribute: + def attrgetter(item: t.Any) -> t.Any: + for part in parts: item = environment.getitem(item, part) if default is not None and isinstance(item, Undefined): @@ -78,7 +99,11 @@ def make_attrgetter(environment, attribute, postprocess=None, default=None): return attrgetter -def make_multi_attrgetter(environment, attribute, postprocess=None): +def make_multi_attrgetter( + environment: "Environment", + attribute: t.Optional[t.Union[str, int]], + postprocess: t.Optional[t.Callable[[t.Any], t.Any]] = None, +) -> t.Callable[[t.Any], t.List[t.Any]]: """Returns a callable that looks up the given comma separated attributes from a passed object with the rules of the environment. Dots are allowed to access attributes of each attribute. Integer @@ -89,17 +114,19 @@ def make_multi_attrgetter(environment, attribute, postprocess=None): Examples of attribute: "attr1,attr2", "attr1.inner1.0,attr2.inner2.0", etc. """ - attribute_parts = ( - attribute.split(",") if isinstance(attribute, str) else [attribute] - ) - attribute = [ - _prepare_attribute_parts(attribute_part) for attribute_part in attribute_parts - ] + if isinstance(attribute, str): + split: t.Sequence[t.Union[str, int, None]] = attribute.split(",") + else: + split = [attribute] + + parts = [_prepare_attribute_parts(item) for item in split] + + def attrgetter(item: t.Any) -> t.List[t.Any]: + items = [None] * len(parts) - def attrgetter(item): - items = [None] * len(attribute) - for i, attribute_part in enumerate(attribute): + for i, attribute_part in enumerate(parts): item_i = item + for part in attribute_part: item_i = environment.getitem(item_i, part) @@ -107,28 +134,35 @@ def make_multi_attrgetter(environment, attribute, postprocess=None): item_i = postprocess(item_i) items[i] = item_i + return items return attrgetter -def _prepare_attribute_parts(attr): +def _prepare_attribute_parts( + attr: t.Optional[t.Union[str, int]] +) -> t.List[t.Union[str, int]]: if attr is None: return [] - elif isinstance(attr, str): + + if isinstance(attr, str): return [int(x) if x.isdigit() else x for x in attr.split(".")] - else: - return [attr] + + return [attr] -def do_forceescape(value): +def do_forceescape(value: "t.Union[str, HasHTML]") -> Markup: """Enforce HTML escaping. This will probably double escape variables.""" if hasattr(value, "__html__"): - value = value.__html__() + value = t.cast("HasHTML", value).__html__() + return escape(str(value)) -def do_urlencode(value): +def do_urlencode( + value: t.Union[str, t.Mapping[str, t.Any], t.Iterable[t.Tuple[str, t.Any]]] +) -> str: """Quote data for use in a URL path or query using UTF-8. Basic wrapper around :func:`urllib.parse.quote` when given a @@ -148,9 +182,9 @@ def do_urlencode(value): return url_quote(value) if isinstance(value, dict): - items = value.items() + items: t.Iterable[t.Tuple[str, t.Any]] = value.items() else: - items = iter(value) + items = value # type: ignore return "&".join( f"{url_quote(k, for_qs=True)}={url_quote(v, for_qs=True)}" for k, v in items @@ -158,7 +192,9 @@ def do_urlencode(value): @evalcontextfilter -def do_replace(eval_ctx, s, old, new, count=None): +def do_replace( + eval_ctx: "EvalContext", s: str, old: str, new: str, count: t.Optional[int] = None +) -> str: """Return a copy of the value with all occurrences of a substring replaced with a new one. The first argument is the substring that should be replaced, the second is the replacement string. @@ -175,8 +211,10 @@ def do_replace(eval_ctx, s, old, new, count=None): """ if count is None: count = -1 + if not eval_ctx.autoescape: return str(s).replace(str(old), str(new), count) + if ( hasattr(old, "__html__") or hasattr(new, "__html__") @@ -185,21 +223,24 @@ def do_replace(eval_ctx, s, old, new, count=None): s = escape(s) else: s = soft_str(s) + return s.replace(soft_str(old), soft_str(new), count) -def do_upper(s): +def do_upper(s: str) -> str: """Convert a value to uppercase.""" return soft_str(s).upper() -def do_lower(s): +def do_lower(s: str) -> str: """Convert a value to lowercase.""" return soft_str(s).lower() @evalcontextfilter -def do_xmlattr(_eval_ctx, d, autospace=True): +def do_xmlattr( + eval_ctx: "EvalContext", d: t.Mapping[str, t.Any], autospace: bool = True +) -> str: """Create an SGML/XML attribute string based on the items in a dict. All values that are neither `none` nor `undefined` are automatically escaped: @@ -227,21 +268,27 @@ def do_xmlattr(_eval_ctx, d, autospace=True): for key, value in d.items() if value is not None and not isinstance(value, Undefined) ) + if autospace and rv: rv = " " + rv - if _eval_ctx.autoescape: + + if eval_ctx.autoescape: rv = Markup(rv) + return rv -def do_capitalize(s): +def do_capitalize(s: str) -> str: """Capitalize a value. The first character will be uppercase, all others lowercase. """ return soft_str(s).capitalize() -def do_title(s): +_word_beginning_split_re = re.compile(r"([-\s({\[<]+)") + + +def do_title(s: str) -> str: """Return a titlecased version of the value. I.e. words will start with uppercase letters, all remaining characters are lowercase. """ @@ -254,7 +301,12 @@ def do_title(s): ) -def do_dictsort(value, case_sensitive=False, by="key", reverse=False): +def do_dictsort( + value: "t.Mapping[K, V]", + case_sensitive: bool = False, + by: 'te.Literal["key", "value"]' = "key", + reverse: bool = False, +) -> "t.List[t.Tuple[K, V]]": """Sort a dict and yield (key, value) pairs. Because python dicts are unsorted you may want to use this function to order them by either key or value: @@ -292,7 +344,13 @@ def do_dictsort(value, case_sensitive=False, by="key", reverse=False): @environmentfilter -def do_sort(environment, value, reverse=False, case_sensitive=False, attribute=None): +def do_sort( + environment: "Environment", + value: "t.Iterable[V]", + reverse: bool = False, + case_sensitive: bool = False, + attribute: t.Optional[t.Union[str, int]] = None, +) -> "t.List[V]": """Sort an iterable using Python's :func:`sorted`. .. sourcecode:: jinja @@ -342,7 +400,12 @@ def do_sort(environment, value, reverse=False, case_sensitive=False, attribute=N @environmentfilter -def do_unique(environment, value, case_sensitive=False, attribute=None): +def do_unique( + environment: "Environment", + value: "t.Iterable[V]", + case_sensitive: bool = False, + attribute: t.Optional[t.Union[str, int]] = None, +) -> "t.Iterator[V]": """Returns a list of unique items from the given iterable. .. sourcecode:: jinja @@ -369,7 +432,13 @@ def do_unique(environment, value, case_sensitive=False, attribute=None): yield item -def _min_or_max(environment, value, func, case_sensitive, attribute): +def _min_or_max( + environment: "Environment", + value: "t.Iterable[V]", + func: "t.Callable[..., V]", + case_sensitive: bool, + attribute: t.Optional[t.Union[str, int]], +) -> "t.Union[V, Undefined]": it = iter(value) try: @@ -384,7 +453,12 @@ def _min_or_max(environment, value, func, case_sensitive, attribute): @environmentfilter -def do_min(environment, value, case_sensitive=False, attribute=None): +def do_min( + environment: "Environment", + value: "t.Iterable[V]", + case_sensitive: bool = False, + attribute: t.Optional[t.Union[str, int]] = None, +) -> "t.Union[V, Undefined]": """Return the smallest item from the sequence. .. sourcecode:: jinja @@ -399,7 +473,12 @@ def do_min(environment, value, case_sensitive=False, attribute=None): @environmentfilter -def do_max(environment, value, case_sensitive=False, attribute=None): +def do_max( + environment: "Environment", + value: "t.Iterable[V]", + case_sensitive: bool = False, + attribute: t.Optional[t.Union[str, int]] = None, +) -> "t.Union[V, Undefined]": """Return the largest item from the sequence. .. sourcecode:: jinja @@ -413,7 +492,11 @@ def do_max(environment, value, case_sensitive=False, attribute=None): return _min_or_max(environment, value, max, case_sensitive, attribute) -def do_default(value, default_value="", boolean=False): +def do_default( + value: "V", + default_value: "V" = "", # type: ignore + boolean: bool = False, +) -> "V": """If the value is undefined it will return the passed default value, otherwise the value of the variable: @@ -438,11 +521,17 @@ def do_default(value, default_value="", boolean=False): """ if isinstance(value, Undefined) or (boolean and not value): return default_value + return value @evalcontextfilter -def do_join(eval_ctx, value, d="", attribute=None): +def do_join( + eval_ctx: "EvalContext", + value: t.Iterable, + d: str = "", + attribute: t.Optional[t.Union[str, int]] = None, +) -> str: """Return a string which is the concatenation of the strings in the sequence. The separator between elements is an empty string per default, you can define it with the optional parameter: @@ -476,28 +565,33 @@ def do_join(eval_ctx, value, d="", attribute=None): if not hasattr(d, "__html__"): value = list(value) do_escape = False + for idx, item in enumerate(value): if hasattr(item, "__html__"): do_escape = True else: value[idx] = str(item) + if do_escape: d = escape(d) else: d = str(d) + return d.join(value) # no html involved, to normal joining return soft_str(d).join(map(soft_str, value)) -def do_center(value, width=80): +def do_center(value: str, width: int = 80) -> str: """Centers the value in a field of a given width.""" - return str(value).center(width) + return soft_str(value).center(width) @environmentfilter -def do_first(environment, seq): +def do_first( + environment: "Environment", seq: "t.Iterable[V]" +) -> "t.Union[V, Undefined]": """Return the first item of a sequence.""" try: return next(iter(seq)) @@ -506,9 +600,10 @@ def do_first(environment, seq): @environmentfilter -def do_last(environment, seq): - """ - Return the last item of a sequence. +def do_last( + environment: "Environment", seq: "t.Reversible[V]" +) -> "t.Union[V, Undefined]": + """Return the last item of a sequence. Note: Does not work with generators. You may want to explicitly convert it to a list: @@ -524,7 +619,7 @@ def do_last(environment, seq): @contextfilter -def do_random(context, seq): +def do_random(context: "Context", seq: "t.Sequence[V]") -> "t.Union[V, Undefined]": """Return a random item from the sequence.""" try: return random.choice(seq) @@ -532,7 +627,7 @@ def do_random(context, seq): return context.environment.undefined("No random item, sequence was empty.") -def do_filesizeformat(value, binary=False): +def do_filesizeformat(value: t.Union[str, float, int], binary: bool = False) -> str: """Format the value like a 'human-readable' file size (i.e. 13 kB, 4.1 MB, 102 Bytes, etc). Per default decimal prefixes are used (Mega, Giga, etc.), if the second parameter is set to `True` the binary @@ -550,6 +645,7 @@ def do_filesizeformat(value, binary=False): ("ZiB" if binary else "ZB"), ("YiB" if binary else "YB"), ] + if bytes == 1: return "1 Byte" elif bytes < base: @@ -557,14 +653,16 @@ def do_filesizeformat(value, binary=False): else: for i, prefix in enumerate(prefixes): unit = base ** (i + 2) + if bytes < unit: return f"{base * bytes / unit:.1f} {prefix}" + return f"{base * bytes / unit:.1f} {prefix}" -def do_pprint(value): +def do_pprint(value: t.Any) -> str: """Pretty print a variable. Useful for debugging.""" - return pformat(value) + return t.cast(str, pformat(value)) _uri_scheme_re = re.compile(r"^([\w.+-]{2,}:(/){0,2})$") @@ -572,14 +670,14 @@ _uri_scheme_re = re.compile(r"^([\w.+-]{2,}:(/){0,2})$") @evalcontextfilter def do_urlize( - eval_ctx, - value, - trim_url_limit=None, - nofollow=False, - target=None, - rel=None, - extra_schemes=None, -): + eval_ctx: "EvalContext", + value: str, + trim_url_limit: t.Optional[int] = None, + nofollow: bool = False, + target: t.Optional[str] = None, + rel: t.Optional[str] = None, + extra_schemes: t.Optional[t.Iterable[str]] = None, +) -> str: """Convert URLs in text into clickable links. This may not recognize links in some situations. Usually, a more @@ -650,7 +748,7 @@ def do_urlize( return rv -def do_indent(s, width=4, first=False, blank=False): +def do_indent(s: str, width: int = 4, first: bool = False, blank: bool = False) -> str: """Return a copy of the string with each line indented by 4 spaces. The first line and blank lines are not indented by default. @@ -690,7 +788,14 @@ def do_indent(s, width=4, first=False, blank=False): @environmentfilter -def do_truncate(env, s, length=255, killwords=False, end="...", leeway=None): +def do_truncate( + env: "Environment", + s: str, + length: int = 255, + killwords: bool = False, + end: str = "...", + leeway: t.Optional[int] = None, +) -> str: """Return a truncated copy of the string. The length is specified with the first parameter which defaults to ``255``. If the second parameter is ``true`` the filter will cut the text at length. Otherwise @@ -716,25 +821,29 @@ def do_truncate(env, s, length=255, killwords=False, end="...", leeway=None): """ if leeway is None: leeway = env.policies["truncate.leeway"] + assert length >= len(end), f"expected length >= {len(end)}, got {length}" assert leeway >= 0, f"expected leeway >= 0, got {leeway}" + if len(s) <= length + leeway: return s + if killwords: return s[: length - len(end)] + end + result = s[: length - len(end)].rsplit(" ", 1)[0] return result + end @environmentfilter def do_wordwrap( - environment, - s, - width=79, - break_long_words=True, - wrapstring=None, - break_on_hyphens=True, -): + environment: "Environment", + s: str, + width: int = 79, + break_long_words: bool = True, + wrapstring: t.Optional[str] = None, + break_on_hyphens: bool = True, +) -> str: """Wrap a string to the given width. Existing newlines are treated as paragraphs to be wrapped separately. @@ -756,10 +865,9 @@ def do_wordwrap( .. versionchanged:: 2.7 Added the ``wrapstring`` parameter. """ - import textwrap - if not wrapstring: + if wrapstring is None: wrapstring = environment.newline_sequence # textwrap.wrap doesn't consider existing newlines when wrapping. @@ -783,12 +891,15 @@ def do_wordwrap( ) -def do_wordcount(s): +_word_re = re.compile(r"\w+") + + +def do_wordcount(s: str) -> int: """Count the words in that string.""" return len(_word_re.findall(soft_str(s))) -def do_int(value, default=0, base=10): +def do_int(value: t.Any, default: int = 0, base: int = 10) -> int: """Convert the value into an integer. If the conversion doesn't work it will return ``0``. You can override this default using the first parameter. You @@ -800,6 +911,7 @@ def do_int(value, default=0, base=10): try: if isinstance(value, str): return int(value, base) + return int(value) except (TypeError, ValueError): # this quirk is necessary so that "42.23"|int gives 42. @@ -809,7 +921,7 @@ def do_int(value, default=0, base=10): return default -def do_float(value, default=0.0): +def do_float(value: t.Any, default: float = 0.0) -> float: """Convert the value into a floating point number. If the conversion doesn't work it will return ``0.0``. You can override this default using the first parameter. @@ -820,7 +932,7 @@ def do_float(value, default=0.0): return default -def do_format(value, *args, **kwargs): +def do_format(value: str, *args: t.Any, **kwargs: t.Any) -> str: """Apply the given values to a `printf-style`_ format string, like ``string % values``. @@ -844,22 +956,26 @@ def do_format(value, *args, **kwargs): raise FilterArgumentError( "can't handle positional and keyword arguments at the same time" ) + return soft_str(value) % (kwargs or args) -def do_trim(value, chars=None): +def do_trim(value: str, chars: t.Optional[str] = None) -> str: """Strip leading and trailing characters, by default whitespace.""" return soft_str(value).strip(chars) -def do_striptags(value): +def do_striptags(value: "t.Union[str, HasHTML]") -> str: """Strip SGML/XML tags and replace adjacent whitespace by one space.""" if hasattr(value, "__html__"): - value = value.__html__() + value = t.cast("HasHTML", value).__html__() + return Markup(str(value)).striptags() -def do_slice(value, slices, fill_with=None): +def do_slice( + value: "t.Collection[V]", slices: int, fill_with: "t.Optional[V]" = None +) -> "t.Iterator[t.List[V]]": """Slice an iterator and return a list of lists containing those items. Useful if you want to create a div containing three ul tags that represent columns: @@ -884,18 +1000,25 @@ def do_slice(value, slices, fill_with=None): items_per_slice = length // slices slices_with_extra = length % slices offset = 0 + for slice_number in range(slices): start = offset + slice_number * items_per_slice + if slice_number < slices_with_extra: offset += 1 + end = offset + (slice_number + 1) * items_per_slice tmp = seq[start:end] + if fill_with is not None and slice_number >= slices_with_extra: tmp.append(fill_with) + yield tmp -def do_batch(value, linecount, fill_with=None): +def do_batch( + value: "t.Iterable[V]", linecount: int, fill_with: "t.Optional[V]" = None +) -> "t.Iterator[t.List[V]]": """ A filter that batches items. It works pretty much like `slice` just the other way round. It returns a list of lists with the @@ -914,19 +1037,27 @@ def do_batch(value, linecount, fill_with=None): {%- endfor %} </table> """ - tmp = [] + tmp: "t.List[V]" = [] + for item in value: if len(tmp) == linecount: yield tmp tmp = [] + tmp.append(item) + if tmp: if fill_with is not None and len(tmp) < linecount: tmp += [fill_with] * (linecount - len(tmp)) + yield tmp -def do_round(value, precision=0, method="common"): +def do_round( + value: float, + precision: int = 0, + method: 'te.Literal["common", "ceil", "floor"]' = "common", +) -> float: """Round the number to a given precision. The first parameter specifies the precision (default is ``0``), the second the rounding method: @@ -954,24 +1085,31 @@ def do_round(value, precision=0, method="common"): """ if method not in {"common", "ceil", "floor"}: raise FilterArgumentError("method must be common, ceil or floor") + if method == "common": return round(value, precision) + func = getattr(math, method) - return func(value * (10 ** precision)) / (10 ** precision) + return t.cast(float, func(value * (10 ** precision)) / (10 ** precision)) + +class _GroupTuple(t.NamedTuple): + grouper: t.Any + list: t.List -# Use a regular tuple repr here. This is what we did in the past and we -# really want to hide this custom type as much as possible. In particular -# we do not want to accidentally expose an auto generated repr in case -# people start to print this out in comments or something similar for -# debugging. -_GroupTuple = namedtuple("_GroupTuple", ["grouper", "list"]) -_GroupTuple.__repr__ = tuple.__repr__ # type: ignore -_GroupTuple.__str__ = tuple.__str__ # type: ignore + # Use the regular tuple repr to hide this subclass if users print + # out the value during debugging. + def __repr__(self): + return tuple.__repr__(self) + + def __str__(self): + return tuple.__str__(self) @environmentfilter -def do_groupby(environment, value, attribute): +def do_groupby( + environment: "Environment", value: "t.Iterable[V]", attribute: t.Union[str, int] +) -> "t.List[t.Tuple[t.Any, t.List[V]]]": """Group a sequence of objects by an attribute using Python's :func:`itertools.groupby`. The attribute can use dot notation for nested access, like ``"address.city"``. Unlike Python's ``groupby``, @@ -1013,7 +1151,12 @@ def do_groupby(environment, value, attribute): @environmentfilter -def do_sum(environment, iterable, attribute=None, start=0): +def do_sum( + environment: "Environment", + iterable: "t.Iterable[V]", + attribute: t.Optional[t.Union[str, int]] = None, + start: "V" = 0, # type: ignore +) -> "V": """Returns the sum of a sequence of numbers plus the value of parameter 'start' (which defaults to 0). When the sequence is empty it returns start. @@ -1030,34 +1173,46 @@ def do_sum(environment, iterable, attribute=None, start=0): """ if attribute is not None: iterable = map(make_attrgetter(environment, attribute), iterable) + return sum(iterable, start) -def do_list(value): +def do_list(value: "t.Iterable[V]") -> "t.List[V]": """Convert the value into a list. If it was a string the returned list will be a list of characters. """ return list(value) -def do_mark_safe(value): +def do_mark_safe(value: str) -> Markup: """Mark the value as safe which means that in an environment with automatic escaping enabled this variable will not be escaped. """ return Markup(value) -def do_mark_unsafe(value): +def do_mark_unsafe(value: str) -> str: """Mark a value as unsafe. This is the reverse operation for :func:`safe`.""" return str(value) +@typing.overload +def do_reverse(value: str) -> str: + ... + + +@typing.overload +def do_reverse(value: "t.Iterable[V]") -> "t.Iterable[V]": + ... + + def do_reverse(value): """Reverse the object or return an iterator that iterates over it the other way round. """ if isinstance(value, str): return value[::-1] + try: return reversed(value) except TypeError: @@ -1070,7 +1225,9 @@ def do_reverse(value): @environmentfilter -def do_attr(environment, obj, name): +def do_attr( + environment: "Environment", obj: t.Any, name: str +) -> t.Union[Undefined, t.Any]: """Get an attribute of an object. ``foo|attr("bar")`` works like ``foo.bar`` just that always an attribute is returned and items are not looked up. @@ -1087,16 +1244,37 @@ def do_attr(environment, obj, name): except AttributeError: pass else: - if environment.sandboxed and not environment.is_safe_attribute( - obj, name, value - ): - return environment.unsafe_undefined(obj, name) + if environment.sandboxed: + environment = t.cast("SandboxedEnvironment", environment) + + if not environment.is_safe_attribute(obj, name, value): + return environment.unsafe_undefined(obj, name) + return value + return environment.undefined(obj=obj, name=name) +@typing.overload +def do_map( + context: "Context", value: t.Iterable, name: str, *args: t.Any, **kwargs: t.Any +) -> t.Iterable: + ... + + +@typing.overload +def do_map( + context: "Context", + value: t.Iterable, + *, + attribute: str = ..., + default: t.Optional[t.Any] = None, +) -> t.Iterable: + ... + + @contextfilter -def do_map(*args, **kwargs): +def do_map(context, value, *args, **kwargs): """Applies a filter on a sequence of objects or looks up an attribute. This is useful when dealing with lists of objects but you are really only interested in a certain value of it. @@ -1136,14 +1314,17 @@ def do_map(*args, **kwargs): .. versionadded:: 2.7 """ - seq, func = prepare_map(args, kwargs) - if seq: - for item in seq: + if value: + func = prepare_map(context, args, kwargs) + + for item in value: yield func(item) @contextfilter -def do_select(*args, **kwargs): +def do_select( + context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any +) -> "t.Iterator[V]": """Filters a sequence of objects by applying a test to each object, and only selecting the objects with the test succeeding. @@ -1168,11 +1349,13 @@ def do_select(*args, **kwargs): .. versionadded:: 2.7 """ - return select_or_reject(args, kwargs, lambda x: x, False) + return select_or_reject(context, value, args, kwargs, lambda x: x, False) @contextfilter -def do_reject(*args, **kwargs): +def do_reject( + context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any +) -> "t.Iterator[V]": """Filters a sequence of objects by applying a test to each object, and rejecting the objects with the test succeeding. @@ -1192,11 +1375,13 @@ def do_reject(*args, **kwargs): .. versionadded:: 2.7 """ - return select_or_reject(args, kwargs, lambda x: not x, False) + return select_or_reject(context, value, args, kwargs, lambda x: not x, False) @contextfilter -def do_selectattr(*args, **kwargs): +def do_selectattr( + context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any +) -> "t.Iterator[V]": """Filters a sequence of objects by applying a test to the specified attribute of each object, and only selecting the objects with the test succeeding. @@ -1220,11 +1405,13 @@ def do_selectattr(*args, **kwargs): .. versionadded:: 2.7 """ - return select_or_reject(args, kwargs, lambda x: x, True) + return select_or_reject(context, value, args, kwargs, lambda x: x, True) @contextfilter -def do_rejectattr(*args, **kwargs): +def do_rejectattr( + context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any +) -> "t.Iterator[V]": """Filters a sequence of objects by applying a test to the specified attribute of each object, and rejecting the objects with the test succeeding. @@ -1246,11 +1433,13 @@ def do_rejectattr(*args, **kwargs): .. versionadded:: 2.7 """ - return select_or_reject(args, kwargs, lambda x: not x, True) + return select_or_reject(context, value, args, kwargs, lambda x: not x, True) @evalcontextfilter -def do_tojson(eval_ctx, value, indent=None): +def do_tojson( + eval_ctx: "EvalContext", value: t.Any, indent: t.Optional[int] = None +) -> Markup: """Serialize an object to a string of JSON, and mark it safe to render in HTML. This filter is only for use in HTML documents. @@ -1276,22 +1465,23 @@ def do_tojson(eval_ctx, value, indent=None): return htmlsafe_json_dumps(value, dumps=dumps, **kwargs) -def prepare_map(args, kwargs): - context = args[0] - seq = args[1] - - if len(args) == 2 and "attribute" in kwargs: +def prepare_map( + context: "Context", args: t.Tuple, kwargs: t.Dict[str, t.Any] +) -> t.Callable[[t.Any], t.Any]: + if not args and "attribute" in kwargs: attribute = kwargs.pop("attribute") default = kwargs.pop("default", None) + if kwargs: raise FilterArgumentError( f"Unexpected keyword argument {next(iter(kwargs))!r}" ) + func = make_attrgetter(context.environment, attribute, default=default) else: try: - name = args[2] - args = args[3:] + name = args[0] + args = args[1:] except LookupError: raise FilterArgumentError("map requires a filter argument") @@ -1300,17 +1490,22 @@ def prepare_map(args, kwargs): name, item, args, kwargs, context=context ) - return seq, func + return func -def prepare_select_or_reject(args, kwargs, modfunc, lookup_attr): - context = args[0] - seq = args[1] +def prepare_select_or_reject( + context: "Context", + args: t.Tuple, + kwargs: t.Dict[str, t.Any], + modfunc: t.Callable[[t.Any], t.Any], + lookup_attr: bool, +) -> t.Callable[[t.Any], t.Any]: if lookup_attr: try: - attr = args[2] + attr = args[0] except LookupError: raise FilterArgumentError("Missing parameter for attribute name") + transfunc = make_attrgetter(context.environment, attr) off = 1 else: @@ -1320,8 +1515,8 @@ def prepare_select_or_reject(args, kwargs, modfunc, lookup_attr): return x try: - name = args[2 + off] - args = args[3 + off :] + name = args[off] + args = args[1 + off :] def func(item): return context.environment.call_test(name, item, args, kwargs) @@ -1329,13 +1524,21 @@ def prepare_select_or_reject(args, kwargs, modfunc, lookup_attr): except LookupError: func = bool - return seq, lambda item: modfunc(func(transfunc(item))) + return lambda item: modfunc(func(transfunc(item))) + +def select_or_reject( + context: "Context", + value: "t.Iterable[V]", + args: t.Tuple, + kwargs: t.Dict[str, t.Any], + modfunc: t.Callable[[t.Any], t.Any], + lookup_attr: bool, +) -> "t.Iterator[V]": + if value: + func = prepare_select_or_reject(context, args, kwargs, modfunc, lookup_attr) -def select_or_reject(args, kwargs, modfunc, lookup_attr): - seq, func = prepare_select_or_reject(args, kwargs, modfunc, lookup_attr) - if seq: - for item in seq: + for item in value: if func(item): yield item diff --git a/src/jinja2/runtime.py b/src/jinja2/runtime.py index 9468a37..557460f 100644 --- a/src/jinja2/runtime.py +++ b/src/jinja2/runtime.py @@ -20,6 +20,9 @@ from .utils import missing from .utils import Namespace # noqa: F401 from .utils import object_type_repr +if t.TYPE_CHECKING: + from .environment import Environment + # these variables are exported to the template runtime exported = [ "LoopContext", @@ -186,7 +189,7 @@ class Context(metaclass=ContextMeta): def __init__(self, environment, parent, name, blocks, globals=None): self.parent = parent self.vars = {} - self.environment = environment + self.environment: "Environment" = environment self.eval_ctx = EvalContext(self.environment, name) self.exported_vars = set() self.name = name diff --git a/src/jinja2/tests.py b/src/jinja2/tests.py index 62b8322..229f16a 100644 --- a/src/jinja2/tests.py +++ b/src/jinja2/tests.py @@ -1,33 +1,32 @@ """Built-in template tests used with the ``is`` operator.""" import operator -import re +import typing as t from collections import abc from numbers import Number from .runtime import Undefined from .utils import environmentfunction -number_re = re.compile(r"^-?\d+(\.\d+)?$") -regex_type = type(number_re) -test_callable = callable +if t.TYPE_CHECKING: + from .environment import Environment -def test_odd(value): +def test_odd(value: int) -> bool: """Return true if the variable is odd.""" return value % 2 == 1 -def test_even(value): +def test_even(value: int) -> bool: """Return true if the variable is even.""" return value % 2 == 0 -def test_divisibleby(value, num): +def test_divisibleby(value: int, num: int) -> bool: """Check if a variable is divisible by a number.""" return value % num == 0 -def test_defined(value): +def test_defined(value: t.Any) -> bool: """Return true if the variable is defined: .. sourcecode:: jinja @@ -44,13 +43,13 @@ def test_defined(value): return not isinstance(value, Undefined) -def test_undefined(value): +def test_undefined(value: t.Any) -> bool: """Like :func:`defined` but the other way round.""" return isinstance(value, Undefined) @environmentfunction -def test_filter(env, value): +def test_filter(env: "Environment", value: str) -> bool: """Check if a filter exists by name. Useful if a filter may be optionally available. @@ -68,7 +67,7 @@ def test_filter(env, value): @environmentfunction -def test_test(env, value): +def test_test(env: "Environment", value: str) -> bool: """Check if a test exists by name. Useful if a test may be optionally available. @@ -89,12 +88,12 @@ def test_test(env, value): return value in env.tests -def test_none(value): +def test_none(value: t.Any) -> bool: """Return true if the variable is none.""" return value is None -def test_boolean(value): +def test_boolean(value: t.Any) -> bool: """Return true if the object is a boolean value. .. versionadded:: 2.11 @@ -102,7 +101,7 @@ def test_boolean(value): return value is True or value is False -def test_false(value): +def test_false(value: t.Any) -> bool: """Return true if the object is False. .. versionadded:: 2.11 @@ -110,7 +109,7 @@ def test_false(value): return value is False -def test_true(value): +def test_true(value: t.Any) -> bool: """Return true if the object is True. .. versionadded:: 2.11 @@ -119,7 +118,7 @@ def test_true(value): # NOTE: The existing 'number' test matches booleans and floats -def test_integer(value): +def test_integer(value: t.Any) -> bool: """Return true if the object is an integer. .. versionadded:: 2.11 @@ -128,7 +127,7 @@ def test_integer(value): # NOTE: The existing 'number' test matches booleans and integers -def test_float(value): +def test_float(value: t.Any) -> bool: """Return true if the object is a float. .. versionadded:: 2.11 @@ -136,22 +135,22 @@ def test_float(value): return isinstance(value, float) -def test_lower(value): +def test_lower(value: str) -> bool: """Return true if the variable is lowercased.""" return str(value).islower() -def test_upper(value): +def test_upper(value: str) -> bool: """Return true if the variable is uppercased.""" return str(value).isupper() -def test_string(value): +def test_string(value: t.Any) -> bool: """Return true if the object is a string.""" return isinstance(value, str) -def test_mapping(value): +def test_mapping(value: t.Any) -> bool: """Return true if the object is a mapping (dict etc.). .. versionadded:: 2.6 @@ -159,12 +158,12 @@ def test_mapping(value): return isinstance(value, abc.Mapping) -def test_number(value): +def test_number(value: t.Any) -> bool: """Return true if the variable is a number.""" return isinstance(value, Number) -def test_sequence(value): +def test_sequence(value: t.Any) -> bool: """Return true if the variable is a sequence. Sequences are variables that are iterable. """ @@ -173,10 +172,11 @@ def test_sequence(value): value.__getitem__ except Exception: return False + return True -def test_sameas(value, other): +def test_sameas(value: t.Any, other: t.Any) -> bool: """Check if an object points to the same memory address than another object: @@ -189,21 +189,22 @@ def test_sameas(value, other): return value is other -def test_iterable(value): +def test_iterable(value: t.Any) -> bool: """Check if it's possible to iterate over an object.""" try: iter(value) except TypeError: return False + return True -def test_escaped(value): +def test_escaped(value: t.Any) -> bool: """Check if the value is escaped.""" return hasattr(value, "__html__") -def test_in(value, seq): +def test_in(value: t.Any, seq: t.Container) -> bool: """Check if value is in seq. .. versionadded:: 2.10 @@ -232,7 +233,7 @@ TESTS = { "number": test_number, "sequence": test_sequence, "iterable": test_iterable, - "callable": test_callable, + "callable": callable, "sameas": test_sameas, "escaped": test_escaped, "in": test_in, diff --git a/src/jinja2/utils.py b/src/jinja2/utils.py index 42770b1..842410a 100644 --- a/src/jinja2/utils.py +++ b/src/jinja2/utils.py @@ -195,7 +195,13 @@ _http_re = re.compile( _email_re = re.compile(r"^\S+@\w[\w.-]*\.\w+$") -def urlize(text, trim_url_limit=None, rel=None, target=None, extra_schemes=None): +def urlize( + text: str, + trim_url_limit: t.Optional[int] = None, + rel: t.Optional[str] = None, + target: t.Optional[str] = None, + extra_schemes: t.Optional[t.Iterable[str]] = None, +) -> str: """Convert URLs in text into clickable links. This may not recognize links in some situations. Usually, a more @@ -359,12 +365,9 @@ def generate_lorem_ipsum(n=5, html=True, min=20, max=100): return Markup("\n".join(f"<p>{escape(x)}</p>" for x in result)) -def url_quote(obj, charset="utf-8", for_qs=False): +def url_quote(obj: t.Any, charset: str = "utf-8", for_qs: bool = False) -> str: """Quote a string for use in a URL using the given charset. - This function is misnamed, it is a wrapper around - :func:`urllib.parse.quote`. - :param obj: String or bytes to quote. Other types are converted to string then encoded to bytes using the given charset. :param charset: Encode text to bytes using this charset. @@ -610,7 +613,9 @@ def select_autoescape( return autoescape -def htmlsafe_json_dumps(obj, dumps=None, **kwargs): +def htmlsafe_json_dumps( + obj: t.Any, dumps: t.Optional[t.Callable[..., str]] = None, **kwargs: t.Any +) -> Markup: """Serialize an object to a string of JSON with :func:`json.dumps`, then replace HTML-unsafe characters with Unicode escapes and mark the result safe with :class:`~markupsafe.Markup`. |