diff options
author | Yurii Karabas <1998uriyyo@gmail.com> | 2021-12-11 01:27:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-11 00:27:55 +0100 |
commit | 3cb357a2e6ac18ee98db5d450414e773744e3c76 (patch) | |
tree | cfe621315c5651e3b89199a2eb2411029706fbcb /Lib | |
parent | 810c1769f1c24ed907bdf3cc1086db4e602a28ae (diff) | |
download | cpython-git-3cb357a2e6ac18ee98db5d450414e773744e3c76.tar.gz |
bpo-46014: Add ability to use typing.Union with singledispatch (GH-30017)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/functools.py | 35 | ||||
-rw-r--r-- | Lib/test/test_functools.py | 30 |
2 files changed, 58 insertions, 7 deletions
diff --git a/Lib/functools.py b/Lib/functools.py index 77ec852805..ccac6f8999 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -837,6 +837,14 @@ def singledispatch(func): dispatch_cache[cls] = impl return impl + def _is_union_type(cls): + from typing import get_origin, Union + return get_origin(cls) in {Union, types.UnionType} + + def _is_valid_union_type(cls): + from typing import get_args + return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls)) + def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -845,7 +853,7 @@ def singledispatch(func): """ nonlocal cache_token if func is None: - if isinstance(cls, type): + if isinstance(cls, type) or _is_valid_union_type(cls): return lambda f: register(cls, f) ann = getattr(cls, '__annotations__', {}) if not ann: @@ -859,12 +867,25 @@ def singledispatch(func): # only import typing if annotation parsing is necessary from typing import get_type_hints argname, cls = next(iter(get_type_hints(func).items())) - if not isinstance(cls, type): - raise TypeError( - f"Invalid annotation for {argname!r}. " - f"{cls!r} is not a class." - ) - registry[cls] = func + if not isinstance(cls, type) and not _is_valid_union_type(cls): + if _is_union_type(cls): + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} not all arguments are classes." + ) + else: + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} is not a class." + ) + + if _is_union_type(cls): + from typing import get_args + + for arg in get_args(cls): + registry[arg] = func + else: + registry[cls] = func if cache_token is None and hasattr(cls, '__abstractmethods__'): cache_token = get_cache_token() dispatch_cache.clear() diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 08cf457cc1..755ac03879 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2684,6 +2684,17 @@ class TestSingleDispatch(unittest.TestCase): 'typing.Iterable[str] is not a class.' )) + with self.assertRaises(TypeError) as exc: + @i.register + def _(arg: typing.Union[int, typing.Iterable[str]]): + return "Invalid Union" + self.assertTrue(str(exc.exception).startswith( + "Invalid annotation for 'arg'." + )) + self.assertTrue(str(exc.exception).endswith( + 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.' + )) + def test_invalid_positional_argument(self): @functools.singledispatch def f(*args): @@ -2692,6 +2703,25 @@ class TestSingleDispatch(unittest.TestCase): with self.assertRaisesRegex(TypeError, msg): f() + def test_union(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: typing.Union[str, bytes]): + return "typing.Union" + + @f.register + def _(arg: int | float): + return "types.UnionType" + + self.assertEqual(f([]), "default") + self.assertEqual(f(""), "typing.Union") + self.assertEqual(f(b""), "typing.Union") + self.assertEqual(f(1), "types.UnionType") + self.assertEqual(f(1.0), "types.UnionType") + class CachedCostItem: _cost = 1 |