summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author?ukasz Langa <lukasz@langa.pl>2013-07-02 10:28:51 +0200
committer?ukasz Langa <lukasz@langa.pl>2013-07-02 10:28:51 +0200
commit2841142ce382e70b664a3e2071427a3e11cdeb44 (patch)
tree6065230bcbd33f628bab4cf3e15c4fc03a0dedc9
parent005e4d4b44bf259e5c748b6dad3f79d478773994 (diff)
downloadsingledispatch-2841142ce382e70b664a3e2071427a3e11cdeb44.tar.gz
update to the ref implementation as of 02-July-2013:
* more predictable dispatch thanks to C3-based linearization for ABC support * improved tests and docstrings
-rw-r--r--README.rst37
-rw-r--r--setup.py2
-rw-r--r--singledispatch.py178
-rw-r--r--singledispatch_helpers.py12
-rw-r--r--test_singledispatch.py224
-rw-r--r--tox.ini2
6 files changed, 365 insertions, 90 deletions
diff --git a/README.rst b/README.rst
index e1e345c..39e35c7 100644
--- a/README.rst
+++ b/README.rst
@@ -13,7 +13,7 @@ To define a generic function, decorate it with the ``@singledispatch``
decorator. Note that the dispatch happens on the type of the first
argument, create your function accordingly::
- >>> from functools import singledispatch
+ >>> from singledispatch import singledispatch
>>> @singledispatch
... def fun(arg, verbose=False):
... if verbose:
@@ -145,6 +145,41 @@ members of the core CPython team:
Change Log
----------
+3.4.0.2
+~~~~~~~
+
+Updated to the reference implementation as of 02-July-2013.
+
+* more predictable dispatch order when abstract base classes are in use:
+ abstract base classes are now inserted into the MRO of the argument's
+ class where their functionality is introduced, i.e. issubclass(cls,
+ abc) returns True for the class itself but returns False for all its
+ direct base classes. Implicit ABCs for a given class (either
+ registered or inferred from the presence of a special method like
+ __len__) are inserted directly after the last ABC explicitly listed in
+ the MRO of said class. This also means there are less "ambiguous
+ dispatch" exceptions raised.
+
+* better test coverage and improved docstrings
+
+3.4.0.1
+~~~~~~~
+
+Updated to the reference implementation as of 31-May-2013.
+
+* better performance
+
+* fixed a corner case with PEP 435 enums
+
+* calls to `dispatch()` also cached
+
+* dispatching algorithm now now a module-level routine called `_find_impl()`
+ with a simplified implementation and proper documentation
+
+* `dispatch()` now handles all caching-related activities
+
+* terminology more consistent: "overload" -> "implementation"
+
3.4.0.0
~~~~~~~
diff --git a/setup.py b/setup.py
index 4ca442a..338905e 100644
--- a/setup.py
+++ b/setup.py
@@ -43,7 +43,7 @@ with codecs.open(
setup (
name = 'singledispatch',
- version = '3.4.0.0',
+ version = '3.4.0.2',
author = 'Ɓukasz Langa',
author_email = 'lukasz@langa.pl',
description = __doc__,
diff --git a/singledispatch.py b/singledispatch.py
index c68ad02..b428920 100644
--- a/singledispatch.py
+++ b/singledispatch.py
@@ -16,46 +16,138 @@ from singledispatch_helpers import MappingProxyType, get_cache_token
### singledispatch() - single-dispatch generic function decorator
################################################################################
-def _compose_mro(cls, haystack):
- """Calculates the MRO for a given class `cls`, including relevant abstract
- base classes from `haystack`.
+def _c3_merge(sequences):
+ """Merges MROs in *sequences* to a single MRO using the C3 algorithm.
+
+ Adapted from http://www.python.org/download/releases/2.3/mro/.
"""
- bases = set(cls.__mro__)
- mro = list(cls.__mro__)
- for needle in haystack:
- if (needle in bases or not hasattr(needle, '__mro__')
- or not issubclass(cls, needle)):
- continue # either present in the __mro__ already or unrelated
- for index, base in enumerate(mro):
- if not issubclass(base, needle):
+ result = []
+ while True:
+ sequences = [s for s in sequences if s] # purge empty sequences
+ if not sequences:
+ return result
+ for s1 in sequences: # find merge candidates among seq heads
+ candidate = s1[0]
+ for s2 in sequences:
+ if candidate in s2[1:]:
+ candidate = None
+ break # reject the current head, it appears later
+ else:
break
- if base in bases and not issubclass(needle, base):
- # Conflict resolution: put classes present in __mro__ and their
- # subclasses first. See test_mro_conflicts() in test_functools.py
- # for examples.
- index += 1
- mro.insert(index, needle)
- return mro
+ if not candidate:
+ raise RuntimeError("Inconsistent hierarchy")
+ result.append(candidate)
+ # remove the chosen candidate
+ for seq in sequences:
+ if seq[0] == candidate:
+ del seq[0]
+
+def _c3_mro(cls, abcs=None):
+ """Computes the method resolution order using extended C3 linearization.
+
+ If no *abcs* are given, the algorithm works exactly like the built-in C3
+ linearization used for method resolution.
+
+ If given, *abcs* is a list of abstract base classes that should be inserted
+ into the resulting MRO. Unrelated ABCs are ignored and don't end up in the
+ result. The algorithm inserts ABCs where their functionality is introduced,
+ i.e. issubclass(cls, abc) returns True for the class itself but returns
+ False for all its direct base classes. Implicit ABCs for a given class
+ (either registered or inferred from the presence of a special method like
+ __len__) are inserted directly after the last ABC explicitly listed in the
+ MRO of said class. If two implicit ABCs end up next to each other in the
+ resulting MRO, their ordering depends on the order of types in *abcs*.
+
+ """
+ for i, base in enumerate(reversed(cls.__bases__)):
+ if hasattr(base, '__abstractmethods__'):
+ boundary = len(cls.__bases__) - i
+ break # Bases up to the last explicit ABC are considered first.
+ else:
+ boundary = 0
+ abcs = list(abcs) if abcs else []
+ explicit_bases = list(cls.__bases__[:boundary])
+ abstract_bases = []
+ other_bases = list(cls.__bases__[boundary:])
+ for base in abcs:
+ if issubclass(cls, base) and not any(
+ issubclass(b, base) for b in cls.__bases__
+ ):
+ # If *cls* is the class that introduces behaviour described by
+ # an ABC *base*, insert said ABC to its MRO.
+ abstract_bases.append(base)
+ for base in abstract_bases:
+ abcs.remove(base)
+ explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases]
+ abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases]
+ other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases]
+ return _c3_merge(
+ [[cls]] +
+ explicit_c3_mros + abstract_c3_mros + other_c3_mros +
+ [explicit_bases] + [abstract_bases] + [other_bases]
+ )
+
+def _compose_mro(cls, types):
+ """Calculates the method resolution order for a given class *cls*.
+
+ Includes relevant abstract base classes (with their respective bases) from
+ the *types* iterable. Uses a modified C3 linearization algorithm.
+
+ """
+ bases = set(cls.__mro__)
+ # Remove entries which are already present in the __mro__ or unrelated.
+ def is_related(typ):
+ return (typ not in bases and hasattr(typ, '__mro__')
+ and issubclass(cls, typ))
+ types = [n for n in types if is_related(n)]
+ # Remove entries which are strict bases of other entries (they will end up
+ # in the MRO anyway.
+ def is_strict_base(typ):
+ for other in types:
+ if typ != other and typ in other.__mro__:
+ return True
+ return False
+ types = [n for n in types if not is_strict_base(n)]
+ # Subclasses of the ABCs in *types* which are also implemented by
+ # *cls* can be used to stabilize ABC ordering.
+ type_set = set(types)
+ mro = []
+ for typ in types:
+ found = []
+ for sub in typ.__subclasses__():
+ if sub not in bases and issubclass(cls, sub):
+ found.append([s for s in sub.__mro__ if s in type_set])
+ if not found:
+ mro.append(typ)
+ continue
+ # Favor subclasses with the biggest number of useful bases
+ found.sort(key=len, reverse=True)
+ for sub in found:
+ for subcls in sub:
+ if subcls not in mro:
+ mro.append(subcls)
+ return _c3_mro(cls, abcs=mro)
def _find_impl(cls, registry):
- """Returns the best matching implementation for the given class `cls` in
- `registry`. Where there is no registered implementation for a specific
- type, its method resolution order is used to find a more generic
- implementation.
+ """Returns the best matching implementation from *registry* for type *cls*.
+
+ Where there is no registered implementation for a specific type, its method
+ resolution order is used to find a more generic implementation.
- Note: if `registry` does not contain an implementation for the base
- `object` type, this function may return None.
+ Note: if *registry* does not contain an implementation for the base
+ *object* type, this function may return None.
"""
mro = _compose_mro(cls, registry.keys())
match = None
for t in mro:
if match is not None:
- # If `match` is an ABC but there is another unrelated, equally
- # matching ABC. Refuse the temptation to guess.
- if (t in registry and not issubclass(match, t)
- and match not in cls.__mro__):
+ # If *match* is an implicit ABC but there is another unrelated,
+ # equally matching implicit ABC, refuse the temptation to guess.
+ if (t in registry and t not in cls.__mro__
+ and match not in cls.__mro__
+ and not issubclass(match, t)):
raise RuntimeError("Ambiguous dispatch: {} or {}".format(
match, t))
break
@@ -69,19 +161,19 @@ def singledispatch(func):
Transforms a function into a generic function, which can have different
behaviours depending upon the type of its first argument. The decorated
function acts as the default implementation, and additional
- implementations can be registered using the 'register()' attribute of
- the generic function.
+ implementations can be registered using the register() attribute of the
+ generic function.
"""
registry = {}
dispatch_cache = WeakKeyDictionary()
cache_token = None
- def dispatch(typ):
- """generic_func.dispatch(type) -> <function implementation>
+ def dispatch(cls):
+ """generic_func.dispatch(cls) -> <function implementation>
Runs the dispatch algorithm to return the best available implementation
- for the given `type` registered on `generic_func`.
+ for the given *cls* registered on *generic_func*.
"""
nonlocal cache_token
@@ -91,26 +183,26 @@ def singledispatch(func):
dispatch_cache.clear()
cache_token = current_token
try:
- impl = dispatch_cache[typ]
+ impl = dispatch_cache[cls]
except KeyError:
try:
- impl = registry[typ]
+ impl = registry[cls]
except KeyError:
- impl = _find_impl(typ, registry)
- dispatch_cache[typ] = impl
+ impl = _find_impl(cls, registry)
+ dispatch_cache[cls] = impl
return impl
- def register(typ, func=None):
- """generic_func.register(type, func) -> func
+ def register(cls, func=None):
+ """generic_func.register(cls, func) -> func
- Registers a new implementation for the given `type` on a `generic_func`.
+ Registers a new implementation for the given *cls* on a *generic_func*.
"""
nonlocal cache_token
if func is None:
- return lambda f: register(typ, f)
- registry[typ] = func
- if cache_token is None and hasattr(typ, '__abstractmethods__'):
+ return lambda f: register(cls, f)
+ registry[cls] = func
+ if cache_token is None and hasattr(cls, '__abstractmethods__'):
cache_token = get_cache_token()
dispatch_cache.clear()
return func
diff --git a/singledispatch_helpers.py b/singledispatch_helpers.py
index 5c4aa8b..3a71831 100644
--- a/singledispatch_helpers.py
+++ b/singledispatch_helpers.py
@@ -8,6 +8,7 @@ from __future__ import unicode_literals
from abc import ABCMeta
from collections import MutableMapping, UserDict
+import sys
try:
from thread import get_ident
except ImportError:
@@ -148,3 +149,14 @@ class MappingProxyType(UserDict):
def get_cache_token():
return ABCMeta._abc_invalidation_counter
+
+
+
+class Support(object):
+ def dummy(self):
+ pass
+
+ def cpython_only(self, func):
+ if 'PyPy' in sys.version:
+ return self.dummy
+ return func
diff --git a/test_singledispatch.py b/test_singledispatch.py
index 3f243a9..0705cff 100644
--- a/test_singledispatch.py
+++ b/test_singledispatch.py
@@ -10,10 +10,18 @@ import collections
import decimal
from itertools import permutations
import singledispatch as functools
-from singledispatch_helpers import ChainMap
+from singledispatch_helpers import Support
+try:
+ from collections import ChainMap
+except ImportError:
+ from singledispatch_helpers import ChainMap
+ collections.ChainMap = ChainMap
import unittest
+support = Support()
+
+
class TestSingleDispatch(unittest.TestCase):
def test_simple_overloads(self):
@functools.singledispatch
@@ -30,29 +38,24 @@ class TestSingleDispatch(unittest.TestCase):
@functools.singledispatch
def g(obj):
return "base"
- class C:
+ class A:
pass
- class D(C):
+ class C(A):
pass
- def g_C(c):
- return "C"
- g.register(C, g_C)
- self.assertEqual(g(C()), "C")
- self.assertEqual(g(D()), "C")
-
- def test_classic_classes(self):
- @functools.singledispatch
- def g(obj):
- return "base"
- class C:
+ class B(A):
pass
- class D(C):
+ class D(C, B):
pass
- def g_C(c):
- return "C"
- g.register(C, g_C)
- self.assertEqual(g(C()), "C")
- self.assertEqual(g(D()), "C")
+ def g_A(a):
+ return "A"
+ def g_B(b):
+ return "B"
+ g.register(A, g_A)
+ g.register(B, g_B)
+ self.assertEqual(g(A()), "A")
+ self.assertEqual(g(B()), "B")
+ self.assertEqual(g(C()), "A")
+ self.assertEqual(g(D()), "B")
def test_register_decorator(self):
@functools.singledispatch
@@ -77,6 +80,7 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g.__doc__, "Simple test")
@unittest.skipUnless(decimal, 'requires _decimal')
+ @support.cpython_only
def test_c_classes(self):
@functools.singledispatch
def g(obj):
@@ -95,22 +99,55 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(rnd), ("Number got rounded",))
def test_compose_mro(self):
+ # None of the examples in this test depend on haystack ordering.
c = collections
mro = functools._compose_mro
bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
for haystack in permutations(bases):
m = mro(dict, haystack)
- self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object])
+ self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
+ c.Iterable, c.Container, object])
bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
for haystack in permutations(bases):
- m = mro(ChainMap, haystack)
- self.assertEqual(m, [ChainMap, c.MutableMapping, c.Mapping,
+ m = mro(c.ChainMap, haystack)
+ self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
+ c.Sized, c.Iterable, c.Container, object])
+
+ # If there's a generic function with implementations registered for
+ # both Sized and Container, passing a defaultdict to it results in an
+ # ambiguous dispatch which will cause a RuntimeError (see
+ # test_mro_conflicts).
+ bases = [c.Container, c.Sized, str]
+ for haystack in permutations(bases):
+ m = mro(c.defaultdict, [c.Sized, c.Container, str])
+ self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
+ object])
+
+ # MutableSequence below is registered directly on D. In other words, it
+ # preceeds MutableMapping which means single dispatch will always
+ # choose MutableSequence here.
+ class D(c.defaultdict):
+ pass
+ c.MutableSequence.register(D)
+ bases = [c.MutableSequence, c.MutableMapping]
+ for haystack in permutations(bases):
+ m = mro(D, bases)
+ self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
+ c.defaultdict, dict, c.MutableMapping,
+ c.Mapping, c.Sized, c.Iterable, c.Container,
+ object])
+
+ # Container and Callable are registered on different base classes and
+ # a generic function supporting both should always pick the Callable
+ # implementation if a C instance is passed.
+ class C(c.defaultdict):
+ def __call__(self):
+ pass
+ bases = [c.Sized, c.Callable, c.Container, c.Mapping]
+ for haystack in permutations(bases):
+ m = mro(C, haystack)
+ self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
c.Sized, c.Iterable, c.Container, object])
- # Note: The MRO order below depends on haystack ordering.
- m = mro(c.defaultdict, [c.Sized, c.Container, str])
- self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object])
- m = mro(c.defaultdict, [c.Container, c.Sized, str])
- self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object])
def test_register_abc(self):
c = collections
@@ -139,7 +176,7 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(s), "sized")
self.assertEqual(g(f), "sized")
self.assertEqual(g(t), "sized")
- g.register(ChainMap, lambda obj: "chainmap")
+ g.register(c.ChainMap, lambda obj: "chainmap")
self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
self.assertEqual(g(l), "sized")
self.assertEqual(g(s), "sized")
@@ -206,17 +243,37 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(f), "frozen-set")
self.assertEqual(g(t), "tuple")
- def test_mro_conflicts(self):
+ def test_c3_abc(self):
c = collections
+ mro = functools._c3_mro
+ class A(object):
+ pass
+ class B(A):
+ def __len__(self):
+ return 0 # implies Sized
+ @c.Container.register
+ class C(object):
+ pass
+ class D(object):
+ pass # unrelated
+ class X(D, C, B):
+ def __call__(self):
+ pass # implies Callable
+ expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
+ for abcs in permutations([c.Sized, c.Callable, c.Container]):
+ self.assertEqual(mro(X, abcs=abcs), expected)
+ # unrelated ABCs don't appear in the resulting MRO
+ many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
+ self.assertEqual(mro(X, abcs=many_abcs), expected)
+ def test_mro_conflicts(self):
+ c = collections
@functools.singledispatch
def g(arg):
return "base"
-
class O(c.Sized):
def __len__(self):
return 0
-
o = O()
self.assertEqual(g(o), "base")
g.register(c.Iterable, lambda arg: "iterable")
@@ -228,35 +285,114 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
c.Container.register(O)
self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
-
+ c.Set.register(O)
+ self.assertEqual(g(o), "set") # because c.Set is a subclass of
+ # c.Sized and c.Container
class P:
pass
-
p = P()
self.assertEqual(g(p), "base")
c.Iterable.register(P)
self.assertEqual(g(p), "iterable")
c.Container.register(P)
- with self.assertRaises(RuntimeError) as re:
+ with self.assertRaises(RuntimeError) as re_one:
g(p)
- self.assertEqual(
- str(re),
- ("Ambiguous dispatch: <class 'collections.abc.Container'> "
- "or <class 'collections.abc.Iterable'>"),
- )
-
+ self.assertIn(
+ str(re_one.exception),
+ (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+ "or <class 'collections.abc.Iterable'>"),
+ ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
+ "or <class 'collections.abc.Container'>")),
+ )
class Q(c.Sized):
def __len__(self):
return 0
-
q = Q()
self.assertEqual(g(q), "sized")
c.Iterable.register(Q)
self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
c.Set.register(Q)
self.assertEqual(g(q), "set") # because c.Set is a subclass of
- # c.Sized which is explicitly in
- # __mro__
+ # c.Sized and c.Iterable
+ @functools.singledispatch
+ def h(arg):
+ return "base"
+ @h.register(c.Sized)
+ def _(arg):
+ return "sized"
+ @h.register(c.Container)
+ def _(arg):
+ return "container"
+ # Even though Sized and Container are explicit bases of MutableMapping,
+ # this ABC is implicitly registered on defaultdict which makes all of
+ # MutableMapping's bases implicit as well from defaultdict's
+ # perspective.
+ with self.assertRaises(RuntimeError) as re_two:
+ h(c.defaultdict(lambda: 0))
+ self.assertIn(
+ str(re_two.exception),
+ (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+ "or <class 'collections.abc.Sized'>"),
+ ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
+ "or <class 'collections.abc.Container'>")),
+ )
+ class R(c.defaultdict):
+ pass
+ c.MutableSequence.register(R)
+ @functools.singledispatch
+ def i(arg):
+ return "base"
+ @i.register(c.MutableMapping)
+ def _(arg):
+ return "mapping"
+ @i.register(c.MutableSequence)
+ def _(arg):
+ return "sequence"
+ r = R()
+ self.assertEqual(i(r), "sequence")
+ class S:
+ pass
+ class T(S, c.Sized):
+ def __len__(self):
+ return 0
+ t = T()
+ self.assertEqual(h(t), "sized")
+ c.Container.register(T)
+ self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
+ class U:
+ def __len__(self):
+ return 0
+ u = U()
+ self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
+ # from the existence of __len__()
+ c.Container.register(U)
+ # There is no preference for registered versus inferred ABCs.
+ with self.assertRaises(RuntimeError) as re_three:
+ h(u)
+ self.assertIn(
+ str(re_three.exception),
+ (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+ "or <class 'collections.abc.Sized'>"),
+ ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
+ "or <class 'collections.abc.Container'>")),
+ )
+ class V(c.Sized, S):
+ def __len__(self):
+ return 0
+ @functools.singledispatch
+ def j(arg):
+ return "base"
+ @j.register(S)
+ def _(arg):
+ return "s"
+ @j.register(c.Container)
+ def _(arg):
+ return "container"
+ v = V()
+ self.assertEqual(j(v), "s")
+ c.Container.register(V)
+ self.assertEqual(j(v), "container") # because it ends up right after
+ # Sized in the MRO
def test_cache_invalidation(self):
from collections import UserDict
diff --git a/tox.ini b/tox.ini
index bcd09a3..b66e223 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
[tox]
-envlist = py32,py33
+envlist = py33
[testenv]
commands =