diff options
-rw-r--r-- | buildstream/_cas/casserver.py | 83 |
1 files changed, 82 insertions, 1 deletions
diff --git a/buildstream/_cas/casserver.py b/buildstream/_cas/casserver.py index 8f8ef4efe..f88db717a 100644 --- a/buildstream/_cas/casserver.py +++ b/buildstream/_cas/casserver.py @@ -28,12 +28,14 @@ import errno import threading import grpc +from google.protobuf.message import DecodeError import click from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc -from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc from .._protos.google.rpc import code_pb2 +from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc, \ + artifact_pb2, artifact_pb2_grpc from .._exceptions import CASError @@ -62,6 +64,7 @@ def create_server(repo, *, enable_push, max_head_size=int(10e9), min_head_size=int(2e9)): cas = CASCache(os.path.abspath(repo)) + artifactdir = os.path.join(os.path.abspath(repo), 'artifacts', 'refs') # Use max_workers default from Python 3.5+ max_workers = (os.cpu_count() or 1) * 5 @@ -81,6 +84,9 @@ def create_server(repo, *, enable_push, buildstream_pb2_grpc.add_ReferenceStorageServicer_to_server( _ReferenceStorageServicer(cas, enable_push=enable_push), server) + artifact_pb2_grpc.add_ArtifactServiceServicer_to_server( + _ArtifactServicer(cas, artifactdir), server) + return server @@ -405,6 +411,81 @@ class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer): return response +class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): + + def __init__(self, cas, artifactdir): + super().__init__() + self.cas = cas + self.artifactdir = artifactdir + os.makedirs(artifactdir, exist_ok=True) + + def GetArtifact(self, request, context): + artifact_path = os.path.join(self.artifactdir, request.cache_key) + if not os.path.exists(artifact_path): + context.abort(grpc.StatusCode.NOT_FOUND, "Artifact proto not found") + + artifact = artifact_pb2.Artifact() + with open(artifact_path, 'rb') as f: + artifact.ParseFromString(f.read()) + + files_digest = artifact.files + + # Now update mtimes of files present. + try: + self.cas.update_tree_mtime(files_digest) + except FileNotFoundError: + os.unlink(artifact_path) + context.abort(grpc.StatusCode.NOT_FOUND, + "Artifact files incomplete") + except DecodeError: + context.abort(grpc.StatusCode.NOT_FOUND, + "Artifact files not valid") + + return artifact + + def UpdateArtifact(self, request, context): + artifact = request.artifact + + # Check that the files specified are in the CAS + self._check_directory("files", artifact.files, context) + + # Unset protocol buffers don't evaluated to False but do return empty + # strings, hence str() + if str(artifact.buildtree): + self._check_directory("buildtree", artifact.buildtree, context) + + if str(artifact.public_data): + self._check_file("public data", artifact.public_data, context) + + for log_file in artifact.logs: + self._check_file("log digest", log_file.digest, context) + + # Add the artifact proto to the cas + artifact_path = os.path.join(self.artifactdir, request.cache_key) + os.makedirs(os.path.dirname(artifact_path), exist_ok=True) + with open(artifact_path, 'wb') as f: + f.write(artifact.SerializeToString()) + + return artifact + + def _check_directory(self, name, digest, context): + try: + directory = remote_execution_pb2.Directory() + with open(self.cas.objpath(digest), 'rb') as f: + directory.ParseFromString(f.read()) + except FileNotFoundError: + context.abort(grpc.StatusCode.FAILED_PRECONDITION, + "Artifact {} specified but no files found".format(name)) + except DecodeError: + context.abort(grpc.StatusCode.FAILED_PRECONDITION, + "Artifact {} specified but directory not found".format(name)) + + def _check_file(self, name, digest, context): + if not os.path.exists(self.cas.objpath(digest)): + context.abort(grpc.StatusCode.FAILED_PRECONDITION, + "Artifact {} specified but not found".format(name)) + + def _digest_from_download_resource_name(resource_name): parts = resource_name.split('/') |