summaryrefslogtreecommitdiff
path: root/src/setuptools_scm/_entrypoints.py
blob: 62a18e13a9dd8036fa157b9461241af01c8290b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from __future__ import annotations

from typing import Any
from typing import Callable
from typing import cast
from typing import Iterator
from typing import overload
from typing import TYPE_CHECKING

from typing_extensions import Protocol

from . import _log
from . import version

if TYPE_CHECKING:
    from ._config import Configuration
    from . import _types as _t


log = _log.log.getChild("entrypoints")


class EntrypointProtocol(Protocol):
    name: str

    def load(self) -> Any:
        pass


def _version_from_entrypoints(
    config: Configuration, fallback: bool = False
) -> version.ScmVersion | None:
    if fallback:
        entrypoint = "setuptools_scm.parse_scm_fallback"
        root = config.fallback_root
    else:
        entrypoint = "setuptools_scm.parse_scm"
        root = config.absolute_root

    from .discover import iter_matching_entrypoints

    log.debug("version_from_ep %s in %s", entrypoint, root)
    for ep in iter_matching_entrypoints(root, entrypoint, config):
        fn = ep.load()
        maybe_version: version.ScmVersion | None = fn(root, config=config)
        log.debug("%s found %r", ep, maybe_version)
        if maybe_version is not None:
            return maybe_version
    return None


try:
    from importlib_metadata import entry_points
    from importlib_metadata import EntryPoint
except ImportError:
    from importlib.metadata import entry_points  # type: ignore [no-redef, import]
    from importlib.metadata import EntryPoint  # type: ignore [no-redef]


def iter_entry_points(
    group: str, name: str | None = None
) -> Iterator[EntrypointProtocol]:
    eps = entry_points(group=group)
    res = (
        eps
        if name is None
        else eps.select(  # type: ignore [no-untyped-call]
            name=name,
        )
    )
    return cast(Iterator[EntrypointProtocol], iter(res))


def _get_ep(group: str, name: str) -> Any | None:
    for ep in iter_entry_points(group, name):
        log.debug("ep found: %s", ep.name)
        return ep.load()
    else:
        return None


def _get_from_object_reference_str(path: str, group: str) -> Any | None:
    # todo: remove for importlib native spelling
    ep: EntrypointProtocol = EntryPoint(path, path, group)
    try:
        return ep.load()
    except (AttributeError, ModuleNotFoundError):
        return None


def _iter_version_schemes(
    entrypoint: str,
    scheme_value: _t.VERSION_SCHEMES,
    _memo: set[object] | None = None,
) -> Iterator[Callable[[version.ScmVersion], str]]:
    if _memo is None:
        _memo = set()
    if isinstance(scheme_value, str):
        scheme_value = cast(
            "_t.VERSION_SCHEMES",
            _get_ep(entrypoint, scheme_value)
            or _get_from_object_reference_str(scheme_value, entrypoint),
        )

    if isinstance(scheme_value, (list, tuple)):
        for variant in scheme_value:
            if variant not in _memo:
                _memo.add(variant)
                yield from _iter_version_schemes(entrypoint, variant, _memo=_memo)
    elif callable(scheme_value):
        yield scheme_value


@overload
def _call_version_scheme(
    version: version.ScmVersion, entrypoint: str, given_value: str, default: str
) -> str:
    ...


@overload
def _call_version_scheme(
    version: version.ScmVersion, entrypoint: str, given_value: str, default: None
) -> str | None:
    ...


def _call_version_scheme(
    version: version.ScmVersion, entrypoint: str, given_value: str, default: str | None
) -> str | None:
    for scheme in _iter_version_schemes(entrypoint, given_value):
        result = scheme(version)
        if result is not None:
            return result
    return default