diff options
Diffstat (limited to 'src/buildstream/_cas/casremote.py')
-rw-r--r-- | src/buildstream/_cas/casremote.py | 391 |
1 files changed, 391 insertions, 0 deletions
diff --git a/src/buildstream/_cas/casremote.py b/src/buildstream/_cas/casremote.py new file mode 100644 index 000000000..aac0d2802 --- /dev/null +++ b/src/buildstream/_cas/casremote.py @@ -0,0 +1,391 @@ +from collections import namedtuple +import io +import os +import multiprocessing +import signal +from urllib.parse import urlparse +import uuid + +import grpc + +from .. import _yaml +from .._protos.google.rpc import code_pb2 +from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc +from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc +from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc + +from .._exceptions import CASRemoteError, LoadError, LoadErrorReason +from .. import _signals +from .. import utils + +# The default limit for gRPC messages is 4 MiB. +# Limit payload to 1 MiB to leave sufficient headroom for metadata. +_MAX_PAYLOAD_BYTES = 1024 * 1024 + + +class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')): + + # _new_from_config_node + # + # Creates an CASRemoteSpec() from a YAML loaded node + # + @staticmethod + def _new_from_config_node(spec_node, basedir=None): + _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance-name']) + url = _yaml.node_get(spec_node, str, 'url') + push = _yaml.node_get(spec_node, bool, 'push', default_value=False) + if not url: + provenance = _yaml.node_get_provenance(spec_node, 'url') + raise LoadError(LoadErrorReason.INVALID_DATA, + "{}: empty artifact cache URL".format(provenance)) + + instance_name = _yaml.node_get(spec_node, str, 'instance-name', default_value=None) + + server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None) + if server_cert and basedir: + server_cert = os.path.join(basedir, server_cert) + + client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None) + if client_key and basedir: + client_key = os.path.join(basedir, client_key) + + client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None) + if client_cert and basedir: + client_cert = os.path.join(basedir, client_cert) + + if client_key and not client_cert: + provenance = _yaml.node_get_provenance(spec_node, 'client-key') + raise LoadError(LoadErrorReason.INVALID_DATA, + "{}: 'client-key' was specified without 'client-cert'".format(provenance)) + + if client_cert and not client_key: + provenance = _yaml.node_get_provenance(spec_node, 'client-cert') + raise LoadError(LoadErrorReason.INVALID_DATA, + "{}: 'client-cert' was specified without 'client-key'".format(provenance)) + + return CASRemoteSpec(url, push, server_cert, client_key, client_cert, instance_name) + + +CASRemoteSpec.__new__.__defaults__ = (None, None, None, None) + + +class BlobNotFound(CASRemoteError): + + def __init__(self, blob, msg): + self.blob = blob + super().__init__(msg) + + +# Represents a single remote CAS cache. +# +class CASRemote(): + def __init__(self, spec): + self.spec = spec + self._initialized = False + self.channel = None + self.instance_name = None + self.bytestream = None + self.cas = None + self.ref_storage = None + self.batch_update_supported = None + self.batch_read_supported = None + self.capabilities = None + self.max_batch_total_size_bytes = None + + def init(self): + if not self._initialized: + url = urlparse(self.spec.url) + if url.scheme == 'http': + port = url.port or 80 + self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port)) + elif url.scheme == 'https': + port = url.port or 443 + + if self.spec.server_cert: + with open(self.spec.server_cert, 'rb') as f: + server_cert_bytes = f.read() + else: + server_cert_bytes = None + + if self.spec.client_key: + with open(self.spec.client_key, 'rb') as f: + client_key_bytes = f.read() + else: + client_key_bytes = None + + if self.spec.client_cert: + with open(self.spec.client_cert, 'rb') as f: + client_cert_bytes = f.read() + else: + client_cert_bytes = None + + credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes, + private_key=client_key_bytes, + certificate_chain=client_cert_bytes) + self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials) + else: + raise CASRemoteError("Unsupported URL: {}".format(self.spec.url)) + + self.instance_name = self.spec.instance_name or None + + self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel) + self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) + self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) + self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) + + self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES + try: + request = remote_execution_pb2.GetCapabilitiesRequest() + if self.instance_name: + request.instance_name = self.instance_name + response = self.capabilities.GetCapabilities(request) + server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes + if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: + self.max_batch_total_size_bytes = server_max_batch_total_size_bytes + except grpc.RpcError as e: + # Simply use the defaults for servers that don't implement GetCapabilities() + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + + # Check whether the server supports BatchReadBlobs() + self.batch_read_supported = False + try: + request = remote_execution_pb2.BatchReadBlobsRequest() + if self.instance_name: + request.instance_name = self.instance_name + response = self.cas.BatchReadBlobs(request) + self.batch_read_supported = True + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + + # Check whether the server supports BatchUpdateBlobs() + self.batch_update_supported = False + try: + request = remote_execution_pb2.BatchUpdateBlobsRequest() + if self.instance_name: + request.instance_name = self.instance_name + response = self.cas.BatchUpdateBlobs(request) + self.batch_update_supported = True + except grpc.RpcError as e: + if (e.code() != grpc.StatusCode.UNIMPLEMENTED and + e.code() != grpc.StatusCode.PERMISSION_DENIED): + raise + + self._initialized = True + + # check_remote + # + # Used when checking whether remote_specs work in the buildstream main + # thread, runs this in a seperate process to avoid creation of gRPC threads + # in the main BuildStream process + # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details + @classmethod + def check_remote(cls, remote_spec, q): + + def __check_remote(): + try: + remote = cls(remote_spec) + remote.init() + + request = buildstream_pb2.StatusRequest() + response = remote.ref_storage.Status(request) + + if remote_spec.push and not response.allow_updates: + q.put('CAS server does not allow push') + else: + # No error + q.put(None) + + except grpc.RpcError as e: + # str(e) is too verbose for errors reported to the user + q.put(e.details()) + + except Exception as e: # pylint: disable=broad-except + # Whatever happens, we need to return it to the calling process + # + q.put(str(e)) + + p = multiprocessing.Process(target=__check_remote) + + try: + # Keep SIGINT blocked in the child process + with _signals.blocked([signal.SIGINT], ignore=False): + p.start() + + error = q.get() + p.join() + except KeyboardInterrupt: + utils._kill_process_tree(p.pid) + raise + + return error + + # push_message(): + # + # Push the given protobuf message to a remote. + # + # Args: + # message (Message): A protobuf message to push. + # + # Raises: + # (CASRemoteError): if there was an error + # + def push_message(self, message): + + message_buffer = message.SerializeToString() + message_digest = utils._message_digest(message_buffer) + + self.init() + + with io.BytesIO(message_buffer) as b: + self._send_blob(message_digest, b) + + return message_digest + + ################################################ + # Local Private Methods # + ################################################ + def _fetch_blob(self, digest, stream): + if self.instance_name: + resource_name = '/'.join([self.instance_name, 'blobs', + digest.hash, str(digest.size_bytes)]) + else: + resource_name = '/'.join(['blobs', + digest.hash, str(digest.size_bytes)]) + + request = bytestream_pb2.ReadRequest() + request.resource_name = resource_name + request.read_offset = 0 + for response in self.bytestream.Read(request): + stream.write(response.data) + stream.flush() + + assert digest.size_bytes == os.fstat(stream.fileno()).st_size + + def _send_blob(self, digest, stream, u_uid=uuid.uuid4()): + if self.instance_name: + resource_name = '/'.join([self.instance_name, 'uploads', str(u_uid), 'blobs', + digest.hash, str(digest.size_bytes)]) + else: + resource_name = '/'.join(['uploads', str(u_uid), 'blobs', + digest.hash, str(digest.size_bytes)]) + + def request_stream(resname, instream): + offset = 0 + finished = False + remaining = digest.size_bytes + while not finished: + chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) + remaining -= chunk_size + + request = bytestream_pb2.WriteRequest() + request.write_offset = offset + # max. _MAX_PAYLOAD_BYTES chunks + request.data = instream.read(chunk_size) + request.resource_name = resname + request.finish_write = remaining <= 0 + + yield request + + offset += chunk_size + finished = request.finish_write + + response = self.bytestream.Write(request_stream(resource_name, stream)) + + assert response.committed_size == digest.size_bytes + + +# Represents a batch of blobs queued for fetching. +# +class _CASBatchRead(): + def __init__(self, remote): + self._remote = remote + self._max_total_size_bytes = remote.max_batch_total_size_bytes + self._request = remote_execution_pb2.BatchReadBlobsRequest() + if remote.instance_name: + self._request.instance_name = remote.instance_name + self._size = 0 + self._sent = False + + def add(self, digest): + assert not self._sent + + new_batch_size = self._size + digest.size_bytes + if new_batch_size > self._max_total_size_bytes: + # Not enough space left in current batch + return False + + request_digest = self._request.digests.add() + request_digest.hash = digest.hash + request_digest.size_bytes = digest.size_bytes + self._size = new_batch_size + return True + + def send(self, *, missing_blobs=None): + assert not self._sent + self._sent = True + + if not self._request.digests: + return + + batch_response = self._remote.cas.BatchReadBlobs(self._request) + + for response in batch_response.responses: + if response.status.code == code_pb2.NOT_FOUND: + if missing_blobs is None: + raise BlobNotFound(response.digest.hash, "Failed to download blob {}: {}".format( + response.digest.hash, response.status.code)) + else: + missing_blobs.append(response.digest) + + if response.status.code != code_pb2.OK: + raise CASRemoteError("Failed to download blob {}: {}".format( + response.digest.hash, response.status.code)) + if response.digest.size_bytes != len(response.data): + raise CASRemoteError("Failed to download blob {}: expected {} bytes, received {} bytes".format( + response.digest.hash, response.digest.size_bytes, len(response.data))) + + yield (response.digest, response.data) + + +# Represents a batch of blobs queued for upload. +# +class _CASBatchUpdate(): + def __init__(self, remote): + self._remote = remote + self._max_total_size_bytes = remote.max_batch_total_size_bytes + self._request = remote_execution_pb2.BatchUpdateBlobsRequest() + if remote.instance_name: + self._request.instance_name = remote.instance_name + self._size = 0 + self._sent = False + + def add(self, digest, stream): + assert not self._sent + + new_batch_size = self._size + digest.size_bytes + if new_batch_size > self._max_total_size_bytes: + # Not enough space left in current batch + return False + + blob_request = self._request.requests.add() + blob_request.digest.hash = digest.hash + blob_request.digest.size_bytes = digest.size_bytes + blob_request.data = stream.read(digest.size_bytes) + self._size = new_batch_size + return True + + def send(self): + assert not self._sent + self._sent = True + + if not self._request.requests: + return + + batch_response = self._remote.cas.BatchUpdateBlobs(self._request) + + for response in batch_response.responses: + if response.status.code != code_pb2.OK: + raise CASRemoteError("Failed to upload blob {}: {}".format( + response.digest.hash, response.status.code)) |