summaryrefslogtreecommitdiff
path: root/singledispatch.py
diff options
context:
space:
mode:
author?ukasz Langa <lukasz@langa.pl>2013-05-31 12:59:38 +0200
committer?ukasz Langa <lukasz@langa.pl>2013-05-31 12:59:38 +0200
commit7608b45137e9b06d127c25550821f11a644479dd (patch)
tree98be66fef0526496152aca2cdebac701073d5856 /singledispatch.py
parentb0e6c5214c258c0e861e058df3c63411c8eae6d2 (diff)
parent070c478c0c768a1cf94a0964c594a03a4e520f8d (diff)
downloadsingledispatch-7608b45137e9b06d127c25550821f11a644479dd.tar.gz
Merged with upstream and made compatible with 2.6 - 3.3
Diffstat (limited to 'singledispatch.py')
-rw-r--r--singledispatch.py81
1 files changed, 47 insertions, 34 deletions
diff --git a/singledispatch.py b/singledispatch.py
index e1009b9..c9e2362 100644
--- a/singledispatch.py
+++ b/singledispatch.py
@@ -18,23 +18,51 @@ from singledispatch_helpers import MappingProxyType, get_cache_token
def _compose_mro(cls, haystack):
"""Calculates the MRO for a given class `cls`, including relevant abstract
- base classes from `haystack`."""
+ base classes from `haystack`.
+
+ """
bases = set(cls.__mro__)
mro = list(cls.__mro__)
- for regcls in haystack:
- if regcls in bases or not issubclass(cls, regcls):
+ for needle in haystack:
+ if (needle in bases or not hasattr(needle, '__mro__')
+ or not issubclass(cls, needle)):
continue # either present in the __mro__ already or unrelated
for index, base in enumerate(mro):
- if not issubclass(base, regcls):
+ if not issubclass(base, needle):
break
- if base in bases and not issubclass(regcls, base):
+ if base in bases and not issubclass(needle, base):
# Conflict resolution: put classes present in __mro__ and their
# subclasses first. See test_mro_conflicts() in test_functools.py
# for examples.
index += 1
- mro.insert(index, regcls)
+ mro.insert(index, needle)
return mro
+def _find_impl(cls, registry):
+ """Returns the best matching implementation for the given class `cls` in
+ `registry`. Where there is no registered implementation for a specific
+ type, its method resolution order is used to find a more generic
+ implementation.
+
+ Note: if `registry` does not contain an implementation for the base
+ `object` type, this function may return None.
+
+ """
+ mro = _compose_mro(cls, registry.keys())
+ match = None
+ for t in mro:
+ if match is not None:
+ # If `match` is an ABC but there is another unrelated, equally
+ # matching ABC. Refuse the temptation to guess.
+ if (t in registry and not issubclass(match, t)
+ and match not in cls.__mro__):
+ raise RuntimeError("Ambiguous dispatch: {0} or {1}".format(
+ match, t))
+ break
+ if t in registry:
+ match = t
+ return registry.get(match)
+
def singledispatch(func):
"""Single-dispatch generic function decorator.
@@ -50,7 +78,7 @@ def singledispatch(func):
def ns(): pass
ns.cache_token = None
- def dispatch(cls):
+ def dispatch(typ):
"""generic_func.dispatch(type) -> <function implementation>
Runs the dispatch algorithm to return the best available implementation
@@ -58,43 +86,24 @@ def singledispatch(func):
"""
if ns.cache_token is not None:
- mro = _compose_mro(cls, registry.keys())
- match = None
- for t in mro:
- if not match:
- if t in registry:
- match = t
- continue
- if (t in registry and not issubclass(match, t)
- and match not in cls.__mro__):
- # `match` is an ABC but there is another unrelated, equally
- # matching ABC. Refuse the temptation to guess.
- raise RuntimeError("Ambiguous dispatch: {0} or {1}".format(
- match, t))
- return registry[match]
- else:
- for t in cls.__mro__:
- if t in registry:
- return registry[t]
- return func
-
- def wrapper(*args, **kw):
- if ns.cache_token is not None:
current_token = get_cache_token()
if ns.cache_token != current_token:
dispatch_cache.clear()
ns.cache_token = current_token
- cls = args[0].__class__
try:
- impl = dispatch_cache[cls]
+ impl = dispatch_cache[typ]
except KeyError:
- impl = dispatch_cache[cls] = dispatch(cls)
- return impl(*args, **kw)
+ try:
+ impl = registry[typ]
+ except KeyError:
+ impl = _find_impl(typ, registry)
+ dispatch_cache[typ] = impl
+ return impl
def register(typ, func=None):
"""generic_func.register(type, func) -> func
- Registers a new overload for the given `type` on a `generic_func`.
+ Registers a new implementation for the given `type` on a `generic_func`.
"""
if func is None:
@@ -105,10 +114,14 @@ def singledispatch(func):
dispatch_cache.clear()
return func
+ def wrapper(*args, **kw):
+ return dispatch(args[0].__class__)(*args, **kw)
+
registry[object] = func
wrapper.register = register
wrapper.dispatch = dispatch
wrapper.registry = MappingProxyType(registry)
+ wrapper._clear_cache = dispatch_cache.clear
update_wrapper(wrapper, func)
return wrapper