summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author?ukasz Langa <lukasz@langa.pl>2013-05-31 12:48:16 +0200
committer?ukasz Langa <lukasz@langa.pl>2013-05-31 12:48:16 +0200
commit070c478c0c768a1cf94a0964c594a03a4e520f8d (patch)
treecfaaf52e8908f37cd7ed82c064ad16de7bc2ee6e
parent166a442bd7fc983ecb5cb49aa68dce2a17855d1b (diff)
downloadsingledispatch-070c478c0c768a1cf94a0964c594a03a4e520f8d.tar.gz
update to the ref implementation as of 31-May-2013
* calls to `dispatch()` also cached * dispatching algorithm now now a module-level routine called `_find_impl()` with a simplified implementation and proper documentation * `dispatch()` now handles all caching-related activities * terminology more consisntent: "overload" -> "implementation"
-rw-r--r--README.rst28
-rw-r--r--singledispatch.py82
-rw-r--r--test_singledispatch.py13
3 files changed, 76 insertions, 47 deletions
diff --git a/README.rst b/README.rst
index 802ab5e..e1e345c 100644
--- a/README.rst
+++ b/README.rst
@@ -21,8 +21,9 @@ argument, create your function accordingly::
... print(arg)
To add overloaded implementations to the function, use the
-``register()`` attribute of the generic function. It takes a type
-parameter::
+``register()`` attribute of the generic function. It is a decorator,
+taking a type parameter and decorating a function implementing the
+operation for that type::
>>> @fun.register(int)
... def _(arg, verbose=False):
@@ -59,7 +60,8 @@ each variant independently::
>>> fun_num is fun
False
-When called, the generic function dispatches on the first argument::
+When called, the generic function dispatches on the type of the first
+argument::
>>> fun("Hello, world.")
Hello, world.
@@ -78,15 +80,21 @@ When called, the generic function dispatches on the first argument::
>>> fun(1.23)
0.615
-To get the implementation for a specific type, use the ``dispatch()``
-attribute::
+Where there is no registered implementation for a specific type, its
+method resolution order is used to find a more generic implementation.
+The original function decorated with ``@singledispatch`` is registered
+for the base ``object`` type, which means it is used if no better
+implementation is found.
+
+To check which implementation will the generic function choose for
+a given type, use the ``dispatch()`` attribute::
>>> fun.dispatch(float)
- <function fun_num at 0x104319058>
- >>> fun.dispatch(dict)
- <function fun at 0x103fe4788>
+ <function fun_num at 0x1035a2840>
+ >>> fun.dispatch(dict) # note: default implementation
+ <function fun at 0x103fe0000>
-To access all registered overloads, use the read-only ``registry``
+To access all registered implementations, use the read-only ``registry``
attribute::
>>> fun.registry.keys()
@@ -96,7 +104,7 @@ attribute::
>>> fun.registry[float]
<function fun_num at 0x1035a2840>
>>> fun.registry[object]
- <function fun at 0x103170788>
+ <function fun at 0x103fe0000>
The vanilla documentation is available at
http://docs.python.org/3/library/functools.html#functools.singledispatch.
diff --git a/singledispatch.py b/singledispatch.py
index 5093904..c68ad02 100644
--- a/singledispatch.py
+++ b/singledispatch.py
@@ -18,23 +18,51 @@ from singledispatch_helpers import MappingProxyType, get_cache_token
def _compose_mro(cls, haystack):
"""Calculates the MRO for a given class `cls`, including relevant abstract
- base classes from `haystack`."""
+ base classes from `haystack`.
+
+ """
bases = set(cls.__mro__)
mro = list(cls.__mro__)
- for regcls in haystack:
- if regcls in bases or not issubclass(cls, regcls):
+ for needle in haystack:
+ if (needle in bases or not hasattr(needle, '__mro__')
+ or not issubclass(cls, needle)):
continue # either present in the __mro__ already or unrelated
for index, base in enumerate(mro):
- if not issubclass(base, regcls):
+ if not issubclass(base, needle):
break
- if base in bases and not issubclass(regcls, base):
+ if base in bases and not issubclass(needle, base):
# Conflict resolution: put classes present in __mro__ and their
# subclasses first. See test_mro_conflicts() in test_functools.py
# for examples.
index += 1
- mro.insert(index, regcls)
+ mro.insert(index, needle)
return mro
+def _find_impl(cls, registry):
+ """Returns the best matching implementation for the given class `cls` in
+ `registry`. Where there is no registered implementation for a specific
+ type, its method resolution order is used to find a more generic
+ implementation.
+
+ Note: if `registry` does not contain an implementation for the base
+ `object` type, this function may return None.
+
+ """
+ mro = _compose_mro(cls, registry.keys())
+ match = None
+ for t in mro:
+ if match is not None:
+ # If `match` is an ABC but there is another unrelated, equally
+ # matching ABC. Refuse the temptation to guess.
+ if (t in registry and not issubclass(match, t)
+ and match not in cls.__mro__):
+ raise RuntimeError("Ambiguous dispatch: {} or {}".format(
+ match, t))
+ break
+ if t in registry:
+ match = t
+ return registry.get(match)
+
def singledispatch(func):
"""Single-dispatch generic function decorator.
@@ -49,52 +77,33 @@ def singledispatch(func):
dispatch_cache = WeakKeyDictionary()
cache_token = None
- def dispatch(cls):
+ def dispatch(typ):
"""generic_func.dispatch(type) -> <function implementation>
Runs the dispatch algorithm to return the best available implementation
for the given `type` registered on `generic_func`.
"""
- if cache_token is not None:
- mro = _compose_mro(cls, registry.keys())
- match = None
- for t in mro:
- if not match:
- if t in registry:
- match = t
- continue
- if (t in registry and not issubclass(match, t)
- and match not in cls.__mro__):
- # `match` is an ABC but there is another unrelated, equally
- # matching ABC. Refuse the temptation to guess.
- raise RuntimeError("Ambiguous dispatch: {} or {}".format(
- match, t))
- return registry[match]
- else:
- for t in cls.__mro__:
- if t in registry:
- return registry[t]
- return func
-
- def wrapper(*args, **kw):
nonlocal cache_token
if cache_token is not None:
current_token = get_cache_token()
if cache_token != current_token:
dispatch_cache.clear()
cache_token = current_token
- cls = args[0].__class__
try:
- impl = dispatch_cache[cls]
+ impl = dispatch_cache[typ]
except KeyError:
- impl = dispatch_cache[cls] = dispatch(cls)
- return impl(*args, **kw)
+ try:
+ impl = registry[typ]
+ except KeyError:
+ impl = _find_impl(typ, registry)
+ dispatch_cache[typ] = impl
+ return impl
def register(typ, func=None):
"""generic_func.register(type, func) -> func
- Registers a new overload for the given `type` on a `generic_func`.
+ Registers a new implementation for the given `type` on a `generic_func`.
"""
nonlocal cache_token
@@ -106,10 +115,13 @@ def singledispatch(func):
dispatch_cache.clear()
return func
+ def wrapper(*args, **kw):
+ return dispatch(args[0].__class__)(*args, **kw)
+
registry[object] = func
wrapper.register = register
wrapper.dispatch = dispatch
wrapper.registry = MappingProxyType(registry)
+ wrapper._clear_cache = dispatch_cache.clear
update_wrapper(wrapper, func)
return wrapper
-
diff --git a/test_singledispatch.py b/test_singledispatch.py
index 0234de3..3f243a9 100644
--- a/test_singledispatch.py
+++ b/test_singledispatch.py
@@ -307,12 +307,14 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(len(td), 1)
self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(td.set_ops, [dict, list, dict])
- self.assertEqual(td.data[dict], g.dispatch(dict))
+ self.assertEqual(td.data[dict],
+ functools._find_impl(dict, g.registry))
self.assertEqual(g(l), "list")
self.assertEqual(len(td), 2)
self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(td.set_ops, [dict, list, dict, list])
- self.assertEqual(td.data[list], g.dispatch(list))
+ self.assertEqual(td.data[list],
+ functools._find_impl(list, g.registry))
class X:
pass
c.MutableMapping.register(X) # Will not invalidate the cache,
@@ -335,6 +337,11 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(d), "sized")
self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+ g.dispatch(list)
+ g.dispatch(dict)
+ self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
+ list, dict])
+ self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
c.MutableSet.register(X) # Will invalidate the cache.
self.assertEqual(len(td), 2) # Stale cache.
self.assertEqual(g(l), "list")
@@ -348,6 +355,8 @@ class TestSingleDispatch(unittest.TestCase):
g.register(dict, lambda arg: "dict")
self.assertEqual(g(d), "dict")
self.assertEqual(g(l), "list")
+ g._clear_cache()
+ self.assertEqual(len(td), 0)
functools.WeakKeyDictionary = _orig_wkd