diff options
author | ?ukasz Langa <lukasz@langa.pl> | 2013-05-31 12:48:16 +0200 |
---|---|---|
committer | ?ukasz Langa <lukasz@langa.pl> | 2013-05-31 12:48:16 +0200 |
commit | 070c478c0c768a1cf94a0964c594a03a4e520f8d (patch) | |
tree | cfaaf52e8908f37cd7ed82c064ad16de7bc2ee6e | |
parent | 166a442bd7fc983ecb5cb49aa68dce2a17855d1b (diff) | |
download | singledispatch-070c478c0c768a1cf94a0964c594a03a4e520f8d.tar.gz |
update to the ref implementation as of 31-May-2013
* calls to `dispatch()` also cached
* dispatching algorithm now now a module-level routine called `_find_impl()`
with a simplified implementation and proper documentation
* `dispatch()` now handles all caching-related activities
* terminology more consisntent: "overload" -> "implementation"
-rw-r--r-- | README.rst | 28 | ||||
-rw-r--r-- | singledispatch.py | 82 | ||||
-rw-r--r-- | test_singledispatch.py | 13 |
3 files changed, 76 insertions, 47 deletions
@@ -21,8 +21,9 @@ argument, create your function accordingly:: ... print(arg) To add overloaded implementations to the function, use the -``register()`` attribute of the generic function. It takes a type -parameter:: +``register()`` attribute of the generic function. It is a decorator, +taking a type parameter and decorating a function implementing the +operation for that type:: >>> @fun.register(int) ... def _(arg, verbose=False): @@ -59,7 +60,8 @@ each variant independently:: >>> fun_num is fun False -When called, the generic function dispatches on the first argument:: +When called, the generic function dispatches on the type of the first +argument:: >>> fun("Hello, world.") Hello, world. @@ -78,15 +80,21 @@ When called, the generic function dispatches on the first argument:: >>> fun(1.23) 0.615 -To get the implementation for a specific type, use the ``dispatch()`` -attribute:: +Where there is no registered implementation for a specific type, its +method resolution order is used to find a more generic implementation. +The original function decorated with ``@singledispatch`` is registered +for the base ``object`` type, which means it is used if no better +implementation is found. + +To check which implementation will the generic function choose for +a given type, use the ``dispatch()`` attribute:: >>> fun.dispatch(float) - <function fun_num at 0x104319058> - >>> fun.dispatch(dict) - <function fun at 0x103fe4788> + <function fun_num at 0x1035a2840> + >>> fun.dispatch(dict) # note: default implementation + <function fun at 0x103fe0000> -To access all registered overloads, use the read-only ``registry`` +To access all registered implementations, use the read-only ``registry`` attribute:: >>> fun.registry.keys() @@ -96,7 +104,7 @@ attribute:: >>> fun.registry[float] <function fun_num at 0x1035a2840> >>> fun.registry[object] - <function fun at 0x103170788> + <function fun at 0x103fe0000> The vanilla documentation is available at http://docs.python.org/3/library/functools.html#functools.singledispatch. diff --git a/singledispatch.py b/singledispatch.py index 5093904..c68ad02 100644 --- a/singledispatch.py +++ b/singledispatch.py @@ -18,23 +18,51 @@ from singledispatch_helpers import MappingProxyType, get_cache_token def _compose_mro(cls, haystack): """Calculates the MRO for a given class `cls`, including relevant abstract - base classes from `haystack`.""" + base classes from `haystack`. + + """ bases = set(cls.__mro__) mro = list(cls.__mro__) - for regcls in haystack: - if regcls in bases or not issubclass(cls, regcls): + for needle in haystack: + if (needle in bases or not hasattr(needle, '__mro__') + or not issubclass(cls, needle)): continue # either present in the __mro__ already or unrelated for index, base in enumerate(mro): - if not issubclass(base, regcls): + if not issubclass(base, needle): break - if base in bases and not issubclass(regcls, base): + if base in bases and not issubclass(needle, base): # Conflict resolution: put classes present in __mro__ and their # subclasses first. See test_mro_conflicts() in test_functools.py # for examples. index += 1 - mro.insert(index, regcls) + mro.insert(index, needle) return mro +def _find_impl(cls, registry): + """Returns the best matching implementation for the given class `cls` in + `registry`. Where there is no registered implementation for a specific + type, its method resolution order is used to find a more generic + implementation. + + Note: if `registry` does not contain an implementation for the base + `object` type, this function may return None. + + """ + mro = _compose_mro(cls, registry.keys()) + match = None + for t in mro: + if match is not None: + # If `match` is an ABC but there is another unrelated, equally + # matching ABC. Refuse the temptation to guess. + if (t in registry and not issubclass(match, t) + and match not in cls.__mro__): + raise RuntimeError("Ambiguous dispatch: {} or {}".format( + match, t)) + break + if t in registry: + match = t + return registry.get(match) + def singledispatch(func): """Single-dispatch generic function decorator. @@ -49,52 +77,33 @@ def singledispatch(func): dispatch_cache = WeakKeyDictionary() cache_token = None - def dispatch(cls): + def dispatch(typ): """generic_func.dispatch(type) -> <function implementation> Runs the dispatch algorithm to return the best available implementation for the given `type` registered on `generic_func`. """ - if cache_token is not None: - mro = _compose_mro(cls, registry.keys()) - match = None - for t in mro: - if not match: - if t in registry: - match = t - continue - if (t in registry and not issubclass(match, t) - and match not in cls.__mro__): - # `match` is an ABC but there is another unrelated, equally - # matching ABC. Refuse the temptation to guess. - raise RuntimeError("Ambiguous dispatch: {} or {}".format( - match, t)) - return registry[match] - else: - for t in cls.__mro__: - if t in registry: - return registry[t] - return func - - def wrapper(*args, **kw): nonlocal cache_token if cache_token is not None: current_token = get_cache_token() if cache_token != current_token: dispatch_cache.clear() cache_token = current_token - cls = args[0].__class__ try: - impl = dispatch_cache[cls] + impl = dispatch_cache[typ] except KeyError: - impl = dispatch_cache[cls] = dispatch(cls) - return impl(*args, **kw) + try: + impl = registry[typ] + except KeyError: + impl = _find_impl(typ, registry) + dispatch_cache[typ] = impl + return impl def register(typ, func=None): """generic_func.register(type, func) -> func - Registers a new overload for the given `type` on a `generic_func`. + Registers a new implementation for the given `type` on a `generic_func`. """ nonlocal cache_token @@ -106,10 +115,13 @@ def singledispatch(func): dispatch_cache.clear() return func + def wrapper(*args, **kw): + return dispatch(args[0].__class__)(*args, **kw) + registry[object] = func wrapper.register = register wrapper.dispatch = dispatch wrapper.registry = MappingProxyType(registry) + wrapper._clear_cache = dispatch_cache.clear update_wrapper(wrapper, func) return wrapper - diff --git a/test_singledispatch.py b/test_singledispatch.py index 0234de3..3f243a9 100644 --- a/test_singledispatch.py +++ b/test_singledispatch.py @@ -307,12 +307,14 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(len(td), 1) self.assertEqual(td.get_ops, [list, dict]) self.assertEqual(td.set_ops, [dict, list, dict]) - self.assertEqual(td.data[dict], g.dispatch(dict)) + self.assertEqual(td.data[dict], + functools._find_impl(dict, g.registry)) self.assertEqual(g(l), "list") self.assertEqual(len(td), 2) self.assertEqual(td.get_ops, [list, dict]) self.assertEqual(td.set_ops, [dict, list, dict, list]) - self.assertEqual(td.data[list], g.dispatch(list)) + self.assertEqual(td.data[list], + functools._find_impl(list, g.registry)) class X: pass c.MutableMapping.register(X) # Will not invalidate the cache, @@ -335,6 +337,11 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(d), "sized") self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) + g.dispatch(list) + g.dispatch(dict) + self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, + list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) c.MutableSet.register(X) # Will invalidate the cache. self.assertEqual(len(td), 2) # Stale cache. self.assertEqual(g(l), "list") @@ -348,6 +355,8 @@ class TestSingleDispatch(unittest.TestCase): g.register(dict, lambda arg: "dict") self.assertEqual(g(d), "dict") self.assertEqual(g(l), "list") + g._clear_cache() + self.assertEqual(len(td), 0) functools.WeakKeyDictionary = _orig_wkd |