diff options
author | Tristan Maat <tristan.maat@codethink.co.uk> | 2019-08-22 17:48:34 +0100 |
---|---|---|
committer | Tristan Maat <tristan.maat@codethink.co.uk> | 2019-09-06 15:55:10 +0100 |
commit | 47a3f93d9795be6af849c112d4180f0ad50ca23b (patch) | |
tree | 2d65dd2c24d9d6bd6795f0680811cf95ae3803e4 /src/buildstream/_cas | |
parent | e71621510de7c55aae4855f8bbb64eb2755346a8 (diff) | |
download | buildstream-47a3f93d9795be6af849c112d4180f0ad50ca23b.tar.gz |
Allow splitting artifact caches
This is now split into storage/index remotes, where the former is
expected to be a CASRemote and the latter a BuildStream-specific
remote with the extensions required to store BuildStream artifact
protos.
Diffstat (limited to 'src/buildstream/_cas')
-rw-r--r-- | src/buildstream/_cas/casremote.py | 255 | ||||
-rw-r--r-- | src/buildstream/_cas/casserver.py | 56 |
2 files changed, 135 insertions, 176 deletions
diff --git a/src/buildstream/_cas/casremote.py b/src/buildstream/_cas/casremote.py index 35bbb68ec..1efed22e6 100644 --- a/src/buildstream/_cas/casremote.py +++ b/src/buildstream/_cas/casremote.py @@ -1,9 +1,3 @@ -from collections import namedtuple -import os -import multiprocessing -import signal -from urllib.parse import urlparse - import grpc from .._protos.google.rpc import code_pb2 @@ -11,9 +5,8 @@ from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remo from .._protos.build.buildgrid import local_cas_pb2 from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc -from .._exceptions import CASRemoteError, LoadError, LoadErrorReason -from .. import _signals -from .. import utils +from .._remote import BaseRemote +from .._exceptions import CASRemoteError # The default limit for gRPC messages is 4 MiB. # Limit payload to 1 MiB to leave sufficient headroom for metadata. @@ -29,13 +22,12 @@ class BlobNotFound(CASRemoteError): # Represents a single remote CAS cache. # -class CASRemote(): - def __init__(self, spec, cascache): - self.spec = spec - self._initialized = False +class CASRemote(BaseRemote): + + def __init__(self, spec, cascache, **kwargs): + super().__init__(spec, **kwargs) + self.cascache = cascache - self.channel = None - self.instance_name = None self.cas = None self.ref_storage = None self.batch_update_supported = None @@ -44,157 +36,102 @@ class CASRemote(): self.max_batch_total_size_bytes = None self.local_cas_instance_name = None - def init(self): - if not self._initialized: - server_cert_bytes = None - client_key_bytes = None - client_cert_bytes = None - - url = urlparse(self.spec.url) - if url.scheme == 'http': - port = url.port or 80 - self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port)) - elif url.scheme == 'https': - port = url.port or 443 - - if self.spec.server_cert: - with open(self.spec.server_cert, 'rb') as f: - server_cert_bytes = f.read() - - if self.spec.client_key: - with open(self.spec.client_key, 'rb') as f: - client_key_bytes = f.read() - - if self.spec.client_cert: - with open(self.spec.client_cert, 'rb') as f: - client_cert_bytes = f.read() - - credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes, - private_key=client_key_bytes, - certificate_chain=client_cert_bytes) - self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials) - else: - raise CASRemoteError("Unsupported URL: {}".format(self.spec.url)) - - self.instance_name = self.spec.instance_name or None - - self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) - self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) - self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) - - self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES - try: - request = remote_execution_pb2.GetCapabilitiesRequest() - if self.instance_name: - request.instance_name = self.instance_name - response = self.capabilities.GetCapabilities(request) - server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes - if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: - self.max_batch_total_size_bytes = server_max_batch_total_size_bytes - except grpc.RpcError as e: - # Simply use the defaults for servers that don't implement GetCapabilities() - if e.code() != grpc.StatusCode.UNIMPLEMENTED: - raise - - # Check whether the server supports BatchReadBlobs() - self.batch_read_supported = False - try: - request = remote_execution_pb2.BatchReadBlobsRequest() - if self.instance_name: - request.instance_name = self.instance_name - response = self.cas.BatchReadBlobs(request) - self.batch_read_supported = True - except grpc.RpcError as e: - if e.code() != grpc.StatusCode.UNIMPLEMENTED: - raise - - # Check whether the server supports BatchUpdateBlobs() - self.batch_update_supported = False - try: - request = remote_execution_pb2.BatchUpdateBlobsRequest() - if self.instance_name: - request.instance_name = self.instance_name - response = self.cas.BatchUpdateBlobs(request) - self.batch_update_supported = True - except grpc.RpcError as e: - if (e.code() != grpc.StatusCode.UNIMPLEMENTED and - e.code() != grpc.StatusCode.PERMISSION_DENIED): - raise - - local_cas = self.cascache._get_local_cas() - request = local_cas_pb2.GetInstanceNameForRemoteRequest() - request.url = self.spec.url - if self.spec.instance_name: - request.instance_name = self.spec.instance_name - if server_cert_bytes: - request.server_cert = server_cert_bytes - if client_key_bytes: - request.client_key = client_key_bytes - if client_cert_bytes: - request.client_cert = client_cert_bytes - response = local_cas.GetInstanceNameForRemote(request) - self.local_cas_instance_name = response.instance_name - - self._initialized = True - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - return False - - def close(self): - if self.channel: - self.channel.close() - self.channel = None - # check_remote + # _configure_protocols(): # - # Used when checking whether remote_specs work in the buildstream main - # thread, runs this in a seperate process to avoid creation of gRPC threads - # in the main BuildStream process - # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details - @classmethod - def check_remote(cls, remote_spec, cascache, q): - - def __check_remote(): - try: - remote = cls(remote_spec, cascache) - remote.init() - - request = buildstream_pb2.StatusRequest() - response = remote.ref_storage.Status(request) - - if remote_spec.push and not response.allow_updates: - q.put('CAS server does not allow push') - else: - # No error - q.put(None) - - except grpc.RpcError as e: - # str(e) is too verbose for errors reported to the user - q.put(e.details()) + # Configure remote-specific protocols. This method should *never* + # be called outside of init(). + # + def _configure_protocols(self): + self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) + self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) + self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) + + # Figure out what batch sizes the server will accept, falling + # back to our _MAX_PAYLOAD_BYTES + self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES + try: + request = remote_execution_pb2.GetCapabilitiesRequest() + if self.instance_name: + request.instance_name = self.instance_name + response = self.capabilities.GetCapabilities(request) + server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes + if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: + self.max_batch_total_size_bytes = server_max_batch_total_size_bytes + except grpc.RpcError as e: + # Simply use the defaults for servers that don't implement + # GetCapabilities() + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + + # Check whether the server supports BatchReadBlobs() + self.batch_read_supported = self._check_support( + remote_execution_pb2.BatchReadBlobsRequest, + self.cas.BatchReadBlobs + ) + + # Check whether the server supports BatchUpdateBlobs() + self.batch_update_supported = self._check_support( + remote_execution_pb2.BatchUpdateBlobsRequest, + self.cas.BatchUpdateBlobs + ) + + local_cas = self.cascache._get_local_cas() + request = local_cas_pb2.GetInstanceNameForRemoteRequest() + request.url = self.spec.url + if self.spec.instance_name: + request.instance_name = self.spec.instance_name + if self.server_cert: + request.server_cert = self.server_cert + if self.client_key: + request.client_key = self.client_key + if self.client_cert: + request.client_cert = self.client_cert + response = local_cas.GetInstanceNameForRemote(request) + self.local_cas_instance_name = response.instance_name + + # _check(): + # + # Check if this remote provides everything required for the + # particular kind of remote. This is expected to be called as part + # of check(), and must be called in a non-main process. + # + # Returns: + # (str|None): An error message, or None if no error message. + # + def _check(self): + request = buildstream_pb2.StatusRequest() + response = self.ref_storage.Status(request) - except Exception as e: # pylint: disable=broad-except - # Whatever happens, we need to return it to the calling process - # - q.put(str(e)) + if self.spec.push and not response.allow_updates: + return 'CAS server does not allow push' - p = multiprocessing.Process(target=__check_remote) + return None + # _check_support(): + # + # Figure out if a remote server supports a given method based on + # grpc.StatusCode.UNIMPLEMENTED and grpc.StatusCode.PERMISSION_DENIED. + # + # Args: + # request_type (callable): The type of request to check. + # invoker (callable): The remote method that will be invoked. + # + # Returns: + # (bool) - Whether the request is supported. + # + def _check_support(self, request_type, invoker): try: - # Keep SIGINT blocked in the child process - with _signals.blocked([signal.SIGINT], ignore=False): - p.start() + request = request_type() + if self.instance_name: + request.instance_name = self.instance_name + invoker(request) + return True + except grpc.RpcError as e: + if not e.code() in (grpc.StatusCode.UNIMPLEMENTED, grpc.StatusCode.PERMISSION_DENIED): + raise - error = q.get() - p.join() - except KeyboardInterrupt: - utils._kill_process_tree(p.pid) - raise - - return error + return False # push_message(): # diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py index ca7a21955..bb011146e 100644 --- a/src/buildstream/_cas/casserver.py +++ b/src/buildstream/_cas/casserver.py @@ -54,9 +54,10 @@ _MAX_PAYLOAD_BYTES = 1024 * 1024 # Args: # repo (str): Path to CAS repository # enable_push (bool): Whether to allow blob uploads and artifact updates +# index_only (bool): Whether to store CAS blobs or only artifacts # @contextmanager -def create_server(repo, *, enable_push, quota): +def create_server(repo, *, enable_push, quota, index_only): cas = CASCache(os.path.abspath(repo), cache_quota=quota, protect_session_blobs=False) try: @@ -67,11 +68,12 @@ def create_server(repo, *, enable_push, quota): max_workers = (os.cpu_count() or 1) * 5 server = grpc.server(futures.ThreadPoolExecutor(max_workers)) - bytestream_pb2_grpc.add_ByteStreamServicer_to_server( - _ByteStreamServicer(cas, enable_push=enable_push), server) + if not index_only: + bytestream_pb2_grpc.add_ByteStreamServicer_to_server( + _ByteStreamServicer(cas, enable_push=enable_push), server) - remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server( - _ContentAddressableStorageServicer(cas, enable_push=enable_push), server) + remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server( + _ContentAddressableStorageServicer(cas, enable_push=enable_push), server) remote_execution_pb2_grpc.add_CapabilitiesServicer_to_server( _CapabilitiesServicer(), server) @@ -80,7 +82,7 @@ def create_server(repo, *, enable_push, quota): _ReferenceStorageServicer(cas, enable_push=enable_push), server) artifact_pb2_grpc.add_ArtifactServiceServicer_to_server( - _ArtifactServicer(cas, artifactdir), server) + _ArtifactServicer(cas, artifactdir, update_cas=not index_only), server) source_pb2_grpc.add_SourceServiceServicer_to_server( _SourceServicer(sourcedir), server) @@ -110,9 +112,12 @@ def create_server(repo, *, enable_push, quota): @click.option('--quota', type=click.INT, help="Maximum disk usage in bytes", default=10e9) +@click.option('--index-only', type=click.BOOL, + help="Only provide the BuildStream artifact and source services (\"index\"), not the CAS (\"storage\")", + default=False) @click.argument('repo') def server_main(repo, port, server_key, server_cert, client_certs, enable_push, - quota): + quota, index_only): # Handle SIGTERM by calling sys.exit(0), which will raise a SystemExit exception, # properly executing cleanup code in `finally` clauses and context managers. # This is required to terminate buildbox-casd on SIGTERM. @@ -120,7 +125,8 @@ def server_main(repo, port, server_key, server_cert, client_certs, enable_push, with create_server(repo, quota=quota, - enable_push=enable_push) as server: + enable_push=enable_push, + index_only=index_only) as server: use_tls = bool(server_key) @@ -434,10 +440,11 @@ class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer): class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): - def __init__(self, cas, artifactdir): + def __init__(self, cas, artifactdir, *, update_cas=True): super().__init__() self.cas = cas self.artifactdir = artifactdir + self.update_cas = update_cas os.makedirs(artifactdir, exist_ok=True) def GetArtifact(self, request, context): @@ -449,6 +456,20 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): with open(artifact_path, 'rb') as f: artifact.ParseFromString(f.read()) + # Artifact-only servers will not have blobs on their system, + # so we can't reasonably update their mtimes. Instead, we exit + # early, and let the CAS server deal with its blobs. + # + # FIXME: We could try to run FindMissingBlobs on the other + # server. This is tricky to do from here, of course, + # because we don't know who the other server is, but + # the client could be smart about it - but this might + # make things slower. + # + # It needs some more thought... + if not self.update_cas: + return artifact + # Now update mtimes of files present. try: @@ -481,16 +502,17 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): def UpdateArtifact(self, request, context): artifact = request.artifact - # Check that the files specified are in the CAS - self._check_directory("files", artifact.files, context) + if self.update_cas: + # Check that the files specified are in the CAS + self._check_directory("files", artifact.files, context) - # Unset protocol buffers don't evaluated to False but do return empty - # strings, hence str() - if str(artifact.public_data): - self._check_file("public data", artifact.public_data, context) + # Unset protocol buffers don't evaluated to False but do return empty + # strings, hence str() + if str(artifact.public_data): + self._check_file("public data", artifact.public_data, context) - for log_file in artifact.logs: - self._check_file("log digest", log_file.digest, context) + for log_file in artifact.logs: + self._check_file("log digest", log_file.digest, context) # Add the artifact proto to the cas artifact_path = os.path.join(self.artifactdir, request.cache_key) |