summaryrefslogtreecommitdiff
path: root/singledispatch.py
blob: e1009b99b6acc675bbc94f069d0b25e1d6f14a19 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

__all__ = ['singledispatch']

from functools import update_wrapper
from weakref import WeakKeyDictionary
from singledispatch_helpers import MappingProxyType, get_cache_token

################################################################################
### singledispatch() - single-dispatch generic function decorator
################################################################################

def _compose_mro(cls, haystack):
    """Calculates the MRO for a given class `cls`, including relevant abstract
    base classes from `haystack`."""
    bases = set(cls.__mro__)
    mro = list(cls.__mro__)
    for regcls in haystack:
        if regcls in bases or not issubclass(cls, regcls):
            continue   # either present in the __mro__ already or unrelated
        for index, base in enumerate(mro):
            if not issubclass(base, regcls):
                break
        if base in bases and not issubclass(regcls, base):
            # Conflict resolution: put classes present in __mro__ and their
            # subclasses first. See test_mro_conflicts() in test_functools.py
            # for examples.
            index += 1
        mro.insert(index, regcls)
    return mro

def singledispatch(func):
    """Single-dispatch generic function decorator.

    Transforms a function into a generic function, which can have different
    behaviours depending upon the type of its first argument. The decorated
    function acts as the default implementation, and additional
    implementations can be registered using the 'register()' attribute of
    the generic function.

    """
    registry = {}
    dispatch_cache = WeakKeyDictionary()
    def ns(): pass
    ns.cache_token = None

    def dispatch(cls):
        """generic_func.dispatch(type) -> <function implementation>

        Runs the dispatch algorithm to return the best available implementation
        for the given `type` registered on `generic_func`.

        """
        if ns.cache_token is not None:
            mro = _compose_mro(cls, registry.keys())
            match = None
            for t in mro:
                if not match:
                    if t in registry:
                        match = t
                    continue
                if (t in registry and not issubclass(match, t)
                                  and match not in cls.__mro__):
                    # `match` is an ABC but there is another unrelated, equally
                    # matching ABC. Refuse the temptation to guess.
                    raise RuntimeError("Ambiguous dispatch: {0} or {1}".format(
                        match, t))
                return registry[match]
        else:
            for t in cls.__mro__:
                if t in registry:
                    return registry[t]
        return func

    def wrapper(*args, **kw):
        if ns.cache_token is not None:
            current_token = get_cache_token()
            if ns.cache_token != current_token:
                dispatch_cache.clear()
                ns.cache_token = current_token
        cls = args[0].__class__
        try:
            impl = dispatch_cache[cls]
        except KeyError:
            impl = dispatch_cache[cls] = dispatch(cls)
        return impl(*args, **kw)

    def register(typ, func=None):
        """generic_func.register(type, func) -> func

        Registers a new overload for the given `type` on a `generic_func`.

        """
        if func is None:
            return lambda f: register(typ, f)
        registry[typ] = func
        if ns.cache_token is None and hasattr(typ, '__abstractmethods__'):
            ns.cache_token = get_cache_token()
        dispatch_cache.clear()
        return func

    registry[object] = func
    wrapper.register = register
    wrapper.dispatch = dispatch
    wrapper.registry = MappingProxyType(registry)
    update_wrapper(wrapper, func)
    return wrapper