From 070c478c0c768a1cf94a0964c594a03a4e520f8d Mon Sep 17 00:00:00 2001 From: ?ukasz Langa Date: Fri, 31 May 2013 12:48:16 +0200 Subject: update to the ref implementation as of 31-May-2013 * calls to `dispatch()` also cached * dispatching algorithm now now a module-level routine called `_find_impl()` with a simplified implementation and proper documentation * `dispatch()` now handles all caching-related activities * terminology more consisntent: "overload" -> "implementation" --- singledispatch.py | 82 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 47 insertions(+), 35 deletions(-) (limited to 'singledispatch.py') diff --git a/singledispatch.py b/singledispatch.py index 5093904..c68ad02 100644 --- a/singledispatch.py +++ b/singledispatch.py @@ -18,23 +18,51 @@ from singledispatch_helpers import MappingProxyType, get_cache_token def _compose_mro(cls, haystack): """Calculates the MRO for a given class `cls`, including relevant abstract - base classes from `haystack`.""" + base classes from `haystack`. + + """ bases = set(cls.__mro__) mro = list(cls.__mro__) - for regcls in haystack: - if regcls in bases or not issubclass(cls, regcls): + for needle in haystack: + if (needle in bases or not hasattr(needle, '__mro__') + or not issubclass(cls, needle)): continue # either present in the __mro__ already or unrelated for index, base in enumerate(mro): - if not issubclass(base, regcls): + if not issubclass(base, needle): break - if base in bases and not issubclass(regcls, base): + if base in bases and not issubclass(needle, base): # Conflict resolution: put classes present in __mro__ and their # subclasses first. See test_mro_conflicts() in test_functools.py # for examples. index += 1 - mro.insert(index, regcls) + mro.insert(index, needle) return mro +def _find_impl(cls, registry): + """Returns the best matching implementation for the given class `cls` in + `registry`. Where there is no registered implementation for a specific + type, its method resolution order is used to find a more generic + implementation. + + Note: if `registry` does not contain an implementation for the base + `object` type, this function may return None. + + """ + mro = _compose_mro(cls, registry.keys()) + match = None + for t in mro: + if match is not None: + # If `match` is an ABC but there is another unrelated, equally + # matching ABC. Refuse the temptation to guess. + if (t in registry and not issubclass(match, t) + and match not in cls.__mro__): + raise RuntimeError("Ambiguous dispatch: {} or {}".format( + match, t)) + break + if t in registry: + match = t + return registry.get(match) + def singledispatch(func): """Single-dispatch generic function decorator. @@ -49,52 +77,33 @@ def singledispatch(func): dispatch_cache = WeakKeyDictionary() cache_token = None - def dispatch(cls): + def dispatch(typ): """generic_func.dispatch(type) -> Runs the dispatch algorithm to return the best available implementation for the given `type` registered on `generic_func`. """ - if cache_token is not None: - mro = _compose_mro(cls, registry.keys()) - match = None - for t in mro: - if not match: - if t in registry: - match = t - continue - if (t in registry and not issubclass(match, t) - and match not in cls.__mro__): - # `match` is an ABC but there is another unrelated, equally - # matching ABC. Refuse the temptation to guess. - raise RuntimeError("Ambiguous dispatch: {} or {}".format( - match, t)) - return registry[match] - else: - for t in cls.__mro__: - if t in registry: - return registry[t] - return func - - def wrapper(*args, **kw): nonlocal cache_token if cache_token is not None: current_token = get_cache_token() if cache_token != current_token: dispatch_cache.clear() cache_token = current_token - cls = args[0].__class__ try: - impl = dispatch_cache[cls] + impl = dispatch_cache[typ] except KeyError: - impl = dispatch_cache[cls] = dispatch(cls) - return impl(*args, **kw) + try: + impl = registry[typ] + except KeyError: + impl = _find_impl(typ, registry) + dispatch_cache[typ] = impl + return impl def register(typ, func=None): """generic_func.register(type, func) -> func - Registers a new overload for the given `type` on a `generic_func`. + Registers a new implementation for the given `type` on a `generic_func`. """ nonlocal cache_token @@ -106,10 +115,13 @@ def singledispatch(func): dispatch_cache.clear() return func + def wrapper(*args, **kw): + return dispatch(args[0].__class__)(*args, **kw) + registry[object] = func wrapper.register = register wrapper.dispatch = dispatch wrapper.registry = MappingProxyType(registry) + wrapper._clear_cache = dispatch_cache.clear update_wrapper(wrapper, func) return wrapper - -- cgit v1.2.1