diff options
author | Jürg Billeter <j@bitron.ch> | 2018-03-15 10:13:14 +0100 |
---|---|---|
committer | Jürg Billeter <j@bitron.ch> | 2018-05-31 10:34:26 +0200 |
commit | ae9aab02a9c2a469040bb8aaf4bc54924924eb0d (patch) | |
tree | a78f785636f1311fdc150bd6ea06935dff0a0202 | |
parent | 79cc4c6fc69afe2faf847c7330c360926be7c868 (diff) | |
download | buildstream-ae9aab02a9c2a469040bb8aaf4bc54924924eb0d.tar.gz |
_artifactcache/cascache.py: Add remote cache support
-rw-r--r-- | buildstream/_artifactcache/artifactcache.py | 31 | ||||
-rw-r--r-- | buildstream/_artifactcache/cascache.py | 369 | ||||
-rw-r--r-- | buildstream/_project.py | 2 |
3 files changed, 391 insertions, 11 deletions
diff --git a/buildstream/_artifactcache/artifactcache.py b/buildstream/_artifactcache/artifactcache.py index 726091502..1a0d14f74 100644 --- a/buildstream/_artifactcache/artifactcache.py +++ b/buildstream/_artifactcache/artifactcache.py @@ -36,22 +36,38 @@ from .. import _yaml # push (bool): Whether we should attempt to push artifacts to this cache, # in addition to pulling from it. # -class ArtifactCacheSpec(namedtuple('ArtifactCacheSpec', 'url push')): +class ArtifactCacheSpec(namedtuple('ArtifactCacheSpec', 'url push server_cert client_key client_cert')): # _new_from_config_node # # Creates an ArtifactCacheSpec() from a YAML loaded node # @staticmethod - def _new_from_config_node(spec_node): - _yaml.node_validate(spec_node, ['url', 'push']) + def _new_from_config_node(spec_node, basedir=None): + _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert']) url = _yaml.node_get(spec_node, str, 'url') push = _yaml.node_get(spec_node, bool, 'push', default_value=False) if not url: provenance = _yaml.node_get_provenance(spec_node) raise LoadError(LoadErrorReason.INVALID_DATA, "{}: empty artifact cache URL".format(provenance)) - return ArtifactCacheSpec(url, push) + + server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None) + if server_cert and basedir: + server_cert = os.path.join(basedir, server_cert) + + client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None) + if client_key and basedir: + client_key = os.path.join(basedir, client_key) + + client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None) + if client_cert and basedir: + client_cert = os.path.join(basedir, client_cert) + + return ArtifactCacheSpec(url, push, server_cert, client_key, client_cert) + + +ArtifactCacheSpec.__new__.__defaults__ = (None, None, None) # An ArtifactCache manages artifacts. @@ -139,6 +155,7 @@ class ArtifactCache(): # # Args: # config_node (dict): The config block, which may contain the 'artifacts' key + # basedir (str): The base directory for relative paths # # Returns: # A list of ArtifactCacheSpec instances. @@ -147,15 +164,15 @@ class ArtifactCache(): # LoadError, if the config block contains invalid keys. # @staticmethod - def specs_from_config_node(config_node): + def specs_from_config_node(config_node, basedir=None): cache_specs = [] artifacts = config_node.get('artifacts', []) if isinstance(artifacts, Mapping): - cache_specs.append(ArtifactCacheSpec._new_from_config_node(artifacts)) + cache_specs.append(ArtifactCacheSpec._new_from_config_node(artifacts, basedir)) elif isinstance(artifacts, list): for spec_node in artifacts: - cache_specs.append(ArtifactCacheSpec._new_from_config_node(spec_node)) + cache_specs.append(ArtifactCacheSpec._new_from_config_node(spec_node, basedir)) else: provenance = _yaml.node_get_provenance(config_node, key='artifacts') raise _yaml.LoadError(_yaml.LoadErrorReason.INVALID_DATA, diff --git a/buildstream/_artifactcache/cascache.py b/buildstream/_artifactcache/cascache.py index 5ff045527..880d93ba4 100644 --- a/buildstream/_artifactcache/cascache.py +++ b/buildstream/_artifactcache/cascache.py @@ -19,13 +19,21 @@ # Jürg Billeter <juerg.billeter@codethink.co.uk> import hashlib +import itertools +import multiprocessing import os +import signal import stat import tempfile +from urllib.parse import urlparse -from google.devtools.remoteexecution.v1test import remote_execution_pb2 +import grpc -from .. import utils +from google.bytestream import bytestream_pb2, bytestream_pb2_grpc +from google.devtools.remoteexecution.v1test import remote_execution_pb2, remote_execution_pb2_grpc +from buildstream import buildstream_pb2, buildstream_pb2_grpc + +from .. import _signals, utils from .._exceptions import ArtifactError from . import ArtifactCache @@ -36,15 +44,28 @@ from . import ArtifactCache # # Args: # context (Context): The BuildStream context +# enable_push (bool): Whether pushing is allowed by the platform +# +# Pushing is explicitly disabled by the platform in some cases, +# like when we are falling back to functioning without using +# user namespaces. # class CASCache(ArtifactCache): - def __init__(self, context): + def __init__(self, context, *, enable_push=True): super().__init__(context) self.casdir = os.path.join(context.artifactdir, 'cas') os.makedirs(os.path.join(self.casdir, 'tmp'), exist_ok=True) + self._enable_push = enable_push + + # Per-project list of _CASRemote instances. + self._remotes = {} + + self._has_fetch_remotes = False + self._has_push_remotes = False + ################################################ # Implementation of abstract methods # ################################################ @@ -115,6 +136,205 @@ class CASCache(ArtifactCache): return modified, removed, added + def initialize_remotes(self, *, on_failure=None): + remote_specs = self.global_remote_specs + + for project in self.project_remote_specs: + remote_specs += self.project_remote_specs[project] + + remote_specs = list(utils._deduplicate(remote_specs)) + + remotes = {} + q = multiprocessing.Queue() + for remote_spec in remote_specs: + # Use subprocess to avoid creation of gRPC threads in main BuildStream process + p = multiprocessing.Process(target=self._initialize_remote, args=(remote_spec, q)) + + try: + # Keep SIGINT blocked in the child process + with _signals.blocked([signal.SIGINT], ignore=False): + p.start() + + error = q.get() + p.join() + except KeyboardInterrupt: + utils._kill_process_tree(p.pid) + raise + + if error and on_failure: + on_failure(remote_spec.url, error) + elif error: + raise ArtifactError(error) + else: + self._has_fetch_remotes = True + if remote_spec.push: + self._has_push_remotes = True + + remotes[remote_spec.url] = _CASRemote(remote_spec) + + for project in self.context.get_projects(): + remote_specs = self.global_remote_specs + if project in self.project_remote_specs: + remote_specs = list(utils._deduplicate(remote_specs + self.project_remote_specs[project])) + + project_remotes = [] + + for remote_spec in remote_specs: + # Errors are already handled in the loop above, + # skip unreachable remotes here. + if remote_spec.url not in remotes: + continue + + remote = remotes[remote_spec.url] + project_remotes.append(remote) + + self._remotes[project] = project_remotes + + def has_fetch_remotes(self, *, element=None): + if not self._has_fetch_remotes: + # No project has fetch remotes + return False + elif element is None: + # At least one (sub)project has fetch remotes + return True + else: + # Check whether the specified element's project has fetch remotes + remotes_for_project = self._remotes[element._get_project()] + return bool(remotes_for_project) + + def has_push_remotes(self, *, element=None): + if not self._has_push_remotes or not self._enable_push: + # No project has push remotes + return False + elif element is None: + # At least one (sub)project has push remotes + return True + else: + # Check whether the specified element's project has push remotes + remotes_for_project = self._remotes[element._get_project()] + return any(remote.spec.push for remote in remotes_for_project) + + def pull(self, element, key, *, progress=None): + ref = self.get_artifact_fullname(element, key) + + project = element._get_project() + + for remote in self._remotes[project]: + try: + remote.init() + + request = buildstream_pb2.GetArtifactRequest() + request.key = ref + response = remote.artifact_cache.GetArtifact(request) + + tree = remote_execution_pb2.Digest() + tree.hash = response.artifact.hash + tree.size_bytes = response.artifact.size_bytes + + self._fetch_directory(remote, tree) + + self.set_ref(ref, tree) + + # no need to pull from additional remotes + return True + + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.NOT_FOUND: + raise + + return False + + def link_key(self, element, oldkey, newkey): + oldref = self.get_artifact_fullname(element, oldkey) + newref = self.get_artifact_fullname(element, newkey) + + tree = self.resolve_ref(oldref) + + self.set_ref(newref, tree) + + def push(self, element, keys): + refs = [self.get_artifact_fullname(element, key) for key in keys] + + project = element._get_project() + + push_remotes = [r for r in self._remotes[project] if r.spec.push] + + pushed = False + + for remote in push_remotes: + remote.init() + + for ref in refs: + tree = self.resolve_ref(ref) + + # Check whether ref is already on the server in which case + # there is no need to push the artifact + try: + request = buildstream_pb2.GetArtifactRequest() + request.key = ref + response = remote.artifact_cache.GetArtifact(request) + + if response.artifact.hash == tree.hash and response.artifact.size_bytes == tree.size_bytes: + # ref is already on the server with the same tree + continue + + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.NOT_FOUND: + raise + + missing_blobs = {} + required_blobs = self._required_blobs(tree) + + # Limit size of FindMissingBlobs request + for required_blobs_group in _grouper(required_blobs, 512): + request = remote_execution_pb2.FindMissingBlobsRequest() + + for required_digest in required_blobs_group: + d = request.blob_digests.add() + d.hash = required_digest.hash + d.size_bytes = required_digest.size_bytes + + response = remote.cas.FindMissingBlobs(request) + for digest in response.missing_blob_digests: + d = remote_execution_pb2.Digest() + d.hash = digest.hash + d.size_bytes = digest.size_bytes + missing_blobs[d.hash] = d + + # Upload any blobs missing on the server + for digest in missing_blobs.values(): + def request_stream(): + resource_name = os.path.join(digest.hash, str(digest.size_bytes)) + with open(self.objpath(digest), 'rb') as f: + assert os.fstat(f.fileno()).st_size == digest.size_bytes + offset = 0 + finished = False + remaining = digest.size_bytes + while not finished: + chunk_size = min(remaining, 64 * 1024) + remaining -= chunk_size + + request = bytestream_pb2.WriteRequest() + request.write_offset = offset + # max. 64 kB chunks + request.data = f.read(chunk_size) + request.resource_name = resource_name + request.finish_write = remaining <= 0 + yield request + offset += chunk_size + finished = request.finish_write + response = remote.bytestream.Write(request_stream()) + + request = buildstream_pb2.UpdateArtifactRequest() + request.keys.append(ref) + request.artifact.hash = tree.hash + request.artifact.size_bytes = tree.size_bytes + remote.artifact_cache.UpdateArtifact(request) + + pushed = True + + return pushed + ################################################ # API Private Methods # ################################################ @@ -344,3 +564,146 @@ class CASCache(ArtifactCache): path=os.path.join(path, dir_a.directories[a].name)) a += 1 b += 1 + + def _initialize_remote(self, remote_spec, q): + try: + remote = _CASRemote(remote_spec) + remote.init() + + request = buildstream_pb2.StatusRequest() + response = remote.artifact_cache.Status(request) + + if remote_spec.push and not response.allow_updates: + q.put('Artifact server does not allow push') + else: + # No error + q.put(None) + + except Exception as e: # pylint: disable=broad-except + # Whatever happens, we need to return it to the calling process + # + q.put(str(e)) + + def _required_blobs(self, tree): + # parse directory, and recursively add blobs + d = remote_execution_pb2.Digest() + d.hash = tree.hash + d.size_bytes = tree.size_bytes + yield d + + directory = remote_execution_pb2.Directory() + + with open(self.objpath(tree), 'rb') as f: + directory.ParseFromString(f.read()) + + for filenode in directory.files: + d = remote_execution_pb2.Digest() + d.hash = filenode.digest.hash + d.size_bytes = filenode.digest.size_bytes + yield d + + for dirnode in directory.directories: + yield from self._required_blobs(dirnode.digest) + + def _fetch_blob(self, remote, digest, out): + resource_name = os.path.join(digest.hash, str(digest.size_bytes)) + request = bytestream_pb2.ReadRequest() + request.resource_name = resource_name + request.read_offset = 0 + for response in remote.bytestream.Read(request): + out.write(response.data) + + out.flush() + assert digest.size_bytes == os.fstat(out.fileno()).st_size + + def _fetch_directory(self, remote, tree): + objpath = self.objpath(tree) + if os.path.exists(objpath): + # already in local cache + return + + with tempfile.NamedTemporaryFile(dir=os.path.join(self.casdir, 'tmp')) as out: + self._fetch_blob(remote, tree, out) + + directory = remote_execution_pb2.Directory() + + with open(out.name, 'rb') as f: + directory.ParseFromString(f.read()) + + for filenode in directory.files: + fileobjpath = self.objpath(tree) + if os.path.exists(fileobjpath): + # already in local cache + continue + + with tempfile.NamedTemporaryFile(dir=os.path.join(self.casdir, 'tmp')) as f: + self._fetch_blob(remote, filenode.digest, f) + + digest = self.add_object(path=f.name) + assert digest.hash == filenode.digest.hash + + for dirnode in directory.directories: + self._fetch_directory(remote, dirnode.digest) + + # place directory blob only in final location when we've downloaded + # all referenced blobs to avoid dangling references in the repository + digest = self.add_object(path=out.name) + assert digest.hash == tree.hash + + +# Represents a single remote CAS cache. +# +class _CASRemote(): + def __init__(self, spec): + self.spec = spec + self._initialized = False + self.channel = None + self.bytestream = None + self.cas = None + self.artifact_cache = None + + def init(self): + if not self._initialized: + url = urlparse(self.spec.url) + if url.scheme == 'http': + port = url.port or 80 + self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port)) + elif url.scheme == 'https': + port = url.port or 443 + + if self.spec.server_cert: + with open(self.spec.server_cert, 'rb') as f: + server_cert_bytes = f.read() + else: + server_cert_bytes = None + + if self.spec.client_key: + with open(self.spec.client_key, 'rb') as f: + client_key_bytes = f.read() + else: + client_key_bytes = None + + if self.spec.client_cert: + with open(self.spec.client_cert, 'rb') as f: + client_cert_bytes = f.read() + else: + client_cert_bytes = None + + credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes, + private_key=client_key_bytes, + certificate_chain=client_cert_bytes) + self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials) + else: + raise ArtifactError("Unsupported URL: {}".format(self.spec.url)) + + self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel) + self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) + self.artifact_cache = buildstream_pb2_grpc.ArtifactCacheStub(self.channel) + + self._initialized = True + + +def _grouper(iterable, n): + # pylint: disable=stop-iteration-return + while True: + yield itertools.chain([next(iterable)], itertools.islice(iterable, n - 1)) diff --git a/buildstream/_project.py b/buildstream/_project.py index 5344e954e..87f14ee0d 100644 --- a/buildstream/_project.py +++ b/buildstream/_project.py @@ -296,7 +296,7 @@ class Project(): # # Load artifacts pull/push configuration for this project - self.artifact_cache_specs = ArtifactCache.specs_from_config_node(config) + self.artifact_cache_specs = ArtifactCache.specs_from_config_node(config, self.directory) # Workspace configurations self.workspaces = Workspaces(self) |