diff options
Diffstat (limited to 'src/buildstream/_cas/casserver.py')
-rw-r--r-- | src/buildstream/_cas/casserver.py | 268 |
1 files changed, 50 insertions, 218 deletions
diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py index 71d7d9071..013fb07dd 100644 --- a/src/buildstream/_cas/casserver.py +++ b/src/buildstream/_cas/casserver.py @@ -1,5 +1,6 @@ # # Copyright (C) 2018 Codethink Limited +# Copyright (C) 2020 Bloomberg Finance LP # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -26,22 +27,17 @@ import signal import sys import grpc -from google.protobuf.message import DecodeError import click +from .._protos.build.bazel.remote.asset.v1 import remote_asset_pb2_grpc from .._protos.build.bazel.remote.execution.v2 import ( remote_execution_pb2, remote_execution_pb2_grpc, ) from .._protos.google.bytestream import bytestream_pb2_grpc -from .._protos.build.buildgrid import local_cas_pb2 from .._protos.buildstream.v2 import ( buildstream_pb2, buildstream_pb2_grpc, - artifact_pb2, - artifact_pb2_grpc, - source_pb2, - source_pb2_grpc, ) # Note: We'd ideally like to avoid imports from the core codebase as @@ -115,7 +111,6 @@ def create_server(repo, *, enable_push, quota, index_only, log_level=LogLevel.Le try: root = os.path.abspath(repo) - sourcedir = os.path.join(root, "source_protos") # Use max_workers default from Python 3.5+ max_workers = (os.cpu_count() or 1) * 5 @@ -132,23 +127,16 @@ def create_server(repo, *, enable_push, quota, index_only, log_level=LogLevel.Le remote_execution_pb2_grpc.add_CapabilitiesServicer_to_server(_CapabilitiesServicer(), server) + # Remote Asset API + remote_asset_pb2_grpc.add_FetchServicer_to_server(_FetchServicer(casd_channel), server) + if enable_push: + remote_asset_pb2_grpc.add_PushServicer_to_server(_PushServicer(casd_channel), server) + + # BuildStream protocols buildstream_pb2_grpc.add_ReferenceStorageServicer_to_server( _ReferenceStorageServicer(casd_channel, root, enable_push=enable_push), server ) - artifact_pb2_grpc.add_ArtifactServiceServicer_to_server( - _ArtifactServicer(casd_channel, root, update_cas=not index_only), server - ) - - source_pb2_grpc.add_SourceServiceServicer_to_server(_SourceServicer(sourcedir), server) - - # Create up reference storage and artifact capabilities - artifact_capabilities = buildstream_pb2.ArtifactCapabilities(allow_updates=enable_push) - source_capabilities = buildstream_pb2.SourceCapabilities(allow_updates=enable_push) - buildstream_pb2_grpc.add_CapabilitiesServicer_to_server( - _BuildStreamCapabilitiesServicer(artifact_capabilities, source_capabilities), server - ) - yield server finally: @@ -295,6 +283,48 @@ class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer): return response +class _FetchServicer(remote_asset_pb2_grpc.FetchServicer): + def __init__(self, casd): + super().__init__() + self.fetch = casd.get_asset_fetch() + self.logger = logging.getLogger("buildstream._cas.casserver") + + def FetchBlob(self, request, context): + self.logger.debug("FetchBlob '%s'", request.uris) + try: + return self.fetch.FetchBlob(request) + except grpc.RpcError as err: + context.abort(err.code(), err.details()) + + def FetchDirectory(self, request, context): + self.logger.debug("FetchDirectory '%s'", request.uris) + try: + return self.fetch.FetchDirectory(request) + except grpc.RpcError as err: + context.abort(err.code(), err.details()) + + +class _PushServicer(remote_asset_pb2_grpc.PushServicer): + def __init__(self, casd): + super().__init__() + self.push = casd.get_asset_push() + self.logger = logging.getLogger("buildstream._cas.casserver") + + def PushBlob(self, request, context): + self.logger.debug("PushBlob '%s'", request.uris) + try: + return self.push.PushBlob(request) + except grpc.RpcError as err: + context.abort(err.code(), err.details()) + + def PushDirectory(self, request, context): + self.logger.debug("PushDirectory '%s'", request.uris) + try: + return self.push.PushDirectory(request) + except grpc.RpcError as err: + context.abort(err.code(), err.details()) + + class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer): def __init__(self, casd, cas_root, *, enable_push): super().__init__() @@ -393,201 +423,3 @@ class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer): response.allow_updates = self.enable_push return response - - -class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): - def __init__(self, casd, root, *, update_cas=True): - super().__init__() - self.cas = casd.get_cas() - self.local_cas = casd.get_local_cas() - self.root = root - self.artifactdir = os.path.join(root, "artifacts", "refs") - self.update_cas = update_cas - self.logger = logging.getLogger("buildstream._cas.casserver") - - # object_path(): - # - # Get the path to an object's file. - # - # Args: - # digest - The digest of the object. - # - # Returns: - # str - The path to the object's file. - # - def object_path(self, digest) -> str: - return os.path.join(self.root, "cas", "objects", digest.hash[:2], digest.hash[2:]) - - # resolve_digest(): - # - # Read the directory corresponding to a digest. - # - # Args: - # digest - The digest corresponding to a directory. - # - # Returns: - # remote_execution_pb2.Directory - The directory. - # - # Raises: - # FileNotFoundError - If the digest object doesn't exist. - def resolve_digest(self, digest): - directory = remote_execution_pb2.Directory() - with open(self.object_path(digest), "rb") as f: - directory.ParseFromString(f.read()) - return directory - - def GetArtifact(self, request, context): - self.logger.info("'%s'", request.cache_key) - artifact_path = os.path.join(self.artifactdir, request.cache_key) - if not os.path.exists(artifact_path): - context.abort(grpc.StatusCode.NOT_FOUND, "Artifact proto not found") - - artifact = artifact_pb2.Artifact() - with open(artifact_path, "rb") as f: - artifact.ParseFromString(f.read()) - - os.utime(artifact_path) - - # Artifact-only servers will not have blobs on their system, - # so we can't reasonably update their mtimes. Instead, we exit - # early, and let the CAS server deal with its blobs. - # - # FIXME: We could try to run FindMissingBlobs on the other - # server. This is tricky to do from here, of course, - # because we don't know who the other server is, but - # the client could be smart about it - but this might - # make things slower. - # - # It needs some more thought... - if not self.update_cas: - return artifact - - # Now update mtimes of files present. - try: - - if str(artifact.files): - request = local_cas_pb2.FetchTreeRequest() - request.root_digest.CopyFrom(artifact.files) - request.fetch_file_blobs = True - self.local_cas.FetchTree(request) - - if str(artifact.buildtree): - try: - request = local_cas_pb2.FetchTreeRequest() - request.root_digest.CopyFrom(artifact.buildtree) - request.fetch_file_blobs = True - self.local_cas.FetchTree(request) - except grpc.RpcError as err: - # buildtrees might not be there - if err.code() != grpc.StatusCode.NOT_FOUND: - raise - - if str(artifact.public_data): - request = remote_execution_pb2.FindMissingBlobsRequest() - d = request.blob_digests.add() - d.CopyFrom(artifact.public_data) - self.cas.FindMissingBlobs(request) - - request = remote_execution_pb2.FindMissingBlobsRequest() - for log_file in artifact.logs: - d = request.blob_digests.add() - d.CopyFrom(log_file.digest) - self.cas.FindMissingBlobs(request) - - except grpc.RpcError as err: - if err.code() == grpc.StatusCode.NOT_FOUND: - os.unlink(artifact_path) - context.abort(grpc.StatusCode.NOT_FOUND, "Artifact files incomplete") - else: - context.abort(grpc.StatusCode.NOT_FOUND, "Artifact files not valid") - - return artifact - - def UpdateArtifact(self, request, context): - self.logger.info("'%s' -> '%s'", request.artifact, request.cache_key) - artifact = request.artifact - - if self.update_cas: - # Check that the files specified are in the CAS - if str(artifact.files): - self._check_directory("files", artifact.files, context) - - # Unset protocol buffers don't evaluated to False but do return empty - # strings, hence str() - if str(artifact.public_data): - self._check_file("public data", artifact.public_data, context) - - for log_file in artifact.logs: - self._check_file("log digest", log_file.digest, context) - - # Add the artifact proto to the cas - artifact_path = os.path.join(self.artifactdir, request.cache_key) - os.makedirs(os.path.dirname(artifact_path), exist_ok=True) - with save_file_atomic(artifact_path, mode="wb") as f: - f.write(artifact.SerializeToString()) - - return artifact - - def _check_directory(self, name, digest, context): - try: - self.resolve_digest(digest) - except FileNotFoundError: - self.logger.warning("Artifact %s specified but no files found", name) - context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Artifact {} specified but no files found".format(name)) - except DecodeError: - self.logger.warning("Artifact %s specified but directory not found", name) - context.abort( - grpc.StatusCode.FAILED_PRECONDITION, "Artifact {} specified but directory not found".format(name) - ) - - def _check_file(self, name, digest, context): - if not os.path.exists(self.object_path(digest)): - context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Artifact {} specified but not found".format(name)) - - -class _BuildStreamCapabilitiesServicer(buildstream_pb2_grpc.CapabilitiesServicer): - def __init__(self, artifact_capabilities, source_capabilities): - self.artifact_capabilities = artifact_capabilities - self.source_capabilities = source_capabilities - - def GetCapabilities(self, request, context): - response = buildstream_pb2.ServerCapabilities() - response.artifact_capabilities.CopyFrom(self.artifact_capabilities) - response.source_capabilities.CopyFrom(self.source_capabilities) - return response - - -class _SourceServicer(source_pb2_grpc.SourceServiceServicer): - def __init__(self, sourcedir): - self.sourcedir = sourcedir - self.logger = logging.getLogger("buildstream._cas.casserver") - - def GetSource(self, request, context): - self.logger.info("'%s'", request.cache_key) - try: - source_proto = self._get_source(request.cache_key) - except FileNotFoundError: - context.abort(grpc.StatusCode.NOT_FOUND, "Source not found") - except DecodeError: - context.abort(grpc.StatusCode.NOT_FOUND, "Sources gives invalid directory") - - return source_proto - - def UpdateSource(self, request, context): - self.logger.info("'%s' -> '%s'", request.source, request.cache_key) - self._set_source(request.cache_key, request.source) - return request.source - - def _get_source(self, cache_key): - path = os.path.join(self.sourcedir, cache_key) - source_proto = source_pb2.Source() - with open(path, "r+b") as f: - source_proto.ParseFromString(f.read()) - os.utime(path) - return source_proto - - def _set_source(self, cache_key, source_proto): - path = os.path.join(self.sourcedir, cache_key) - os.makedirs(os.path.dirname(path), exist_ok=True) - with save_file_atomic(path, "w+b") as f: - f.write(source_proto.SerializeToString()) |