summaryrefslogtreecommitdiff
path: root/Lib/functools.py
diff options
context:
space:
mode:
authorYurii Karabas <1998uriyyo@gmail.com>2021-12-11 01:27:55 +0200
committerGitHub <noreply@github.com>2021-12-11 00:27:55 +0100
commit3cb357a2e6ac18ee98db5d450414e773744e3c76 (patch)
treecfe621315c5651e3b89199a2eb2411029706fbcb /Lib/functools.py
parent810c1769f1c24ed907bdf3cc1086db4e602a28ae (diff)
downloadcpython-git-3cb357a2e6ac18ee98db5d450414e773744e3c76.tar.gz
bpo-46014: Add ability to use typing.Union with singledispatch (GH-30017)
Diffstat (limited to 'Lib/functools.py')
-rw-r--r--Lib/functools.py35
1 files changed, 28 insertions, 7 deletions
diff --git a/Lib/functools.py b/Lib/functools.py
index 77ec852805..ccac6f8999 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -837,6 +837,14 @@ def singledispatch(func):
dispatch_cache[cls] = impl
return impl
+ def _is_union_type(cls):
+ from typing import get_origin, Union
+ return get_origin(cls) in {Union, types.UnionType}
+
+ def _is_valid_union_type(cls):
+ from typing import get_args
+ return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
+
def register(cls, func=None):
"""generic_func.register(cls, func) -> func
@@ -845,7 +853,7 @@ def singledispatch(func):
"""
nonlocal cache_token
if func is None:
- if isinstance(cls, type):
+ if isinstance(cls, type) or _is_valid_union_type(cls):
return lambda f: register(cls, f)
ann = getattr(cls, '__annotations__', {})
if not ann:
@@ -859,12 +867,25 @@ def singledispatch(func):
# only import typing if annotation parsing is necessary
from typing import get_type_hints
argname, cls = next(iter(get_type_hints(func).items()))
- if not isinstance(cls, type):
- raise TypeError(
- f"Invalid annotation for {argname!r}. "
- f"{cls!r} is not a class."
- )
- registry[cls] = func
+ if not isinstance(cls, type) and not _is_valid_union_type(cls):
+ if _is_union_type(cls):
+ raise TypeError(
+ f"Invalid annotation for {argname!r}. "
+ f"{cls!r} not all arguments are classes."
+ )
+ else:
+ raise TypeError(
+ f"Invalid annotation for {argname!r}. "
+ f"{cls!r} is not a class."
+ )
+
+ if _is_union_type(cls):
+ from typing import get_args
+
+ for arg in get_args(cls):
+ registry[arg] = func
+ else:
+ registry[cls] = func
if cache_token is None and hasattr(cls, '__abstractmethods__'):
cache_token = get_cache_token()
dispatch_cache.clear()