summaryrefslogtreecommitdiff
path: root/bin/ci/gitlab_gql.py
blob: bd58f320b426c3490f42ff06deabb5feb4c48b09 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
#!/usr/bin/env python3

import re
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace
from dataclasses import dataclass, field
from os import getenv
from pathlib import Path
from typing import Any, Iterable, Optional, Pattern, Union

import yaml
from filecache import DAY, filecache
from gql import Client, gql
from gql.transport.aiohttp import AIOHTTPTransport
from graphql import DocumentNode

Dag = dict[str, list[str]]
TOKEN_DIR = Path(getenv("XDG_CONFIG_HOME") or Path.home() / ".config")


def get_token_from_default_dir() -> str:
    try:
        token_file = TOKEN_DIR / "gitlab-token"
        return token_file.resolve()
    except FileNotFoundError as ex:
        print(
            f"Could not find {token_file}, please provide a token file as an argument"
        )
        raise ex


def get_project_root_dir():
    root_path = Path(__file__).parent.parent.parent.resolve()
    gitlab_file = root_path / ".gitlab-ci.yml"
    assert gitlab_file.exists()

    return root_path


@dataclass
class GitlabGQL:
    _transport: Any = field(init=False)
    client: Client = field(init=False)
    url: str = "https://gitlab.freedesktop.org/api/graphql"
    token: Optional[str] = None

    def __post_init__(self):
        self._setup_gitlab_gql_client()

    def _setup_gitlab_gql_client(self) -> Client:
        # Select your transport with a defined url endpoint
        headers = {}
        if self.token:
            headers["Authorization"] = f"Bearer {self.token}"
        self._transport = AIOHTTPTransport(url=self.url, headers=headers)

        # Create a GraphQL client using the defined transport
        self.client = Client(
            transport=self._transport, fetch_schema_from_transport=True
        )

    @filecache(DAY)
    def query(
        self, gql_file: Union[Path, str], params: dict[str, Any]
    ) -> dict[str, Any]:
        # Provide a GraphQL query
        source_path = Path(__file__).parent
        pipeline_query_file = source_path / gql_file

        query: DocumentNode
        with open(pipeline_query_file, "r") as f:
            pipeline_query = f.read()
            query = gql(pipeline_query)

        # Execute the query on the transport
        return self.client.execute(query, variable_values=params)

    def invalidate_query_cache(self):
        self.query._db.clear()


def create_job_needs_dag(
    gl_gql: GitlabGQL, params
) -> tuple[Dag, dict[str, dict[str, Any]]]:

    result = gl_gql.query("pipeline_details.gql", params)
    dag = {}
    jobs = {}
    pipeline = result["project"]["pipeline"]
    if not pipeline:
        raise RuntimeError(f"Could not find any pipelines for {params}")

    for stage in pipeline["stages"]["nodes"]:
        for stage_job in stage["groups"]["nodes"]:
            for job in stage_job["jobs"]["nodes"]:
                needs = job.pop("needs")["nodes"]
                jobs[job["name"]] = job
                dag[job["name"]] = {node["name"] for node in needs}

    for job, needs in dag.items():
        needs: set
        partial = True

        while partial:
            next_depth = {n for dn in needs for n in dag[dn]}
            partial = not needs.issuperset(next_depth)
            needs = needs.union(next_depth)

        dag[job] = needs

    return dag, jobs


def filter_dag(dag: Dag, regex: Pattern) -> Dag:
    return {job: needs for job, needs in dag.items() if re.match(regex, job)}


def print_dag(dag: Dag) -> None:
    for job, needs in dag.items():
        print(f"{job}:")
        print(f"\t{' '.join(needs)}")
        print()


def fetch_merged_yaml(gl_gql: GitlabGQL, params) -> dict[Any]:
    gitlab_yml_file = get_project_root_dir() / ".gitlab-ci.yml"
    content = Path(gitlab_yml_file).read_text().strip()
    params["content"] = content
    raw_response = gl_gql.query("job_details.gql", params)
    if merged_yaml := raw_response["ciConfig"]["mergedYaml"]:
        return yaml.safe_load(merged_yaml)

    gl_gql.invalidate_query_cache()
    raise ValueError(
        """
    Could not fetch any content for merged YAML,
    please verify if the git SHA exists in remote.
    Maybe you forgot to `git push`?  """
    )


def recursive_fill(job, relationship_field, target_data, acc_data: dict, merged_yaml):
    if relatives := job.get(relationship_field):
        if isinstance(relatives, str):
            relatives = [relatives]

        for relative in relatives:
            parent_job = merged_yaml[relative]
            acc_data = recursive_fill(parent_job, acc_data, merged_yaml)

    acc_data |= job.get(target_data, {})

    return acc_data


def get_variables(job, merged_yaml, project_path, sha) -> dict[str, str]:
    p = get_project_root_dir() / ".gitlab-ci" / "image-tags.yml"
    image_tags = yaml.safe_load(p.read_text())

    variables = image_tags["variables"]
    variables |= merged_yaml["variables"]
    variables |= job["variables"]
    variables["CI_PROJECT_PATH"] = project_path
    variables["CI_PROJECT_NAME"] = project_path.split("/")[1]
    variables["CI_REGISTRY_IMAGE"] = "registry.freedesktop.org/${CI_PROJECT_PATH}"
    variables["CI_COMMIT_SHA"] = sha

    while recurse_among_variables_space(variables):
        pass

    return variables


# Based on: https://stackoverflow.com/a/2158532/1079223
def flatten(xs):
    for x in xs:
        if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
            yield from flatten(x)
        else:
            yield x


def get_full_script(job) -> list[str]:
    script = []
    for script_part in ("before_script", "script", "after_script"):
        script.append(f"# {script_part}")
        lines = flatten(job.get(script_part, []))
        script.extend(lines)
        script.append("")

    return script


def recurse_among_variables_space(var_graph) -> bool:
    updated = False
    for var, value in var_graph.items():
        value = str(value)
        dep_vars = []
        if match := re.findall(r"(\$[{]?[\w\d_]*[}]?)", value):
            all_dep_vars = [v.lstrip("${").rstrip("}") for v in match]
            # print(value, match, all_dep_vars)
            dep_vars = [v for v in all_dep_vars if v in var_graph]

        for dep_var in dep_vars:
            dep_value = str(var_graph[dep_var])
            new_value = var_graph[var]
            new_value = new_value.replace(f"${{{dep_var}}}", dep_value)
            new_value = new_value.replace(f"${dep_var}", dep_value)
            var_graph[var] = new_value
            updated |= dep_value != new_value

    return updated


def get_job_final_definiton(job_name, merged_yaml, project_path, sha):
    job = merged_yaml[job_name]
    variables = get_variables(job, merged_yaml, project_path, sha)

    print("# --------- variables ---------------")
    for var, value in sorted(variables.items()):
        print(f"export {var}={value!r}")

    # TODO: Recurse into needs to get full script
    # TODO: maybe create a extra yaml file to avoid too much rework
    script = get_full_script(job)
    print()
    print()
    print("# --------- full script ---------------")
    print("\n".join(script))

    if image := variables.get("MESA_IMAGE"):
        print()
        print()
        print("# --------- container image ---------------")
        print(image)


def parse_args() -> Namespace:
    parser = ArgumentParser(
        formatter_class=ArgumentDefaultsHelpFormatter,
        description="CLI and library with utility functions to debug jobs via Gitlab GraphQL",
        epilog=f"""Example:
        {Path(__file__).name} --rev $(git rev-parse HEAD) --print-job-dag""",
    )
    parser.add_argument("-pp", "--project-path", type=str, default="mesa/mesa")
    parser.add_argument("--sha", "--rev", type=str, required=True)
    parser.add_argument(
        "--regex",
        type=str,
        required=False,
        help="Regex pattern for the job name to be considered",
    )
    parser.add_argument("--print-dag", action="store_true", help="Print job needs DAG")
    parser.add_argument(
        "--print-merged-yaml",
        action="store_true",
        help="Print the resulting YAML for the specific SHA",
    )
    parser.add_argument(
        "--print-job-manifest", type=str, help="Print the resulting job data"
    )
    parser.add_argument(
        "--gitlab-token-file",
        type=str,
        default=get_token_from_default_dir(),
        help="force GitLab token, otherwise it's read from $XDG_CONFIG_HOME/gitlab-token",
    )

    args = parser.parse_args()
    args.gitlab_token = Path(args.gitlab_token_file).read_text()
    return args


def main():
    args = parse_args()
    gl_gql = GitlabGQL(token=args.gitlab_token)

    if args.print_dag:
        dag, jobs = create_job_needs_dag(
            gl_gql, {"projectPath": args.project_path, "sha": args.sha}
        )

        if args.regex:
            dag = filter_dag(dag, re.compile(args.regex))
        print_dag(dag)

    if args.print_merged_yaml:
        print(
            fetch_merged_yaml(
                gl_gql, {"projectPath": args.project_path, "sha": args.sha}
            )
        )

    if args.print_job_manifest:
        merged_yaml = fetch_merged_yaml(
            gl_gql, {"projectPath": args.project_path, "sha": args.sha}
        )
        get_job_final_definiton(
            args.print_job_manifest, merged_yaml, args.project_path, args.sha
        )


if __name__ == "__main__":
    main()