summaryrefslogtreecommitdiff
path: root/Lib
diff options
context:
space:
mode:
authorYurii Karabas <1998uriyyo@gmail.com>2021-12-11 01:27:55 +0200
committerGitHub <noreply@github.com>2021-12-11 00:27:55 +0100
commit3cb357a2e6ac18ee98db5d450414e773744e3c76 (patch)
treecfe621315c5651e3b89199a2eb2411029706fbcb /Lib
parent810c1769f1c24ed907bdf3cc1086db4e602a28ae (diff)
downloadcpython-git-3cb357a2e6ac18ee98db5d450414e773744e3c76.tar.gz
bpo-46014: Add ability to use typing.Union with singledispatch (GH-30017)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/functools.py35
-rw-r--r--Lib/test/test_functools.py30
2 files changed, 58 insertions, 7 deletions
diff --git a/Lib/functools.py b/Lib/functools.py
index 77ec852805..ccac6f8999 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -837,6 +837,14 @@ def singledispatch(func):
dispatch_cache[cls] = impl
return impl
+ def _is_union_type(cls):
+ from typing import get_origin, Union
+ return get_origin(cls) in {Union, types.UnionType}
+
+ def _is_valid_union_type(cls):
+ from typing import get_args
+ return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
+
def register(cls, func=None):
"""generic_func.register(cls, func) -> func
@@ -845,7 +853,7 @@ def singledispatch(func):
"""
nonlocal cache_token
if func is None:
- if isinstance(cls, type):
+ if isinstance(cls, type) or _is_valid_union_type(cls):
return lambda f: register(cls, f)
ann = getattr(cls, '__annotations__', {})
if not ann:
@@ -859,12 +867,25 @@ def singledispatch(func):
# only import typing if annotation parsing is necessary
from typing import get_type_hints
argname, cls = next(iter(get_type_hints(func).items()))
- if not isinstance(cls, type):
- raise TypeError(
- f"Invalid annotation for {argname!r}. "
- f"{cls!r} is not a class."
- )
- registry[cls] = func
+ if not isinstance(cls, type) and not _is_valid_union_type(cls):
+ if _is_union_type(cls):
+ raise TypeError(
+ f"Invalid annotation for {argname!r}. "
+ f"{cls!r} not all arguments are classes."
+ )
+ else:
+ raise TypeError(
+ f"Invalid annotation for {argname!r}. "
+ f"{cls!r} is not a class."
+ )
+
+ if _is_union_type(cls):
+ from typing import get_args
+
+ for arg in get_args(cls):
+ registry[arg] = func
+ else:
+ registry[cls] = func
if cache_token is None and hasattr(cls, '__abstractmethods__'):
cache_token = get_cache_token()
dispatch_cache.clear()
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index 08cf457cc1..755ac03879 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -2684,6 +2684,17 @@ class TestSingleDispatch(unittest.TestCase):
'typing.Iterable[str] is not a class.'
))
+ with self.assertRaises(TypeError) as exc:
+ @i.register
+ def _(arg: typing.Union[int, typing.Iterable[str]]):
+ return "Invalid Union"
+ self.assertTrue(str(exc.exception).startswith(
+ "Invalid annotation for 'arg'."
+ ))
+ self.assertTrue(str(exc.exception).endswith(
+ 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
+ ))
+
def test_invalid_positional_argument(self):
@functools.singledispatch
def f(*args):
@@ -2692,6 +2703,25 @@ class TestSingleDispatch(unittest.TestCase):
with self.assertRaisesRegex(TypeError, msg):
f()
+ def test_union(self):
+ @functools.singledispatch
+ def f(arg):
+ return "default"
+
+ @f.register
+ def _(arg: typing.Union[str, bytes]):
+ return "typing.Union"
+
+ @f.register
+ def _(arg: int | float):
+ return "types.UnionType"
+
+ self.assertEqual(f([]), "default")
+ self.assertEqual(f(""), "typing.Union")
+ self.assertEqual(f(b""), "typing.Union")
+ self.assertEqual(f(1), "types.UnionType")
+ self.assertEqual(f(1.0), "types.UnionType")
+
class CachedCostItem:
_cost = 1