diff options
author | Yurii Karabas <1998uriyyo@gmail.com> | 2021-12-11 01:27:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-11 00:27:55 +0100 |
commit | 3cb357a2e6ac18ee98db5d450414e773744e3c76 (patch) | |
tree | cfe621315c5651e3b89199a2eb2411029706fbcb /Lib/functools.py | |
parent | 810c1769f1c24ed907bdf3cc1086db4e602a28ae (diff) | |
download | cpython-git-3cb357a2e6ac18ee98db5d450414e773744e3c76.tar.gz |
bpo-46014: Add ability to use typing.Union with singledispatch (GH-30017)
Diffstat (limited to 'Lib/functools.py')
-rw-r--r-- | Lib/functools.py | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/Lib/functools.py b/Lib/functools.py index 77ec852805..ccac6f8999 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -837,6 +837,14 @@ def singledispatch(func): dispatch_cache[cls] = impl return impl + def _is_union_type(cls): + from typing import get_origin, Union + return get_origin(cls) in {Union, types.UnionType} + + def _is_valid_union_type(cls): + from typing import get_args + return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls)) + def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -845,7 +853,7 @@ def singledispatch(func): """ nonlocal cache_token if func is None: - if isinstance(cls, type): + if isinstance(cls, type) or _is_valid_union_type(cls): return lambda f: register(cls, f) ann = getattr(cls, '__annotations__', {}) if not ann: @@ -859,12 +867,25 @@ def singledispatch(func): # only import typing if annotation parsing is necessary from typing import get_type_hints argname, cls = next(iter(get_type_hints(func).items())) - if not isinstance(cls, type): - raise TypeError( - f"Invalid annotation for {argname!r}. " - f"{cls!r} is not a class." - ) - registry[cls] = func + if not isinstance(cls, type) and not _is_valid_union_type(cls): + if _is_union_type(cls): + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} not all arguments are classes." + ) + else: + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} is not a class." + ) + + if _is_union_type(cls): + from typing import get_args + + for arg in get_args(cls): + registry[arg] = func + else: + registry[cls] = func if cache_token is None and hasattr(cls, '__abstractmethods__'): cache_token = get_cache_token() dispatch_cache.clear() |