summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author?ukasz Langa <lukasz@langa.pl>2013-05-31 12:59:38 +0200
committer?ukasz Langa <lukasz@langa.pl>2013-05-31 12:59:38 +0200
commit7608b45137e9b06d127c25550821f11a644479dd (patch)
tree98be66fef0526496152aca2cdebac701073d5856
parentb0e6c5214c258c0e861e058df3c63411c8eae6d2 (diff)
parent070c478c0c768a1cf94a0964c594a03a4e520f8d (diff)
downloadsingledispatch-7608b45137e9b06d127c25550821f11a644479dd.tar.gz
Merged with upstream and made compatible with 2.6 - 3.3
-rw-r--r--README.rst28
-rw-r--r--singledispatch.py81
-rw-r--r--test_singledispatch.py13
3 files changed, 76 insertions, 46 deletions
diff --git a/README.rst b/README.rst
index 802ab5e..e1e345c 100644
--- a/README.rst
+++ b/README.rst
@@ -21,8 +21,9 @@ argument, create your function accordingly::
... print(arg)
To add overloaded implementations to the function, use the
-``register()`` attribute of the generic function. It takes a type
-parameter::
+``register()`` attribute of the generic function. It is a decorator,
+taking a type parameter and decorating a function implementing the
+operation for that type::
>>> @fun.register(int)
... def _(arg, verbose=False):
@@ -59,7 +60,8 @@ each variant independently::
>>> fun_num is fun
False
-When called, the generic function dispatches on the first argument::
+When called, the generic function dispatches on the type of the first
+argument::
>>> fun("Hello, world.")
Hello, world.
@@ -78,15 +80,21 @@ When called, the generic function dispatches on the first argument::
>>> fun(1.23)
0.615
-To get the implementation for a specific type, use the ``dispatch()``
-attribute::
+Where there is no registered implementation for a specific type, its
+method resolution order is used to find a more generic implementation.
+The original function decorated with ``@singledispatch`` is registered
+for the base ``object`` type, which means it is used if no better
+implementation is found.
+
+To check which implementation will the generic function choose for
+a given type, use the ``dispatch()`` attribute::
>>> fun.dispatch(float)
- <function fun_num at 0x104319058>
- >>> fun.dispatch(dict)
- <function fun at 0x103fe4788>
+ <function fun_num at 0x1035a2840>
+ >>> fun.dispatch(dict) # note: default implementation
+ <function fun at 0x103fe0000>
-To access all registered overloads, use the read-only ``registry``
+To access all registered implementations, use the read-only ``registry``
attribute::
>>> fun.registry.keys()
@@ -96,7 +104,7 @@ attribute::
>>> fun.registry[float]
<function fun_num at 0x1035a2840>
>>> fun.registry[object]
- <function fun at 0x103170788>
+ <function fun at 0x103fe0000>
The vanilla documentation is available at
http://docs.python.org/3/library/functools.html#functools.singledispatch.
diff --git a/singledispatch.py b/singledispatch.py
index e1009b9..c9e2362 100644
--- a/singledispatch.py
+++ b/singledispatch.py
@@ -18,23 +18,51 @@ from singledispatch_helpers import MappingProxyType, get_cache_token
def _compose_mro(cls, haystack):
"""Calculates the MRO for a given class `cls`, including relevant abstract
- base classes from `haystack`."""
+ base classes from `haystack`.
+
+ """
bases = set(cls.__mro__)
mro = list(cls.__mro__)
- for regcls in haystack:
- if regcls in bases or not issubclass(cls, regcls):
+ for needle in haystack:
+ if (needle in bases or not hasattr(needle, '__mro__')
+ or not issubclass(cls, needle)):
continue # either present in the __mro__ already or unrelated
for index, base in enumerate(mro):
- if not issubclass(base, regcls):
+ if not issubclass(base, needle):
break
- if base in bases and not issubclass(regcls, base):
+ if base in bases and not issubclass(needle, base):
# Conflict resolution: put classes present in __mro__ and their
# subclasses first. See test_mro_conflicts() in test_functools.py
# for examples.
index += 1
- mro.insert(index, regcls)
+ mro.insert(index, needle)
return mro
+def _find_impl(cls, registry):
+ """Returns the best matching implementation for the given class `cls` in
+ `registry`. Where there is no registered implementation for a specific
+ type, its method resolution order is used to find a more generic
+ implementation.
+
+ Note: if `registry` does not contain an implementation for the base
+ `object` type, this function may return None.
+
+ """
+ mro = _compose_mro(cls, registry.keys())
+ match = None
+ for t in mro:
+ if match is not None:
+ # If `match` is an ABC but there is another unrelated, equally
+ # matching ABC. Refuse the temptation to guess.
+ if (t in registry and not issubclass(match, t)
+ and match not in cls.__mro__):
+ raise RuntimeError("Ambiguous dispatch: {0} or {1}".format(
+ match, t))
+ break
+ if t in registry:
+ match = t
+ return registry.get(match)
+
def singledispatch(func):
"""Single-dispatch generic function decorator.
@@ -50,7 +78,7 @@ def singledispatch(func):
def ns(): pass
ns.cache_token = None
- def dispatch(cls):
+ def dispatch(typ):
"""generic_func.dispatch(type) -> <function implementation>
Runs the dispatch algorithm to return the best available implementation
@@ -58,43 +86,24 @@ def singledispatch(func):
"""
if ns.cache_token is not None:
- mro = _compose_mro(cls, registry.keys())
- match = None
- for t in mro:
- if not match:
- if t in registry:
- match = t
- continue
- if (t in registry and not issubclass(match, t)
- and match not in cls.__mro__):
- # `match` is an ABC but there is another unrelated, equally
- # matching ABC. Refuse the temptation to guess.
- raise RuntimeError("Ambiguous dispatch: {0} or {1}".format(
- match, t))
- return registry[match]
- else:
- for t in cls.__mro__:
- if t in registry:
- return registry[t]
- return func
-
- def wrapper(*args, **kw):
- if ns.cache_token is not None:
current_token = get_cache_token()
if ns.cache_token != current_token:
dispatch_cache.clear()
ns.cache_token = current_token
- cls = args[0].__class__
try:
- impl = dispatch_cache[cls]
+ impl = dispatch_cache[typ]
except KeyError:
- impl = dispatch_cache[cls] = dispatch(cls)
- return impl(*args, **kw)
+ try:
+ impl = registry[typ]
+ except KeyError:
+ impl = _find_impl(typ, registry)
+ dispatch_cache[typ] = impl
+ return impl
def register(typ, func=None):
"""generic_func.register(type, func) -> func
- Registers a new overload for the given `type` on a `generic_func`.
+ Registers a new implementation for the given `type` on a `generic_func`.
"""
if func is None:
@@ -105,10 +114,14 @@ def singledispatch(func):
dispatch_cache.clear()
return func
+ def wrapper(*args, **kw):
+ return dispatch(args[0].__class__)(*args, **kw)
+
registry[object] = func
wrapper.register = register
wrapper.dispatch = dispatch
wrapper.registry = MappingProxyType(registry)
+ wrapper._clear_cache = dispatch_cache.clear
update_wrapper(wrapper, func)
return wrapper
diff --git a/test_singledispatch.py b/test_singledispatch.py
index 484e88e..60872f8 100644
--- a/test_singledispatch.py
+++ b/test_singledispatch.py
@@ -313,12 +313,14 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(len(td), 1)
self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(td.set_ops, [dict, list, dict])
- self.assertEqual(td.data[dict], g.dispatch(dict))
+ self.assertEqual(td.data[dict],
+ functools._find_impl(dict, g.registry))
self.assertEqual(g(l), "list")
self.assertEqual(len(td), 2)
self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(td.set_ops, [dict, list, dict, list])
- self.assertEqual(td.data[list], g.dispatch(list))
+ self.assertEqual(td.data[list],
+ functools._find_impl(list, g.registry))
class X(object):
pass
c.MutableMapping.register(X) # Will not invalidate the cache,
@@ -341,6 +343,11 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(d), "sized")
self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+ g.dispatch(list)
+ g.dispatch(dict)
+ self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
+ list, dict])
+ self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
c.MutableSet.register(X) # Will invalidate the cache.
self.assertEqual(len(td), 2) # Stale cache.
self.assertEqual(g(l), "list")
@@ -354,6 +361,8 @@ class TestSingleDispatch(unittest.TestCase):
g.register(dict, lambda arg: "dict")
self.assertEqual(g(d), "dict")
self.assertEqual(g(l), "list")
+ g._clear_cache()
+ self.assertEqual(len(td), 0)
functools.WeakKeyDictionary = _orig_wkd