summaryrefslogtreecommitdiff
path: root/src/buildstream/_includes.py
blob: bc0d7718b5c35c9a0382e235ca952083b1f976ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
from . import _yaml
from .node import MappingNode, ScalarNode, SequenceNode
from ._exceptions import LoadError, LoadErrorReason


# Includes()
#
# This takes care of processing include directives "(@)".
#
# Args:
#    loader (Loader): The Loader object
#    copy_tree (bool): Whether to make a copy, of tree in
#                      provenance. Should be true if intended to be
#                      serialized.
class Includes:
    def __init__(self, loader, *, copy_tree=False):
        self._loader = loader
        self._loaded = {}
        self._copy_tree = copy_tree

    # process()
    #
    # Process recursively include directives in a YAML node.
    #
    # Args:
    #    node (dict): A YAML node
    #    included (set): Fail for recursion if trying to load any files in this set
    #    current_loader (Loader): Use alternative loader (for junction files)
    #    only_local (bool): Whether to ignore junction files
    def process(self, node, *, included=set(), current_loader=None, only_local=False):
        if current_loader is None:
            current_loader = self._loader

        includes_node = node.get_node("(@)", allowed_types=[ScalarNode, SequenceNode], allow_none=True)

        if includes_node:
            if type(includes_node) is ScalarNode:  # pylint: disable=unidiomatic-typecheck
                includes = [includes_node.as_str()]
            else:
                includes = includes_node.as_str_list()

            del node["(@)"]

            for include in reversed(includes):
                if only_local and ":" in include:
                    continue
                try:
                    include_node, file_path, sub_loader = self._include_file(include, current_loader)
                except LoadError as e:
                    include_provenance = includes_node.get_provenance()
                    if e.reason == LoadErrorReason.MISSING_FILE:
                        message = "{}: Include block references a file that could not be found: '{}'.".format(
                            include_provenance, include
                        )
                        raise LoadError(message, LoadErrorReason.MISSING_FILE) from e
                    if e.reason == LoadErrorReason.LOADING_DIRECTORY:
                        message = "{}: Include block references a directory instead of a file: '{}'.".format(
                            include_provenance, include
                        )
                        raise LoadError(message, LoadErrorReason.LOADING_DIRECTORY) from e

                    # Otherwise, we don't know the reason, so just raise
                    raise

                if file_path in included:
                    include_provenance = includes_node.get_provenance()
                    raise LoadError(
                        "{}: trying to recursively include {}".format(include_provenance, file_path),
                        LoadErrorReason.RECURSIVE_INCLUDE,
                    )
                # Because the included node will be modified, we need
                # to copy it so that we do not modify the toplevel
                # node of the provenance.
                include_node = include_node.clone()

                try:
                    included.add(file_path)
                    self.process(include_node, included=included, current_loader=sub_loader, only_local=only_local)
                finally:
                    included.remove(file_path)

                include_node._composite_under(node)

        for value in node.values():
            self._process_value(value, included=included, current_loader=current_loader, only_local=only_local)

    # _include_file()
    #
    # Load include YAML file from with a loader.
    #
    # Args:
    #    include (str): file path relative to loader's project directory.
    #                   Can be prefixed with junctio name.
    #    loader (Loader): Loader for the current project.
    def _include_file(self, include, loader):
        shortname = include
        if ":" in include:
            junction, include = include.split(":", 1)
            junction_loader = loader._get_loader(junction)
            current_loader = junction_loader
        else:
            current_loader = loader
        project = current_loader.project
        directory = project.directory
        file_path = os.path.join(directory, include)
        key = (current_loader, file_path)
        if key not in self._loaded:
            self._loaded[key] = _yaml.load(file_path, shortname=shortname, project=project, copy_tree=self._copy_tree)
        return self._loaded[key], file_path, current_loader

    # _process_value()
    #
    # Select processing for value that could be a list or a dictionary.
    #
    # Args:
    #    value: Value to process. Can be a list or a dictionary.
    #    included (set): Fail for recursion if trying to load any files in this set
    #    current_loader (Loader): Use alternative loader (for junction files)
    #    only_local (bool): Whether to ignore junction files
    def _process_value(self, value, *, included=set(), current_loader=None, only_local=False):
        value_type = type(value)

        if value_type is MappingNode:
            self.process(value, included=included, current_loader=current_loader, only_local=only_local)
        elif value_type is SequenceNode:
            for v in value:
                self._process_value(v, included=included, current_loader=current_loader, only_local=only_local)