diff options
| author | Michele Simionato <michele.simionato@gmail.com> | 2015-07-21 16:43:59 +0200 |
|---|---|---|
| committer | Michele Simionato <michele.simionato@gmail.com> | 2015-07-21 16:43:59 +0200 |
| commit | 10e035eaf93997a63a30687f6fe33b918714ba61 (patch) | |
| tree | 7bc04aceba74e15306e54b542414cdda99291de1 /src/tests/test.py | |
| parent | 5e4bc16af3f5d0e6eca3ac95ddb89ef4331c8fb8 (diff) | |
| download | python-decorator-git-10e035eaf93997a63a30687f6fe33b918714ba61.tar.gz | |
First version of dispatch_on
Diffstat (limited to 'src/tests/test.py')
| -rw-r--r-- | src/tests/test.py | 376 |
1 files changed, 376 insertions, 0 deletions
diff --git a/src/tests/test.py b/src/tests/test.py new file mode 100644 index 0000000..57f6a06 --- /dev/null +++ b/src/tests/test.py @@ -0,0 +1,376 @@ +from __future__ import absolute_import +import sys +import doctest +import unittest +import decimal +import inspect +import collections +from decorator import dispatch_on +try: + from . import documentation +except (SystemError, ValueError): + import documentation + + +class DocumentationTestCase(unittest.TestCase): + def test(self): + doctest.testmod(documentation) + + +class ExtraTestCase(unittest.TestCase): + def test_signature(self): + if hasattr(inspect, 'signature'): + sig = inspect.signature(documentation.f1) + self.assertEqual(str(sig), '(x)') + +# ################### test dispatch_on ############################# # +# adapted from test_functools in Python 3.5 +singledispatch = dispatch_on('obj') + + +class TestSingleDispatch(unittest.TestCase): + def test_simple_overloads(self): + @singledispatch + def g(obj): + return "base" + + @g.register(int) + def g_int(i): + return "integer" + + self.assertEqual(g("str"), "base") + self.assertEqual(g(1), "integer") + self.assertEqual(g([1, 2, 3]), "base") + + def test_mro(self): + @singledispatch + def g(obj): + return "base" + + class A(object): + pass + + class C(A): + pass + + class B(A): + pass + + class D(C, B): + pass + + @g.register(A) + def g_A(a): + return "A" + + @g.register(B) + def g_B(b): + return "B" + + self.assertEqual(g(A()), "A") + self.assertEqual(g(B()), "B") + self.assertEqual(g(C()), "A") + self.assertEqual(g(D()), "B") + + def test_register_decorator(self): + @singledispatch + def g(obj): + return "base" + + @g.register(int) + def g_int(i): + return "int %s" % (i,) + self.assertEqual(g(""), "base") + self.assertEqual(g(12), "int 12") + self.assertIs(g.typemap[int, ], g_int) + + def test_register_error(self): + @singledispatch + def g(obj): + return "base" + + with self.assertRaises(TypeError): + @g.register(int) + def g_int(): + return "int" + + def test_wrapping_attributes(self): + @singledispatch + def g(obj): + "Simple test" + return "Test" + self.assertEqual(g.__name__, "g") + if sys.flags.optimize < 2: + self.assertEqual(g.__doc__, "Simple test") + + def test_c_classes(self): + @singledispatch + def g(obj): + return "base" + + @g.register(decimal.DecimalException) + def _(obj): + return obj.args + subn = decimal.Subnormal("Exponent < Emin") + rnd = decimal.Rounded("Number got rounded") + self.assertEqual(g(subn), ("Exponent < Emin",)) + self.assertEqual(g(rnd), ("Number got rounded",)) + + @g.register(decimal.Subnormal) + def _g(obj): + return "Too small to care." + self.assertEqual(g(subn), "Too small to care.") + self.assertEqual(g(rnd), ("Number got rounded",)) + + def test_register_abc(self): + c = collections + d = {"a": "b"} + l = [1, 2, 3] + s = {object(), None} + f = frozenset(s) + t = (1, 2, 3) + + @singledispatch + def g(obj): + return "base" + + self.assertEqual(g(d), "base") + self.assertEqual(g(l), "base") + self.assertEqual(g(s), "base") + self.assertEqual(g(f), "base") + self.assertEqual(g(t), "base") + + g.register(c.Sized)(lambda obj: "sized") + self.assertEqual(g(d), "sized") + self.assertEqual(g(l), "sized") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + + g.register(c.MutableMapping)(lambda obj: "mutablemapping") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "sized") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + + if hasattr(c, 'ChainMap'): + g.register(c.ChainMap)(lambda obj: "chainmap") + # irrelevant ABCs registered + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "sized") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + + g.register(c.MutableSequence)(lambda obj: "mutablesequence") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + + g.register(c.MutableSet)(lambda obj: "mutableset") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + + g.register(c.Mapping)(lambda obj: "mapping") + self.assertEqual(g(d), "mutablemapping") # not specific enough + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + + g.register(c.Sequence)(lambda obj: "sequence") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sequence") + + g.register(c.Set)(lambda obj: "set") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + + g.register(dict)(lambda obj: "dict") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + + g.register(list)(lambda obj: "list") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + + g.register(set)(lambda obj: "concrete-set") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "concrete-set") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + + g.register(frozenset)(lambda obj: "frozen-set") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "concrete-set") + self.assertEqual(g(f), "frozen-set") + self.assertEqual(g(t), "sequence") + + g.register(tuple)(lambda obj: "tuple") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "concrete-set") + self.assertEqual(g(f), "frozen-set") + self.assertEqual(g(t), "tuple") + + self.assertEqual( + [abc.__name__ for abc in g.abcs[0]], + ['MutableMapping', 'MutableSequence', 'MutableSet', 'Mapping', + 'Sequence', 'Set', 'Sized']) + + def test_mro_conflicts(self): + c = collections + + @singledispatch + def g(obj): + return "base" + + class O(c.Sized): + def __len__(self): + return 0 + o = O() + self.assertEqual(g(o), "base") + g.register(c.Iterable)(lambda arg: "iterable") + g.register(c.Container)(lambda arg: "container") + g.register(c.Sized)(lambda arg: "sized") + g.register(c.Set)(lambda arg: "set") + self.assertEqual(g(o), "sized") + c.Iterable.register(O) + self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ + c.Container.register(O) + return + self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ + c.Set.register(O) + self.assertEqual(g(o), "set") + # because c.Set is a subclass of + # c.Sized and c.Container + + class P: + pass + p = P() + self.assertEqual(g(p), "base") + c.Iterable.register(P) + self.assertEqual(g(p), "iterable") + c.Container.register(P) + + with self.assertRaises(RuntimeError) as re_one: + g(p) + self.assertIn( + str(re_one.exception), + (("Ambiguous dispatch: <class 'collections.abc.Container'> " + "or <class 'collections.abc.Iterable'>"), + ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " + "or <class 'collections.abc.Container'>")), + ) + class Q(c.Sized): + def __len__(self): + return 0 + q = Q() + self.assertEqual(g(q), "sized") + c.Iterable.register(Q) + self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ + c.Set.register(Q) + self.assertEqual(g(q), "set") # because c.Set is a subclass of + # c.Sized and c.Iterable + @functools.singledispatch + def h(arg): + return "base" + @h.register(c.Sized) + def _(arg): + return "sized" + @h.register(c.Container) + def _(arg): + return "container" + # Even though Sized and Container are explicit bases of MutableMapping, + # this ABC is implicitly registered on defaultdict which makes all of + # MutableMapping's bases implicit as well from defaultdict's + # perspective. + with self.assertRaises(RuntimeError) as re_two: + h(c.defaultdict(lambda: 0)) + self.assertIn( + str(re_two.exception), + (("Ambiguous dispatch: <class 'collections.abc.Container'> " + "or <class 'collections.abc.Sized'>"), + ("Ambiguous dispatch: <class 'collections.abc.Sized'> " + "or <class 'collections.abc.Container'>")), + ) + class R(c.defaultdict): + pass + c.MutableSequence.register(R) + @functools.singledispatch + def i(arg): + return "base" + @i.register(c.MutableMapping) + def _(arg): + return "mapping" + @i.register(c.MutableSequence) + def _(arg): + return "sequence" + r = R() + self.assertEqual(i(r), "sequence") + class S: + pass + class T(S, c.Sized): + def __len__(self): + return 0 + t = T() + self.assertEqual(h(t), "sized") + c.Container.register(T) + self.assertEqual(h(t), "sized") # because it's explicitly in the MRO + class U: + def __len__(self): + return 0 + u = U() + self.assertEqual(h(u), "sized") # implicit Sized subclass inferred + # from the existence of __len__() + c.Container.register(U) + # There is no preference for registered versus inferred ABCs. + with self.assertRaises(RuntimeError) as re_three: + h(u) + self.assertIn( + str(re_three.exception), + (("Ambiguous dispatch: <class 'collections.abc.Container'> " + "or <class 'collections.abc.Sized'>"), + ("Ambiguous dispatch: <class 'collections.abc.Sized'> " + "or <class 'collections.abc.Container'>")), + ) + class V(c.Sized, S): + def __len__(self): + return 0 + @functools.singledispatch + def j(arg): + return "base" + @j.register(S) + def _(arg): + return "s" + @j.register(c.Container) + def _(arg): + return "container" + v = V() + self.assertEqual(j(v), "s") + c.Container.register(V) + self.assertEqual(j(v), "container") # because it ends up right after + # Sized in the MRO + +if __name__ == '__main__': + unittest.main() |
