diff options
author | ?ukasz Langa <lukasz@langa.pl> | 2013-07-02 11:03:31 +0200 |
---|---|---|
committer | ?ukasz Langa <lukasz@langa.pl> | 2013-07-02 11:03:31 +0200 |
commit | 86a2569645566c00e2ef36dff084315c75d573ca (patch) | |
tree | 4dc2068fd8c894e1d73575d00535d61295a5e2f6 /test_singledispatch.py | |
parent | 26495c234aa9ff61a5134a696329d22521cb2666 (diff) | |
parent | 2841142ce382e70b664a3e2071427a3e11cdeb44 (diff) | |
download | singledispatch-86a2569645566c00e2ef36dff084315c75d573ca.tar.gz |
Merged with upstream and made compatible with Python 2.6 - 3.33.4.0.2
Diffstat (limited to 'test_singledispatch.py')
-rw-r--r-- | test_singledispatch.py | 239 |
1 files changed, 194 insertions, 45 deletions
diff --git a/test_singledispatch.py b/test_singledispatch.py index 60872f8..779cb6d 100644 --- a/test_singledispatch.py +++ b/test_singledispatch.py @@ -10,13 +10,33 @@ import collections import decimal from itertools import permutations import singledispatch as functools -from singledispatch_helpers import ChainMap, OrderedDict +from singledispatch_helpers import Support +try: + from collections import ChainMap +except ImportError: + from singledispatch_helpers import ChainMap + collections.ChainMap = ChainMap +try: + from collections import OrderedDict +except ImportError: + from singledispatch_helpers import OrderedDict + collections.OrderedDict = OrderedDict try: import unittest2 as unittest except ImportError: import unittest +support = Support() +for _prefix in ('collections.abc', '_abcoll'): + if _prefix in repr(collections.Container): + abcoll_prefix = _prefix + break +else: + abcoll_prefix = '?' +del _prefix + + class TestSingleDispatch(unittest.TestCase): def test_simple_overloads(self): @functools.singledispatch @@ -33,29 +53,24 @@ class TestSingleDispatch(unittest.TestCase): @functools.singledispatch def g(obj): return "base" - class C(object): + class A(object): pass - class D(C): + class C(A): pass - def g_C(c): - return "C" - g.register(C, g_C) - self.assertEqual(g(C()), "C") - self.assertEqual(g(D()), "C") - - def test_classic_classes(self): - @functools.singledispatch - def g(obj): - return "base" - class C(object): + class B(A): pass - class D(C): + class D(C, B): pass - def g_C(c): - return "C" - g.register(C, g_C) - self.assertEqual(g(C()), "C") - self.assertEqual(g(D()), "C") + def g_A(a): + return "A" + def g_B(b): + return "B" + g.register(A, g_A) + g.register(B, g_B) + self.assertEqual(g(A()), "A") + self.assertEqual(g(B()), "B") + self.assertEqual(g(C()), "A") + self.assertEqual(g(D()), "B") def test_register_decorator(self): @functools.singledispatch @@ -80,6 +95,7 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g.__doc__, "Simple test") @unittest.skipUnless(decimal, 'requires _decimal') + @support.cpython_only def test_c_classes(self): @functools.singledispatch def g(obj): @@ -98,22 +114,55 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(rnd), ("Number got rounded",)) def test_compose_mro(self): + # None of the examples in this test depend on haystack ordering. c = collections mro = functools._compose_mro bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] for haystack in permutations(bases): m = mro(dict, haystack) - self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object]) - bases = [c.Container, c.Mapping, c.MutableMapping, OrderedDict] + self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized, + c.Iterable, c.Container, object]) + bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict] for haystack in permutations(bases): - m = mro(ChainMap, haystack) - self.assertEqual(m, [ChainMap, c.MutableMapping, c.Mapping, + m = mro(c.ChainMap, haystack) + self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping, + c.Sized, c.Iterable, c.Container, object]) + + # If there's a generic function with implementations registered for + # both Sized and Container, passing a defaultdict to it results in an + # ambiguous dispatch which will cause a RuntimeError (see + # test_mro_conflicts). + bases = [c.Container, c.Sized, str] + for haystack in permutations(bases): + m = mro(c.defaultdict, [c.Sized, c.Container, str]) + self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, + object]) + + # MutableSequence below is registered directly on D. In other words, it + # preceeds MutableMapping which means single dispatch will always + # choose MutableSequence here. + class D(c.defaultdict): + pass + c.MutableSequence.register(D) + bases = [c.MutableSequence, c.MutableMapping] + for haystack in permutations(bases): + m = mro(D, bases) + self.assertEqual(m, [D, c.MutableSequence, c.Sequence, + c.defaultdict, dict, c.MutableMapping, + c.Mapping, c.Sized, c.Iterable, c.Container, + object]) + + # Container and Callable are registered on different base classes and + # a generic function supporting both should always pick the Callable + # implementation if a C instance is passed. + class C(c.defaultdict): + def __call__(self): + pass + bases = [c.Sized, c.Callable, c.Container, c.Mapping] + for haystack in permutations(bases): + m = mro(C, haystack) + self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping, c.Sized, c.Iterable, c.Container, object]) - # Note: The MRO order below depends on haystack ordering. - m = mro(c.defaultdict, [c.Sized, c.Container, str]) - self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object]) - m = mro(c.defaultdict, [c.Container, c.Sized, str]) - self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object]) def test_register_abc(self): c = collections @@ -142,7 +191,7 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(s), "sized") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") - g.register(ChainMap, lambda obj: "chainmap") + g.register(c.ChainMap, lambda obj: "chainmap") self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered self.assertEqual(g(l), "sized") self.assertEqual(g(s), "sized") @@ -209,17 +258,38 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(f), "frozen-set") self.assertEqual(g(t), "tuple") - def test_mro_conflicts(self): + def test_c3_abc(self): c = collections + mro = functools._c3_mro + class A(object): + pass + class B(A): + def __len__(self): + return 0 # implies Sized + #@c.Container.register + class C(object): + pass + c.Container.register(C) + class D(object): + pass # unrelated + class X(D, C, B): + def __call__(self): + pass # implies Callable + expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] + for abcs in permutations([c.Sized, c.Callable, c.Container]): + self.assertEqual(mro(X, abcs=abcs), expected) + # unrelated ABCs don't appear in the resulting MRO + many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] + self.assertEqual(mro(X, abcs=many_abcs), expected) + def test_mro_conflicts(self): + c = collections @functools.singledispatch def g(arg): return "base" - class O(c.Sized): def __len__(self): return 0 - o = O() self.assertEqual(g(o), "base") g.register(c.Iterable, lambda arg: "iterable") @@ -231,35 +301,114 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ c.Container.register(O) self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ - + c.Set.register(O) + self.assertEqual(g(o), "set") # because c.Set is a subclass of + # c.Sized and c.Container class P(object): pass - p = P() self.assertEqual(g(p), "base") c.Iterable.register(P) self.assertEqual(g(p), "iterable") c.Container.register(P) - with self.assertRaises(RuntimeError) as re: + with self.assertRaises(RuntimeError) as re_one: g(p) - self.assertEqual( - str(re), - ("Ambiguous dispatch: <class 'collections.abc.Container'> " - "or <class 'collections.abc.Iterable'>"), - ) - + self.assertIn( + str(re_one.exception), + (("Ambiguous dispatch: <class '{prefix}.Container'> " + "or <class '{prefix}.Iterable'>").format(prefix=abcoll_prefix), + ("Ambiguous dispatch: <class '{prefix}.Iterable'> " + "or <class '{prefix}.Container'>").format(prefix=abcoll_prefix)), + ) class Q(c.Sized): def __len__(self): return 0 - q = Q() self.assertEqual(g(q), "sized") c.Iterable.register(Q) self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ c.Set.register(Q) self.assertEqual(g(q), "set") # because c.Set is a subclass of - # c.Sized which is explicitly in - # __mro__ + # c.Sized and c.Iterable + @functools.singledispatch + def h(arg): + return "base" + @h.register(c.Sized) + def _(arg): + return "sized" + @h.register(c.Container) + def _(arg): + return "container" + # Even though Sized and Container are explicit bases of MutableMapping, + # this ABC is implicitly registered on defaultdict which makes all of + # MutableMapping's bases implicit as well from defaultdict's + # perspective. + with self.assertRaises(RuntimeError) as re_two: + h(c.defaultdict(lambda: 0)) + self.assertIn( + str(re_two.exception), + (("Ambiguous dispatch: <class '{prefix}.Container'> " + "or <class '{prefix}.Sized'>").format(prefix=abcoll_prefix), + ("Ambiguous dispatch: <class '{prefix}.Sized'> " + "or <class '{prefix}.Container'>").format(prefix=abcoll_prefix)), + ) + class R(c.defaultdict): + pass + c.MutableSequence.register(R) + @functools.singledispatch + def i(arg): + return "base" + @i.register(c.MutableMapping) + def _(arg): + return "mapping" + @i.register(c.MutableSequence) + def _(arg): + return "sequence" + r = R() + self.assertEqual(i(r), "sequence") + class S(object): + pass + class T(S, c.Sized): + def __len__(self): + return 0 + t = T() + self.assertEqual(h(t), "sized") + c.Container.register(T) + self.assertEqual(h(t), "sized") # because it's explicitly in the MRO + class U(object): + def __len__(self): + return 0 + u = U() + self.assertEqual(h(u), "sized") # implicit Sized subclass inferred + # from the existence of __len__() + c.Container.register(U) + # There is no preference for registered versus inferred ABCs. + with self.assertRaises(RuntimeError) as re_three: + h(u) + self.assertIn( + str(re_three.exception), + (("Ambiguous dispatch: <class '{prefix}.Container'> " + "or <class '{prefix}.Sized'>").format(prefix=abcoll_prefix), + ("Ambiguous dispatch: <class '{prefix}.Sized'> " + "or <class '{prefix}.Container'>").format(prefix=abcoll_prefix)), + ) + class V(c.Sized, S): + def __len__(self): + return 0 + @functools.singledispatch + def j(arg): + return "base" + @j.register(S) + def _(arg): + return "s" + @j.register(c.Container) + def _(arg): + return "container" + v = V() + self.assertEqual(j(v), "s") + c.Container.register(V) + self.assertEqual(j(v), "container") # because it ends up right after + # Sized in the MRO def test_cache_invalidation(self): try: |