summaryrefslogtreecommitdiff
path: root/Lib/typing.py
diff options
context:
space:
mode:
authorYurii Karabas <1998uriyyo@gmail.com>2020-11-17 04:23:19 +0200
committerGitHub <noreply@github.com>2020-11-16 18:23:19 -0800
commitf03d318ca42578e45405717aedd4ac26ea52aaed (patch)
tree97d64b427c5f171138a10391e088e77d241b9182 /Lib/typing.py
parentb0aba1fcdc3da952698d99aec2334faa79a8b68c (diff)
downloadcpython-git-f03d318ca42578e45405717aedd4ac26ea52aaed.tar.gz
bpo-42345: Fix three issues with typing.Literal parameters (GH-23294)
Literal equality no longer depends on the order of arguments. Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function. Add deduplication of `typing.Literal` arguments.
Diffstat (limited to 'Lib/typing.py')
-rw-r--r--Lib/typing.py99
1 files changed, 76 insertions, 23 deletions
diff --git a/Lib/typing.py b/Lib/typing.py
index 3fa97a4a15..d310b3dd58 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -202,6 +202,20 @@ def _check_generic(cls, parameters, elen):
f" actual {alen}, expected {elen}")
+def _deduplicate(params):
+ # Weed out strict duplicates, preserving the first of each occurrence.
+ all_params = set(params)
+ if len(all_params) < len(params):
+ new_params = []
+ for t in params:
+ if t in all_params:
+ new_params.append(t)
+ all_params.remove(t)
+ params = new_params
+ assert not all_params, all_params
+ return params
+
+
def _remove_dups_flatten(parameters):
"""An internal helper for Union creation and substitution: flatten Unions
among parameters, then remove duplicates.
@@ -215,38 +229,45 @@ def _remove_dups_flatten(parameters):
params.extend(p[1:])
else:
params.append(p)
- # Weed out strict duplicates, preserving the first of each occurrence.
- all_params = set(params)
- if len(all_params) < len(params):
- new_params = []
- for t in params:
- if t in all_params:
- new_params.append(t)
- all_params.remove(t)
- params = new_params
- assert not all_params, all_params
+
+ return tuple(_deduplicate(params))
+
+
+def _flatten_literal_params(parameters):
+ """An internal helper for Literal creation: flatten Literals among parameters"""
+ params = []
+ for p in parameters:
+ if isinstance(p, _LiteralGenericAlias):
+ params.extend(p.__args__)
+ else:
+ params.append(p)
return tuple(params)
_cleanups = []
-def _tp_cache(func):
+def _tp_cache(func=None, /, *, typed=False):
"""Internal wrapper caching __getitem__ of generic types with a fallback to
original function for non-hashable arguments.
"""
- cached = functools.lru_cache()(func)
- _cleanups.append(cached.cache_clear)
+ def decorator(func):
+ cached = functools.lru_cache(typed=typed)(func)
+ _cleanups.append(cached.cache_clear)
- @functools.wraps(func)
- def inner(*args, **kwds):
- try:
- return cached(*args, **kwds)
- except TypeError:
- pass # All real errors (not unhashable args) are raised below.
- return func(*args, **kwds)
- return inner
+ @functools.wraps(func)
+ def inner(*args, **kwds):
+ try:
+ return cached(*args, **kwds)
+ except TypeError:
+ pass # All real errors (not unhashable args) are raised below.
+ return func(*args, **kwds)
+ return inner
+ if func is not None:
+ return decorator(func)
+
+ return decorator
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
"""Evaluate all forward references in the given type t.
@@ -319,6 +340,13 @@ class _SpecialForm(_Final, _root=True):
def __getitem__(self, parameters):
return self._getitem(self, parameters)
+
+class _LiteralSpecialForm(_SpecialForm, _root=True):
+ @_tp_cache(typed=True)
+ def __getitem__(self, parameters):
+ return self._getitem(self, parameters)
+
+
@_SpecialForm
def Any(self, parameters):
"""Special type indicating an unconstrained type.
@@ -436,7 +464,7 @@ def Optional(self, parameters):
arg = _type_check(parameters, f"{self} requires a single type.")
return Union[arg, type(None)]
-@_SpecialForm
+@_LiteralSpecialForm
def Literal(self, parameters):
"""Special typing form to define literal types (a.k.a. value types).
@@ -460,7 +488,17 @@ def Literal(self, parameters):
"""
# There is no '_type_check' call because arguments to Literal[...] are
# values, not types.
- return _GenericAlias(self, parameters)
+ if not isinstance(parameters, tuple):
+ parameters = (parameters,)
+
+ parameters = _flatten_literal_params(parameters)
+
+ try:
+ parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
+ except TypeError: # unhashable parameters
+ pass
+
+ return _LiteralGenericAlias(self, parameters)
@_SpecialForm
@@ -930,6 +968,21 @@ class _UnionGenericAlias(_GenericAlias, _root=True):
return True
+def _value_and_type_iter(parameters):
+ return ((p, type(p)) for p in parameters)
+
+
+class _LiteralGenericAlias(_GenericAlias, _root=True):
+
+ def __eq__(self, other):
+ if not isinstance(other, _LiteralGenericAlias):
+ return NotImplemented
+
+ return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
+
+ def __hash__(self):
+ return hash(tuple(_value_and_type_iter(self.__args__)))
+
class Generic:
"""Abstract base class for generic types.