summaryrefslogtreecommitdiff
path: root/src/setuptools_scm/_entrypoints.py
blob: 9b5b09355c32efa4424db62eb5a60668cbda473b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import annotations

import warnings
from typing import Any
from typing import Callable
from typing import cast
from typing import Iterator
from typing import overload
from typing import TYPE_CHECKING

from . import version
from ._trace import trace

if TYPE_CHECKING:
    from ._config import Configuration
    from typing_extensions import Protocol
    from . import _types as _t
else:
    Configuration = Any

    class Protocol:
        pass


def _version_from_entrypoints(
    config: Configuration, fallback: bool = False
) -> version.ScmVersion | None:
    if fallback:
        entrypoint = "setuptools_scm.parse_scm_fallback"
        root = config.fallback_root
    else:
        entrypoint = "setuptools_scm.parse_scm"
        root = config.absolute_root

    from .discover import iter_matching_entrypoints

    trace("version_from_ep", entrypoint, root)
    for ep in iter_matching_entrypoints(root, entrypoint, config):
        fn = ep.load()
        maybe_version: version.ScmVersion | None = fn(root, config=config)
        trace(ep, version)
        if maybe_version is not None:
            return maybe_version
    return None


try:
    from importlib.metadata import entry_points  # type: ignore
    from importlib.metadata import EntryPoint
except ImportError:
    try:
        from importlib_metadata import entry_points
        from importlib_metadata import EntryPoint
    except ImportError:
        from collections import defaultdict

        def entry_points() -> dict[str, list[_t.EntrypointProtocol]]:
            warnings.warn(
                "importlib metadata missing, "
                "this may happen at build time for python3.7"
            )
            return defaultdict(list)

        class EntryPoint:  # type: ignore
            def __init__(self, *args: Any, **kwargs: Any):
                pass  # entry_points() already provides the warning


def iter_entry_points(
    group: str, name: str | None = None
) -> Iterator[_t.EntrypointProtocol]:
    all_eps = entry_points()
    if hasattr(all_eps, "select"):
        eps = all_eps.select(group=group)
    else:
        eps = all_eps[group]
    if name is None:
        return iter(eps)
    return (ep for ep in eps if ep.name == name)


def _get_ep(group: str, name: str) -> Any | None:
    from ._entrypoints import iter_entry_points

    for ep in iter_entry_points(group, name):
        trace("ep found:", ep.name)
        return ep.load()
    else:
        return None


def _get_from_object_reference_str(path: str) -> Any | None:
    try:
        return EntryPoint(path, path, None).load()
    except (AttributeError, ModuleNotFoundError):
        return None


def _iter_version_schemes(
    entrypoint: str,
    scheme_value: _t.VERSION_SCHEMES,
    _memo: set[object] | None = None,
) -> Iterator[Callable[[version.ScmVersion], str]]:
    if _memo is None:
        _memo = set()
    if isinstance(scheme_value, str):
        scheme_value = cast(
            "_t.VERSION_SCHEMES",
            _get_ep(entrypoint, scheme_value)
            or _get_from_object_reference_str(scheme_value),
        )

    if isinstance(scheme_value, (list, tuple)):
        for variant in scheme_value:
            if variant not in _memo:
                _memo.add(variant)
                yield from _iter_version_schemes(entrypoint, variant, _memo=_memo)
    elif callable(scheme_value):
        yield scheme_value


@overload
def _call_version_scheme(
    version: version.ScmVersion, entypoint: str, given_value: str, default: str
) -> str:
    ...


@overload
def _call_version_scheme(
    version: version.ScmVersion, entypoint: str, given_value: str, default: None
) -> str | None:
    ...


def _call_version_scheme(
    version: version.ScmVersion, entypoint: str, given_value: str, default: str | None
) -> str | None:
    for scheme in _iter_version_schemes(entypoint, given_value):
        result = scheme(version)
        if result is not None:
            return result
    return default