summaryrefslogtreecommitdiff
path: root/test_singledispatch.py
diff options
context:
space:
mode:
author?ukasz Langa <lukasz@langa.pl>2013-07-02 11:03:31 +0200
committer?ukasz Langa <lukasz@langa.pl>2013-07-02 11:03:31 +0200
commit86a2569645566c00e2ef36dff084315c75d573ca (patch)
tree4dc2068fd8c894e1d73575d00535d61295a5e2f6 /test_singledispatch.py
parent26495c234aa9ff61a5134a696329d22521cb2666 (diff)
parent2841142ce382e70b664a3e2071427a3e11cdeb44 (diff)
downloadsingledispatch-86a2569645566c00e2ef36dff084315c75d573ca.tar.gz
Merged with upstream and made compatible with Python 2.6 - 3.33.4.0.2
Diffstat (limited to 'test_singledispatch.py')
-rw-r--r--test_singledispatch.py239
1 files changed, 194 insertions, 45 deletions
diff --git a/test_singledispatch.py b/test_singledispatch.py
index 60872f8..779cb6d 100644
--- a/test_singledispatch.py
+++ b/test_singledispatch.py
@@ -10,13 +10,33 @@ import collections
import decimal
from itertools import permutations
import singledispatch as functools
-from singledispatch_helpers import ChainMap, OrderedDict
+from singledispatch_helpers import Support
+try:
+ from collections import ChainMap
+except ImportError:
+ from singledispatch_helpers import ChainMap
+ collections.ChainMap = ChainMap
+try:
+ from collections import OrderedDict
+except ImportError:
+ from singledispatch_helpers import OrderedDict
+ collections.OrderedDict = OrderedDict
try:
import unittest2 as unittest
except ImportError:
import unittest
+support = Support()
+for _prefix in ('collections.abc', '_abcoll'):
+ if _prefix in repr(collections.Container):
+ abcoll_prefix = _prefix
+ break
+else:
+ abcoll_prefix = '?'
+del _prefix
+
+
class TestSingleDispatch(unittest.TestCase):
def test_simple_overloads(self):
@functools.singledispatch
@@ -33,29 +53,24 @@ class TestSingleDispatch(unittest.TestCase):
@functools.singledispatch
def g(obj):
return "base"
- class C(object):
+ class A(object):
pass
- class D(C):
+ class C(A):
pass
- def g_C(c):
- return "C"
- g.register(C, g_C)
- self.assertEqual(g(C()), "C")
- self.assertEqual(g(D()), "C")
-
- def test_classic_classes(self):
- @functools.singledispatch
- def g(obj):
- return "base"
- class C(object):
+ class B(A):
pass
- class D(C):
+ class D(C, B):
pass
- def g_C(c):
- return "C"
- g.register(C, g_C)
- self.assertEqual(g(C()), "C")
- self.assertEqual(g(D()), "C")
+ def g_A(a):
+ return "A"
+ def g_B(b):
+ return "B"
+ g.register(A, g_A)
+ g.register(B, g_B)
+ self.assertEqual(g(A()), "A")
+ self.assertEqual(g(B()), "B")
+ self.assertEqual(g(C()), "A")
+ self.assertEqual(g(D()), "B")
def test_register_decorator(self):
@functools.singledispatch
@@ -80,6 +95,7 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g.__doc__, "Simple test")
@unittest.skipUnless(decimal, 'requires _decimal')
+ @support.cpython_only
def test_c_classes(self):
@functools.singledispatch
def g(obj):
@@ -98,22 +114,55 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(rnd), ("Number got rounded",))
def test_compose_mro(self):
+ # None of the examples in this test depend on haystack ordering.
c = collections
mro = functools._compose_mro
bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
for haystack in permutations(bases):
m = mro(dict, haystack)
- self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object])
- bases = [c.Container, c.Mapping, c.MutableMapping, OrderedDict]
+ self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
+ c.Iterable, c.Container, object])
+ bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
for haystack in permutations(bases):
- m = mro(ChainMap, haystack)
- self.assertEqual(m, [ChainMap, c.MutableMapping, c.Mapping,
+ m = mro(c.ChainMap, haystack)
+ self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
+ c.Sized, c.Iterable, c.Container, object])
+
+ # If there's a generic function with implementations registered for
+ # both Sized and Container, passing a defaultdict to it results in an
+ # ambiguous dispatch which will cause a RuntimeError (see
+ # test_mro_conflicts).
+ bases = [c.Container, c.Sized, str]
+ for haystack in permutations(bases):
+ m = mro(c.defaultdict, [c.Sized, c.Container, str])
+ self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
+ object])
+
+ # MutableSequence below is registered directly on D. In other words, it
+ # preceeds MutableMapping which means single dispatch will always
+ # choose MutableSequence here.
+ class D(c.defaultdict):
+ pass
+ c.MutableSequence.register(D)
+ bases = [c.MutableSequence, c.MutableMapping]
+ for haystack in permutations(bases):
+ m = mro(D, bases)
+ self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
+ c.defaultdict, dict, c.MutableMapping,
+ c.Mapping, c.Sized, c.Iterable, c.Container,
+ object])
+
+ # Container and Callable are registered on different base classes and
+ # a generic function supporting both should always pick the Callable
+ # implementation if a C instance is passed.
+ class C(c.defaultdict):
+ def __call__(self):
+ pass
+ bases = [c.Sized, c.Callable, c.Container, c.Mapping]
+ for haystack in permutations(bases):
+ m = mro(C, haystack)
+ self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
c.Sized, c.Iterable, c.Container, object])
- # Note: The MRO order below depends on haystack ordering.
- m = mro(c.defaultdict, [c.Sized, c.Container, str])
- self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object])
- m = mro(c.defaultdict, [c.Container, c.Sized, str])
- self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object])
def test_register_abc(self):
c = collections
@@ -142,7 +191,7 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(s), "sized")
self.assertEqual(g(f), "sized")
self.assertEqual(g(t), "sized")
- g.register(ChainMap, lambda obj: "chainmap")
+ g.register(c.ChainMap, lambda obj: "chainmap")
self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
self.assertEqual(g(l), "sized")
self.assertEqual(g(s), "sized")
@@ -209,17 +258,38 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(f), "frozen-set")
self.assertEqual(g(t), "tuple")
- def test_mro_conflicts(self):
+ def test_c3_abc(self):
c = collections
+ mro = functools._c3_mro
+ class A(object):
+ pass
+ class B(A):
+ def __len__(self):
+ return 0 # implies Sized
+ #@c.Container.register
+ class C(object):
+ pass
+ c.Container.register(C)
+ class D(object):
+ pass # unrelated
+ class X(D, C, B):
+ def __call__(self):
+ pass # implies Callable
+ expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
+ for abcs in permutations([c.Sized, c.Callable, c.Container]):
+ self.assertEqual(mro(X, abcs=abcs), expected)
+ # unrelated ABCs don't appear in the resulting MRO
+ many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
+ self.assertEqual(mro(X, abcs=many_abcs), expected)
+ def test_mro_conflicts(self):
+ c = collections
@functools.singledispatch
def g(arg):
return "base"
-
class O(c.Sized):
def __len__(self):
return 0
-
o = O()
self.assertEqual(g(o), "base")
g.register(c.Iterable, lambda arg: "iterable")
@@ -231,35 +301,114 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
c.Container.register(O)
self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
-
+ c.Set.register(O)
+ self.assertEqual(g(o), "set") # because c.Set is a subclass of
+ # c.Sized and c.Container
class P(object):
pass
-
p = P()
self.assertEqual(g(p), "base")
c.Iterable.register(P)
self.assertEqual(g(p), "iterable")
c.Container.register(P)
- with self.assertRaises(RuntimeError) as re:
+ with self.assertRaises(RuntimeError) as re_one:
g(p)
- self.assertEqual(
- str(re),
- ("Ambiguous dispatch: <class 'collections.abc.Container'> "
- "or <class 'collections.abc.Iterable'>"),
- )
-
+ self.assertIn(
+ str(re_one.exception),
+ (("Ambiguous dispatch: <class '{prefix}.Container'> "
+ "or <class '{prefix}.Iterable'>").format(prefix=abcoll_prefix),
+ ("Ambiguous dispatch: <class '{prefix}.Iterable'> "
+ "or <class '{prefix}.Container'>").format(prefix=abcoll_prefix)),
+ )
class Q(c.Sized):
def __len__(self):
return 0
-
q = Q()
self.assertEqual(g(q), "sized")
c.Iterable.register(Q)
self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
c.Set.register(Q)
self.assertEqual(g(q), "set") # because c.Set is a subclass of
- # c.Sized which is explicitly in
- # __mro__
+ # c.Sized and c.Iterable
+ @functools.singledispatch
+ def h(arg):
+ return "base"
+ @h.register(c.Sized)
+ def _(arg):
+ return "sized"
+ @h.register(c.Container)
+ def _(arg):
+ return "container"
+ # Even though Sized and Container are explicit bases of MutableMapping,
+ # this ABC is implicitly registered on defaultdict which makes all of
+ # MutableMapping's bases implicit as well from defaultdict's
+ # perspective.
+ with self.assertRaises(RuntimeError) as re_two:
+ h(c.defaultdict(lambda: 0))
+ self.assertIn(
+ str(re_two.exception),
+ (("Ambiguous dispatch: <class '{prefix}.Container'> "
+ "or <class '{prefix}.Sized'>").format(prefix=abcoll_prefix),
+ ("Ambiguous dispatch: <class '{prefix}.Sized'> "
+ "or <class '{prefix}.Container'>").format(prefix=abcoll_prefix)),
+ )
+ class R(c.defaultdict):
+ pass
+ c.MutableSequence.register(R)
+ @functools.singledispatch
+ def i(arg):
+ return "base"
+ @i.register(c.MutableMapping)
+ def _(arg):
+ return "mapping"
+ @i.register(c.MutableSequence)
+ def _(arg):
+ return "sequence"
+ r = R()
+ self.assertEqual(i(r), "sequence")
+ class S(object):
+ pass
+ class T(S, c.Sized):
+ def __len__(self):
+ return 0
+ t = T()
+ self.assertEqual(h(t), "sized")
+ c.Container.register(T)
+ self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
+ class U(object):
+ def __len__(self):
+ return 0
+ u = U()
+ self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
+ # from the existence of __len__()
+ c.Container.register(U)
+ # There is no preference for registered versus inferred ABCs.
+ with self.assertRaises(RuntimeError) as re_three:
+ h(u)
+ self.assertIn(
+ str(re_three.exception),
+ (("Ambiguous dispatch: <class '{prefix}.Container'> "
+ "or <class '{prefix}.Sized'>").format(prefix=abcoll_prefix),
+ ("Ambiguous dispatch: <class '{prefix}.Sized'> "
+ "or <class '{prefix}.Container'>").format(prefix=abcoll_prefix)),
+ )
+ class V(c.Sized, S):
+ def __len__(self):
+ return 0
+ @functools.singledispatch
+ def j(arg):
+ return "base"
+ @j.register(S)
+ def _(arg):
+ return "s"
+ @j.register(c.Container)
+ def _(arg):
+ return "container"
+ v = V()
+ self.assertEqual(j(v), "s")
+ c.Container.register(V)
+ self.assertEqual(j(v), "container") # because it ends up right after
+ # Sized in the MRO
def test_cache_invalidation(self):
try: