summaryrefslogtreecommitdiff
path: root/src/buildstream/_cas
diff options
context:
space:
mode:
authorTristan Maat <tristan.maat@codethink.co.uk>2019-08-22 17:48:34 +0100
committerTristan Maat <tristan.maat@codethink.co.uk>2019-09-06 15:55:10 +0100
commit47a3f93d9795be6af849c112d4180f0ad50ca23b (patch)
tree2d65dd2c24d9d6bd6795f0680811cf95ae3803e4 /src/buildstream/_cas
parente71621510de7c55aae4855f8bbb64eb2755346a8 (diff)
downloadbuildstream-47a3f93d9795be6af849c112d4180f0ad50ca23b.tar.gz
Allow splitting artifact caches
This is now split into storage/index remotes, where the former is expected to be a CASRemote and the latter a BuildStream-specific remote with the extensions required to store BuildStream artifact protos.
Diffstat (limited to 'src/buildstream/_cas')
-rw-r--r--src/buildstream/_cas/casremote.py255
-rw-r--r--src/buildstream/_cas/casserver.py56
2 files changed, 135 insertions, 176 deletions
diff --git a/src/buildstream/_cas/casremote.py b/src/buildstream/_cas/casremote.py
index 35bbb68ec..1efed22e6 100644
--- a/src/buildstream/_cas/casremote.py
+++ b/src/buildstream/_cas/casremote.py
@@ -1,9 +1,3 @@
-from collections import namedtuple
-import os
-import multiprocessing
-import signal
-from urllib.parse import urlparse
-
import grpc
from .._protos.google.rpc import code_pb2
@@ -11,9 +5,8 @@ from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remo
from .._protos.build.buildgrid import local_cas_pb2
from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc
-from .._exceptions import CASRemoteError, LoadError, LoadErrorReason
-from .. import _signals
-from .. import utils
+from .._remote import BaseRemote
+from .._exceptions import CASRemoteError
# The default limit for gRPC messages is 4 MiB.
# Limit payload to 1 MiB to leave sufficient headroom for metadata.
@@ -29,13 +22,12 @@ class BlobNotFound(CASRemoteError):
# Represents a single remote CAS cache.
#
-class CASRemote():
- def __init__(self, spec, cascache):
- self.spec = spec
- self._initialized = False
+class CASRemote(BaseRemote):
+
+ def __init__(self, spec, cascache, **kwargs):
+ super().__init__(spec, **kwargs)
+
self.cascache = cascache
- self.channel = None
- self.instance_name = None
self.cas = None
self.ref_storage = None
self.batch_update_supported = None
@@ -44,157 +36,102 @@ class CASRemote():
self.max_batch_total_size_bytes = None
self.local_cas_instance_name = None
- def init(self):
- if not self._initialized:
- server_cert_bytes = None
- client_key_bytes = None
- client_cert_bytes = None
-
- url = urlparse(self.spec.url)
- if url.scheme == 'http':
- port = url.port or 80
- self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port))
- elif url.scheme == 'https':
- port = url.port or 443
-
- if self.spec.server_cert:
- with open(self.spec.server_cert, 'rb') as f:
- server_cert_bytes = f.read()
-
- if self.spec.client_key:
- with open(self.spec.client_key, 'rb') as f:
- client_key_bytes = f.read()
-
- if self.spec.client_cert:
- with open(self.spec.client_cert, 'rb') as f:
- client_cert_bytes = f.read()
-
- credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes,
- private_key=client_key_bytes,
- certificate_chain=client_cert_bytes)
- self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials)
- else:
- raise CASRemoteError("Unsupported URL: {}".format(self.spec.url))
-
- self.instance_name = self.spec.instance_name or None
-
- self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
- self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
- self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)
-
- self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES
- try:
- request = remote_execution_pb2.GetCapabilitiesRequest()
- if self.instance_name:
- request.instance_name = self.instance_name
- response = self.capabilities.GetCapabilities(request)
- server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes
- if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes:
- self.max_batch_total_size_bytes = server_max_batch_total_size_bytes
- except grpc.RpcError as e:
- # Simply use the defaults for servers that don't implement GetCapabilities()
- if e.code() != grpc.StatusCode.UNIMPLEMENTED:
- raise
-
- # Check whether the server supports BatchReadBlobs()
- self.batch_read_supported = False
- try:
- request = remote_execution_pb2.BatchReadBlobsRequest()
- if self.instance_name:
- request.instance_name = self.instance_name
- response = self.cas.BatchReadBlobs(request)
- self.batch_read_supported = True
- except grpc.RpcError as e:
- if e.code() != grpc.StatusCode.UNIMPLEMENTED:
- raise
-
- # Check whether the server supports BatchUpdateBlobs()
- self.batch_update_supported = False
- try:
- request = remote_execution_pb2.BatchUpdateBlobsRequest()
- if self.instance_name:
- request.instance_name = self.instance_name
- response = self.cas.BatchUpdateBlobs(request)
- self.batch_update_supported = True
- except grpc.RpcError as e:
- if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
- e.code() != grpc.StatusCode.PERMISSION_DENIED):
- raise
-
- local_cas = self.cascache._get_local_cas()
- request = local_cas_pb2.GetInstanceNameForRemoteRequest()
- request.url = self.spec.url
- if self.spec.instance_name:
- request.instance_name = self.spec.instance_name
- if server_cert_bytes:
- request.server_cert = server_cert_bytes
- if client_key_bytes:
- request.client_key = client_key_bytes
- if client_cert_bytes:
- request.client_cert = client_cert_bytes
- response = local_cas.GetInstanceNameForRemote(request)
- self.local_cas_instance_name = response.instance_name
-
- self._initialized = True
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- self.close()
- return False
-
- def close(self):
- if self.channel:
- self.channel.close()
- self.channel = None
-
# check_remote
+ # _configure_protocols():
#
- # Used when checking whether remote_specs work in the buildstream main
- # thread, runs this in a seperate process to avoid creation of gRPC threads
- # in the main BuildStream process
- # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
- @classmethod
- def check_remote(cls, remote_spec, cascache, q):
-
- def __check_remote():
- try:
- remote = cls(remote_spec, cascache)
- remote.init()
-
- request = buildstream_pb2.StatusRequest()
- response = remote.ref_storage.Status(request)
-
- if remote_spec.push and not response.allow_updates:
- q.put('CAS server does not allow push')
- else:
- # No error
- q.put(None)
-
- except grpc.RpcError as e:
- # str(e) is too verbose for errors reported to the user
- q.put(e.details())
+ # Configure remote-specific protocols. This method should *never*
+ # be called outside of init().
+ #
+ def _configure_protocols(self):
+ self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
+ self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
+ self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)
+
+ # Figure out what batch sizes the server will accept, falling
+ # back to our _MAX_PAYLOAD_BYTES
+ self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES
+ try:
+ request = remote_execution_pb2.GetCapabilitiesRequest()
+ if self.instance_name:
+ request.instance_name = self.instance_name
+ response = self.capabilities.GetCapabilities(request)
+ server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes
+ if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes:
+ self.max_batch_total_size_bytes = server_max_batch_total_size_bytes
+ except grpc.RpcError as e:
+ # Simply use the defaults for servers that don't implement
+ # GetCapabilities()
+ if e.code() != grpc.StatusCode.UNIMPLEMENTED:
+ raise
+
+ # Check whether the server supports BatchReadBlobs()
+ self.batch_read_supported = self._check_support(
+ remote_execution_pb2.BatchReadBlobsRequest,
+ self.cas.BatchReadBlobs
+ )
+
+ # Check whether the server supports BatchUpdateBlobs()
+ self.batch_update_supported = self._check_support(
+ remote_execution_pb2.BatchUpdateBlobsRequest,
+ self.cas.BatchUpdateBlobs
+ )
+
+ local_cas = self.cascache._get_local_cas()
+ request = local_cas_pb2.GetInstanceNameForRemoteRequest()
+ request.url = self.spec.url
+ if self.spec.instance_name:
+ request.instance_name = self.spec.instance_name
+ if self.server_cert:
+ request.server_cert = self.server_cert
+ if self.client_key:
+ request.client_key = self.client_key
+ if self.client_cert:
+ request.client_cert = self.client_cert
+ response = local_cas.GetInstanceNameForRemote(request)
+ self.local_cas_instance_name = response.instance_name
+
+ # _check():
+ #
+ # Check if this remote provides everything required for the
+ # particular kind of remote. This is expected to be called as part
+ # of check(), and must be called in a non-main process.
+ #
+ # Returns:
+ # (str|None): An error message, or None if no error message.
+ #
+ def _check(self):
+ request = buildstream_pb2.StatusRequest()
+ response = self.ref_storage.Status(request)
- except Exception as e: # pylint: disable=broad-except
- # Whatever happens, we need to return it to the calling process
- #
- q.put(str(e))
+ if self.spec.push and not response.allow_updates:
+ return 'CAS server does not allow push'
- p = multiprocessing.Process(target=__check_remote)
+ return None
+ # _check_support():
+ #
+ # Figure out if a remote server supports a given method based on
+ # grpc.StatusCode.UNIMPLEMENTED and grpc.StatusCode.PERMISSION_DENIED.
+ #
+ # Args:
+ # request_type (callable): The type of request to check.
+ # invoker (callable): The remote method that will be invoked.
+ #
+ # Returns:
+ # (bool) - Whether the request is supported.
+ #
+ def _check_support(self, request_type, invoker):
try:
- # Keep SIGINT blocked in the child process
- with _signals.blocked([signal.SIGINT], ignore=False):
- p.start()
+ request = request_type()
+ if self.instance_name:
+ request.instance_name = self.instance_name
+ invoker(request)
+ return True
+ except grpc.RpcError as e:
+ if not e.code() in (grpc.StatusCode.UNIMPLEMENTED, grpc.StatusCode.PERMISSION_DENIED):
+ raise
- error = q.get()
- p.join()
- except KeyboardInterrupt:
- utils._kill_process_tree(p.pid)
- raise
-
- return error
+ return False
# push_message():
#
diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py
index ca7a21955..bb011146e 100644
--- a/src/buildstream/_cas/casserver.py
+++ b/src/buildstream/_cas/casserver.py
@@ -54,9 +54,10 @@ _MAX_PAYLOAD_BYTES = 1024 * 1024
# Args:
# repo (str): Path to CAS repository
# enable_push (bool): Whether to allow blob uploads and artifact updates
+# index_only (bool): Whether to store CAS blobs or only artifacts
#
@contextmanager
-def create_server(repo, *, enable_push, quota):
+def create_server(repo, *, enable_push, quota, index_only):
cas = CASCache(os.path.abspath(repo), cache_quota=quota, protect_session_blobs=False)
try:
@@ -67,11 +68,12 @@ def create_server(repo, *, enable_push, quota):
max_workers = (os.cpu_count() or 1) * 5
server = grpc.server(futures.ThreadPoolExecutor(max_workers))
- bytestream_pb2_grpc.add_ByteStreamServicer_to_server(
- _ByteStreamServicer(cas, enable_push=enable_push), server)
+ if not index_only:
+ bytestream_pb2_grpc.add_ByteStreamServicer_to_server(
+ _ByteStreamServicer(cas, enable_push=enable_push), server)
- remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server(
- _ContentAddressableStorageServicer(cas, enable_push=enable_push), server)
+ remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server(
+ _ContentAddressableStorageServicer(cas, enable_push=enable_push), server)
remote_execution_pb2_grpc.add_CapabilitiesServicer_to_server(
_CapabilitiesServicer(), server)
@@ -80,7 +82,7 @@ def create_server(repo, *, enable_push, quota):
_ReferenceStorageServicer(cas, enable_push=enable_push), server)
artifact_pb2_grpc.add_ArtifactServiceServicer_to_server(
- _ArtifactServicer(cas, artifactdir), server)
+ _ArtifactServicer(cas, artifactdir, update_cas=not index_only), server)
source_pb2_grpc.add_SourceServiceServicer_to_server(
_SourceServicer(sourcedir), server)
@@ -110,9 +112,12 @@ def create_server(repo, *, enable_push, quota):
@click.option('--quota', type=click.INT,
help="Maximum disk usage in bytes",
default=10e9)
+@click.option('--index-only', type=click.BOOL,
+ help="Only provide the BuildStream artifact and source services (\"index\"), not the CAS (\"storage\")",
+ default=False)
@click.argument('repo')
def server_main(repo, port, server_key, server_cert, client_certs, enable_push,
- quota):
+ quota, index_only):
# Handle SIGTERM by calling sys.exit(0), which will raise a SystemExit exception,
# properly executing cleanup code in `finally` clauses and context managers.
# This is required to terminate buildbox-casd on SIGTERM.
@@ -120,7 +125,8 @@ def server_main(repo, port, server_key, server_cert, client_certs, enable_push,
with create_server(repo,
quota=quota,
- enable_push=enable_push) as server:
+ enable_push=enable_push,
+ index_only=index_only) as server:
use_tls = bool(server_key)
@@ -434,10 +440,11 @@ class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer):
class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer):
- def __init__(self, cas, artifactdir):
+ def __init__(self, cas, artifactdir, *, update_cas=True):
super().__init__()
self.cas = cas
self.artifactdir = artifactdir
+ self.update_cas = update_cas
os.makedirs(artifactdir, exist_ok=True)
def GetArtifact(self, request, context):
@@ -449,6 +456,20 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer):
with open(artifact_path, 'rb') as f:
artifact.ParseFromString(f.read())
+ # Artifact-only servers will not have blobs on their system,
+ # so we can't reasonably update their mtimes. Instead, we exit
+ # early, and let the CAS server deal with its blobs.
+ #
+ # FIXME: We could try to run FindMissingBlobs on the other
+ # server. This is tricky to do from here, of course,
+ # because we don't know who the other server is, but
+ # the client could be smart about it - but this might
+ # make things slower.
+ #
+ # It needs some more thought...
+ if not self.update_cas:
+ return artifact
+
# Now update mtimes of files present.
try:
@@ -481,16 +502,17 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer):
def UpdateArtifact(self, request, context):
artifact = request.artifact
- # Check that the files specified are in the CAS
- self._check_directory("files", artifact.files, context)
+ if self.update_cas:
+ # Check that the files specified are in the CAS
+ self._check_directory("files", artifact.files, context)
- # Unset protocol buffers don't evaluated to False but do return empty
- # strings, hence str()
- if str(artifact.public_data):
- self._check_file("public data", artifact.public_data, context)
+ # Unset protocol buffers don't evaluated to False but do return empty
+ # strings, hence str()
+ if str(artifact.public_data):
+ self._check_file("public data", artifact.public_data, context)
- for log_file in artifact.logs:
- self._check_file("log digest", log_file.digest, context)
+ for log_file in artifact.logs:
+ self._check_file("log digest", log_file.digest, context)
# Add the artifact proto to the cas
artifact_path = os.path.join(self.artifactdir, request.cache_key)