From 2841142ce382e70b664a3e2071427a3e11cdeb44 Mon Sep 17 00:00:00 2001 From: ?ukasz Langa Date: Tue, 2 Jul 2013 10:28:51 +0200 Subject: update to the ref implementation as of 02-July-2013: * more predictable dispatch thanks to C3-based linearization for ABC support * improved tests and docstrings --- README.rst | 37 +++++++- setup.py | 2 +- singledispatch.py | 178 +++++++++++++++++++++++++++--------- singledispatch_helpers.py | 12 +++ test_singledispatch.py | 224 +++++++++++++++++++++++++++++++++++++--------- tox.ini | 2 +- 6 files changed, 365 insertions(+), 90 deletions(-) diff --git a/README.rst b/README.rst index e1e345c..39e35c7 100644 --- a/README.rst +++ b/README.rst @@ -13,7 +13,7 @@ To define a generic function, decorate it with the ``@singledispatch`` decorator. Note that the dispatch happens on the type of the first argument, create your function accordingly:: - >>> from functools import singledispatch + >>> from singledispatch import singledispatch >>> @singledispatch ... def fun(arg, verbose=False): ... if verbose: @@ -145,6 +145,41 @@ members of the core CPython team: Change Log ---------- +3.4.0.2 +~~~~~~~ + +Updated to the reference implementation as of 02-July-2013. + +* more predictable dispatch order when abstract base classes are in use: + abstract base classes are now inserted into the MRO of the argument's + class where their functionality is introduced, i.e. issubclass(cls, + abc) returns True for the class itself but returns False for all its + direct base classes. Implicit ABCs for a given class (either + registered or inferred from the presence of a special method like + __len__) are inserted directly after the last ABC explicitly listed in + the MRO of said class. This also means there are less "ambiguous + dispatch" exceptions raised. + +* better test coverage and improved docstrings + +3.4.0.1 +~~~~~~~ + +Updated to the reference implementation as of 31-May-2013. + +* better performance + +* fixed a corner case with PEP 435 enums + +* calls to `dispatch()` also cached + +* dispatching algorithm now now a module-level routine called `_find_impl()` + with a simplified implementation and proper documentation + +* `dispatch()` now handles all caching-related activities + +* terminology more consistent: "overload" -> "implementation" + 3.4.0.0 ~~~~~~~ diff --git a/setup.py b/setup.py index 4ca442a..338905e 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ with codecs.open( setup ( name = 'singledispatch', - version = '3.4.0.0', + version = '3.4.0.2', author = 'Ɓukasz Langa', author_email = 'lukasz@langa.pl', description = __doc__, diff --git a/singledispatch.py b/singledispatch.py index c68ad02..b428920 100644 --- a/singledispatch.py +++ b/singledispatch.py @@ -16,46 +16,138 @@ from singledispatch_helpers import MappingProxyType, get_cache_token ### singledispatch() - single-dispatch generic function decorator ################################################################################ -def _compose_mro(cls, haystack): - """Calculates the MRO for a given class `cls`, including relevant abstract - base classes from `haystack`. +def _c3_merge(sequences): + """Merges MROs in *sequences* to a single MRO using the C3 algorithm. + + Adapted from http://www.python.org/download/releases/2.3/mro/. """ - bases = set(cls.__mro__) - mro = list(cls.__mro__) - for needle in haystack: - if (needle in bases or not hasattr(needle, '__mro__') - or not issubclass(cls, needle)): - continue # either present in the __mro__ already or unrelated - for index, base in enumerate(mro): - if not issubclass(base, needle): + result = [] + while True: + sequences = [s for s in sequences if s] # purge empty sequences + if not sequences: + return result + for s1 in sequences: # find merge candidates among seq heads + candidate = s1[0] + for s2 in sequences: + if candidate in s2[1:]: + candidate = None + break # reject the current head, it appears later + else: break - if base in bases and not issubclass(needle, base): - # Conflict resolution: put classes present in __mro__ and their - # subclasses first. See test_mro_conflicts() in test_functools.py - # for examples. - index += 1 - mro.insert(index, needle) - return mro + if not candidate: + raise RuntimeError("Inconsistent hierarchy") + result.append(candidate) + # remove the chosen candidate + for seq in sequences: + if seq[0] == candidate: + del seq[0] + +def _c3_mro(cls, abcs=None): + """Computes the method resolution order using extended C3 linearization. + + If no *abcs* are given, the algorithm works exactly like the built-in C3 + linearization used for method resolution. + + If given, *abcs* is a list of abstract base classes that should be inserted + into the resulting MRO. Unrelated ABCs are ignored and don't end up in the + result. The algorithm inserts ABCs where their functionality is introduced, + i.e. issubclass(cls, abc) returns True for the class itself but returns + False for all its direct base classes. Implicit ABCs for a given class + (either registered or inferred from the presence of a special method like + __len__) are inserted directly after the last ABC explicitly listed in the + MRO of said class. If two implicit ABCs end up next to each other in the + resulting MRO, their ordering depends on the order of types in *abcs*. + + """ + for i, base in enumerate(reversed(cls.__bases__)): + if hasattr(base, '__abstractmethods__'): + boundary = len(cls.__bases__) - i + break # Bases up to the last explicit ABC are considered first. + else: + boundary = 0 + abcs = list(abcs) if abcs else [] + explicit_bases = list(cls.__bases__[:boundary]) + abstract_bases = [] + other_bases = list(cls.__bases__[boundary:]) + for base in abcs: + if issubclass(cls, base) and not any( + issubclass(b, base) for b in cls.__bases__ + ): + # If *cls* is the class that introduces behaviour described by + # an ABC *base*, insert said ABC to its MRO. + abstract_bases.append(base) + for base in abstract_bases: + abcs.remove(base) + explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases] + abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases] + other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases] + return _c3_merge( + [[cls]] + + explicit_c3_mros + abstract_c3_mros + other_c3_mros + + [explicit_bases] + [abstract_bases] + [other_bases] + ) + +def _compose_mro(cls, types): + """Calculates the method resolution order for a given class *cls*. + + Includes relevant abstract base classes (with their respective bases) from + the *types* iterable. Uses a modified C3 linearization algorithm. + + """ + bases = set(cls.__mro__) + # Remove entries which are already present in the __mro__ or unrelated. + def is_related(typ): + return (typ not in bases and hasattr(typ, '__mro__') + and issubclass(cls, typ)) + types = [n for n in types if is_related(n)] + # Remove entries which are strict bases of other entries (they will end up + # in the MRO anyway. + def is_strict_base(typ): + for other in types: + if typ != other and typ in other.__mro__: + return True + return False + types = [n for n in types if not is_strict_base(n)] + # Subclasses of the ABCs in *types* which are also implemented by + # *cls* can be used to stabilize ABC ordering. + type_set = set(types) + mro = [] + for typ in types: + found = [] + for sub in typ.__subclasses__(): + if sub not in bases and issubclass(cls, sub): + found.append([s for s in sub.__mro__ if s in type_set]) + if not found: + mro.append(typ) + continue + # Favor subclasses with the biggest number of useful bases + found.sort(key=len, reverse=True) + for sub in found: + for subcls in sub: + if subcls not in mro: + mro.append(subcls) + return _c3_mro(cls, abcs=mro) def _find_impl(cls, registry): - """Returns the best matching implementation for the given class `cls` in - `registry`. Where there is no registered implementation for a specific - type, its method resolution order is used to find a more generic - implementation. + """Returns the best matching implementation from *registry* for type *cls*. + + Where there is no registered implementation for a specific type, its method + resolution order is used to find a more generic implementation. - Note: if `registry` does not contain an implementation for the base - `object` type, this function may return None. + Note: if *registry* does not contain an implementation for the base + *object* type, this function may return None. """ mro = _compose_mro(cls, registry.keys()) match = None for t in mro: if match is not None: - # If `match` is an ABC but there is another unrelated, equally - # matching ABC. Refuse the temptation to guess. - if (t in registry and not issubclass(match, t) - and match not in cls.__mro__): + # If *match* is an implicit ABC but there is another unrelated, + # equally matching implicit ABC, refuse the temptation to guess. + if (t in registry and t not in cls.__mro__ + and match not in cls.__mro__ + and not issubclass(match, t)): raise RuntimeError("Ambiguous dispatch: {} or {}".format( match, t)) break @@ -69,19 +161,19 @@ def singledispatch(func): Transforms a function into a generic function, which can have different behaviours depending upon the type of its first argument. The decorated function acts as the default implementation, and additional - implementations can be registered using the 'register()' attribute of - the generic function. + implementations can be registered using the register() attribute of the + generic function. """ registry = {} dispatch_cache = WeakKeyDictionary() cache_token = None - def dispatch(typ): - """generic_func.dispatch(type) -> + def dispatch(cls): + """generic_func.dispatch(cls) -> Runs the dispatch algorithm to return the best available implementation - for the given `type` registered on `generic_func`. + for the given *cls* registered on *generic_func*. """ nonlocal cache_token @@ -91,26 +183,26 @@ def singledispatch(func): dispatch_cache.clear() cache_token = current_token try: - impl = dispatch_cache[typ] + impl = dispatch_cache[cls] except KeyError: try: - impl = registry[typ] + impl = registry[cls] except KeyError: - impl = _find_impl(typ, registry) - dispatch_cache[typ] = impl + impl = _find_impl(cls, registry) + dispatch_cache[cls] = impl return impl - def register(typ, func=None): - """generic_func.register(type, func) -> func + def register(cls, func=None): + """generic_func.register(cls, func) -> func - Registers a new implementation for the given `type` on a `generic_func`. + Registers a new implementation for the given *cls* on a *generic_func*. """ nonlocal cache_token if func is None: - return lambda f: register(typ, f) - registry[typ] = func - if cache_token is None and hasattr(typ, '__abstractmethods__'): + return lambda f: register(cls, f) + registry[cls] = func + if cache_token is None and hasattr(cls, '__abstractmethods__'): cache_token = get_cache_token() dispatch_cache.clear() return func diff --git a/singledispatch_helpers.py b/singledispatch_helpers.py index 5c4aa8b..3a71831 100644 --- a/singledispatch_helpers.py +++ b/singledispatch_helpers.py @@ -8,6 +8,7 @@ from __future__ import unicode_literals from abc import ABCMeta from collections import MutableMapping, UserDict +import sys try: from thread import get_ident except ImportError: @@ -148,3 +149,14 @@ class MappingProxyType(UserDict): def get_cache_token(): return ABCMeta._abc_invalidation_counter + + + +class Support(object): + def dummy(self): + pass + + def cpython_only(self, func): + if 'PyPy' in sys.version: + return self.dummy + return func diff --git a/test_singledispatch.py b/test_singledispatch.py index 3f243a9..0705cff 100644 --- a/test_singledispatch.py +++ b/test_singledispatch.py @@ -10,10 +10,18 @@ import collections import decimal from itertools import permutations import singledispatch as functools -from singledispatch_helpers import ChainMap +from singledispatch_helpers import Support +try: + from collections import ChainMap +except ImportError: + from singledispatch_helpers import ChainMap + collections.ChainMap = ChainMap import unittest +support = Support() + + class TestSingleDispatch(unittest.TestCase): def test_simple_overloads(self): @functools.singledispatch @@ -30,29 +38,24 @@ class TestSingleDispatch(unittest.TestCase): @functools.singledispatch def g(obj): return "base" - class C: + class A: pass - class D(C): + class C(A): pass - def g_C(c): - return "C" - g.register(C, g_C) - self.assertEqual(g(C()), "C") - self.assertEqual(g(D()), "C") - - def test_classic_classes(self): - @functools.singledispatch - def g(obj): - return "base" - class C: + class B(A): pass - class D(C): + class D(C, B): pass - def g_C(c): - return "C" - g.register(C, g_C) - self.assertEqual(g(C()), "C") - self.assertEqual(g(D()), "C") + def g_A(a): + return "A" + def g_B(b): + return "B" + g.register(A, g_A) + g.register(B, g_B) + self.assertEqual(g(A()), "A") + self.assertEqual(g(B()), "B") + self.assertEqual(g(C()), "A") + self.assertEqual(g(D()), "B") def test_register_decorator(self): @functools.singledispatch @@ -77,6 +80,7 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g.__doc__, "Simple test") @unittest.skipUnless(decimal, 'requires _decimal') + @support.cpython_only def test_c_classes(self): @functools.singledispatch def g(obj): @@ -95,22 +99,55 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(rnd), ("Number got rounded",)) def test_compose_mro(self): + # None of the examples in this test depend on haystack ordering. c = collections mro = functools._compose_mro bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] for haystack in permutations(bases): m = mro(dict, haystack) - self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object]) + self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized, + c.Iterable, c.Container, object]) bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict] for haystack in permutations(bases): - m = mro(ChainMap, haystack) - self.assertEqual(m, [ChainMap, c.MutableMapping, c.Mapping, + m = mro(c.ChainMap, haystack) + self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping, + c.Sized, c.Iterable, c.Container, object]) + + # If there's a generic function with implementations registered for + # both Sized and Container, passing a defaultdict to it results in an + # ambiguous dispatch which will cause a RuntimeError (see + # test_mro_conflicts). + bases = [c.Container, c.Sized, str] + for haystack in permutations(bases): + m = mro(c.defaultdict, [c.Sized, c.Container, str]) + self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, + object]) + + # MutableSequence below is registered directly on D. In other words, it + # preceeds MutableMapping which means single dispatch will always + # choose MutableSequence here. + class D(c.defaultdict): + pass + c.MutableSequence.register(D) + bases = [c.MutableSequence, c.MutableMapping] + for haystack in permutations(bases): + m = mro(D, bases) + self.assertEqual(m, [D, c.MutableSequence, c.Sequence, + c.defaultdict, dict, c.MutableMapping, + c.Mapping, c.Sized, c.Iterable, c.Container, + object]) + + # Container and Callable are registered on different base classes and + # a generic function supporting both should always pick the Callable + # implementation if a C instance is passed. + class C(c.defaultdict): + def __call__(self): + pass + bases = [c.Sized, c.Callable, c.Container, c.Mapping] + for haystack in permutations(bases): + m = mro(C, haystack) + self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping, c.Sized, c.Iterable, c.Container, object]) - # Note: The MRO order below depends on haystack ordering. - m = mro(c.defaultdict, [c.Sized, c.Container, str]) - self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object]) - m = mro(c.defaultdict, [c.Container, c.Sized, str]) - self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object]) def test_register_abc(self): c = collections @@ -139,7 +176,7 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(s), "sized") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") - g.register(ChainMap, lambda obj: "chainmap") + g.register(c.ChainMap, lambda obj: "chainmap") self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered self.assertEqual(g(l), "sized") self.assertEqual(g(s), "sized") @@ -206,17 +243,37 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(f), "frozen-set") self.assertEqual(g(t), "tuple") - def test_mro_conflicts(self): + def test_c3_abc(self): c = collections + mro = functools._c3_mro + class A(object): + pass + class B(A): + def __len__(self): + return 0 # implies Sized + @c.Container.register + class C(object): + pass + class D(object): + pass # unrelated + class X(D, C, B): + def __call__(self): + pass # implies Callable + expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] + for abcs in permutations([c.Sized, c.Callable, c.Container]): + self.assertEqual(mro(X, abcs=abcs), expected) + # unrelated ABCs don't appear in the resulting MRO + many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] + self.assertEqual(mro(X, abcs=many_abcs), expected) + def test_mro_conflicts(self): + c = collections @functools.singledispatch def g(arg): return "base" - class O(c.Sized): def __len__(self): return 0 - o = O() self.assertEqual(g(o), "base") g.register(c.Iterable, lambda arg: "iterable") @@ -228,35 +285,114 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ c.Container.register(O) self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ - + c.Set.register(O) + self.assertEqual(g(o), "set") # because c.Set is a subclass of + # c.Sized and c.Container class P: pass - p = P() self.assertEqual(g(p), "base") c.Iterable.register(P) self.assertEqual(g(p), "iterable") c.Container.register(P) - with self.assertRaises(RuntimeError) as re: + with self.assertRaises(RuntimeError) as re_one: g(p) - self.assertEqual( - str(re), - ("Ambiguous dispatch: " - "or "), - ) - + self.assertIn( + str(re_one.exception), + (("Ambiguous dispatch: " + "or "), + ("Ambiguous dispatch: " + "or ")), + ) class Q(c.Sized): def __len__(self): return 0 - q = Q() self.assertEqual(g(q), "sized") c.Iterable.register(Q) self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ c.Set.register(Q) self.assertEqual(g(q), "set") # because c.Set is a subclass of - # c.Sized which is explicitly in - # __mro__ + # c.Sized and c.Iterable + @functools.singledispatch + def h(arg): + return "base" + @h.register(c.Sized) + def _(arg): + return "sized" + @h.register(c.Container) + def _(arg): + return "container" + # Even though Sized and Container are explicit bases of MutableMapping, + # this ABC is implicitly registered on defaultdict which makes all of + # MutableMapping's bases implicit as well from defaultdict's + # perspective. + with self.assertRaises(RuntimeError) as re_two: + h(c.defaultdict(lambda: 0)) + self.assertIn( + str(re_two.exception), + (("Ambiguous dispatch: " + "or "), + ("Ambiguous dispatch: " + "or ")), + ) + class R(c.defaultdict): + pass + c.MutableSequence.register(R) + @functools.singledispatch + def i(arg): + return "base" + @i.register(c.MutableMapping) + def _(arg): + return "mapping" + @i.register(c.MutableSequence) + def _(arg): + return "sequence" + r = R() + self.assertEqual(i(r), "sequence") + class S: + pass + class T(S, c.Sized): + def __len__(self): + return 0 + t = T() + self.assertEqual(h(t), "sized") + c.Container.register(T) + self.assertEqual(h(t), "sized") # because it's explicitly in the MRO + class U: + def __len__(self): + return 0 + u = U() + self.assertEqual(h(u), "sized") # implicit Sized subclass inferred + # from the existence of __len__() + c.Container.register(U) + # There is no preference for registered versus inferred ABCs. + with self.assertRaises(RuntimeError) as re_three: + h(u) + self.assertIn( + str(re_three.exception), + (("Ambiguous dispatch: " + "or "), + ("Ambiguous dispatch: " + "or ")), + ) + class V(c.Sized, S): + def __len__(self): + return 0 + @functools.singledispatch + def j(arg): + return "base" + @j.register(S) + def _(arg): + return "s" + @j.register(c.Container) + def _(arg): + return "container" + v = V() + self.assertEqual(j(v), "s") + c.Container.register(V) + self.assertEqual(j(v), "container") # because it ends up right after + # Sized in the MRO def test_cache_invalidation(self): from collections import UserDict diff --git a/tox.ini b/tox.ini index bcd09a3..b66e223 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py32,py33 +envlist = py33 [testenv] commands = -- cgit v1.2.1