diff options
Diffstat (limited to 'buildstream/_cas/cascache.py')
-rw-r--r-- | buildstream/_cas/cascache.py | 275 |
1 files changed, 19 insertions, 256 deletions
diff --git a/buildstream/_cas/cascache.py b/buildstream/_cas/cascache.py index 482d4006f..5f62e6105 100644 --- a/buildstream/_cas/cascache.py +++ b/buildstream/_cas/cascache.py @@ -17,7 +17,6 @@ # Authors: # Jürg Billeter <juerg.billeter@codethink.co.uk> -from collections import namedtuple import hashlib import itertools import io @@ -26,76 +25,17 @@ import stat import tempfile import uuid import contextlib -from urllib.parse import urlparse import grpc -from .._protos.google.rpc import code_pb2 -from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc -from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc -from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc +from .._protos.google.bytestream import bytestream_pb2 +from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 +from .._protos.buildstream.v2 import buildstream_pb2 from .. import utils -from .._exceptions import CASError, LoadError, LoadErrorReason -from .. import _yaml +from .._exceptions import CASCacheError - -# The default limit for gRPC messages is 4 MiB. -# Limit payload to 1 MiB to leave sufficient headroom for metadata. -_MAX_PAYLOAD_BYTES = 1024 * 1024 - - -class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')): - - # _new_from_config_node - # - # Creates an CASRemoteSpec() from a YAML loaded node - # - @staticmethod - def _new_from_config_node(spec_node, basedir=None): - _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance-name']) - url = _yaml.node_get(spec_node, str, 'url') - push = _yaml.node_get(spec_node, bool, 'push', default_value=False) - if not url: - provenance = _yaml.node_get_provenance(spec_node, 'url') - raise LoadError(LoadErrorReason.INVALID_DATA, - "{}: empty artifact cache URL".format(provenance)) - - instance_name = _yaml.node_get(spec_node, str, 'instance-name', default_value=None) - - server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None) - if server_cert and basedir: - server_cert = os.path.join(basedir, server_cert) - - client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None) - if client_key and basedir: - client_key = os.path.join(basedir, client_key) - - client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None) - if client_cert and basedir: - client_cert = os.path.join(basedir, client_cert) - - if client_key and not client_cert: - provenance = _yaml.node_get_provenance(spec_node, 'client-key') - raise LoadError(LoadErrorReason.INVALID_DATA, - "{}: 'client-key' was specified without 'client-cert'".format(provenance)) - - if client_cert and not client_key: - provenance = _yaml.node_get_provenance(spec_node, 'client-cert') - raise LoadError(LoadErrorReason.INVALID_DATA, - "{}: 'client-cert' was specified without 'client-key'".format(provenance)) - - return CASRemoteSpec(url, push, server_cert, client_key, client_cert, instance_name) - - -CASRemoteSpec.__new__.__defaults__ = (None, None, None, None) - - -class BlobNotFound(CASError): - - def __init__(self, blob, msg): - self.blob = blob - super().__init__(msg) +from .casremote import CASRemote, BlobNotFound, _CASBatchRead, _CASBatchUpdate, _MAX_PAYLOAD_BYTES # A CASCache manages a CAS repository as specified in the Remote Execution API. @@ -120,7 +60,7 @@ class CASCache(): headdir = os.path.join(self.casdir, 'refs', 'heads') objdir = os.path.join(self.casdir, 'objects') if not (os.path.isdir(headdir) and os.path.isdir(objdir)): - raise CASError("CAS repository check failed for '{}'".format(self.casdir)) + raise CASCacheError("CAS repository check failed for '{}'".format(self.casdir)) # contains(): # @@ -169,7 +109,7 @@ class CASCache(): # subdir (str): Optional specific dir to extract # # Raises: - # CASError: In cases there was an OSError, or if the ref did not exist. + # CASCacheError: In cases there was an OSError, or if the ref did not exist. # # Returns: path to extracted directory # @@ -201,7 +141,7 @@ class CASCache(): # Another process beat us to rename pass except OSError as e: - raise CASError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e + raise CASCacheError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e return originaldest @@ -306,7 +246,7 @@ class CASCache(): return True except grpc.RpcError as e: if e.code() != grpc.StatusCode.NOT_FOUND: - raise CASError("Failed to pull ref {}: {}".format(ref, e)) from e + raise CASCacheError("Failed to pull ref {}: {}".format(ref, e)) from e else: return False except BlobNotFound as e: @@ -360,7 +300,7 @@ class CASCache(): # (bool): True if any remote was updated, False if no pushes were required # # Raises: - # (CASError): if there was an error + # (CASCacheError): if there was an error # def push(self, refs, remote): skipped_remote = True @@ -395,7 +335,7 @@ class CASCache(): skipped_remote = False except grpc.RpcError as e: if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED: - raise CASError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e + raise CASCacheError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e return not skipped_remote @@ -408,7 +348,7 @@ class CASCache(): # directory (Directory): A virtual directory object to push. # # Raises: - # (CASError): if there was an error + # (CASCacheError): if there was an error # def push_directory(self, remote, directory): remote.init() @@ -424,7 +364,7 @@ class CASCache(): # message (Message): A protobuf message to push. # # Raises: - # (CASError): if there was an error + # (CASCacheError): if there was an error # def push_message(self, remote, message): @@ -531,7 +471,7 @@ class CASCache(): pass except OSError as e: - raise CASError("Failed to hash object: {}".format(e)) from e + raise CASCacheError("Failed to hash object: {}".format(e)) from e return digest @@ -572,7 +512,7 @@ class CASCache(): return digest except FileNotFoundError as e: - raise CASError("Attempt to access unavailable ref: {}".format(e)) from e + raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e # update_mtime() # @@ -585,7 +525,7 @@ class CASCache(): try: os.utime(self._refpath(ref)) except FileNotFoundError as e: - raise CASError("Attempt to access unavailable ref: {}".format(e)) from e + raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e # calculate_cache_size() # @@ -676,7 +616,7 @@ class CASCache(): # Remove cache ref refpath = self._refpath(ref) if not os.path.exists(refpath): - raise CASError("Could not find ref '{}'".format(ref)) + raise CASCacheError("Could not find ref '{}'".format(ref)) os.unlink(refpath) @@ -792,7 +732,7 @@ class CASCache(): # The process serving the socket can't be cached anyway pass else: - raise CASError("Unsupported file type for {}".format(full_path)) + raise CASCacheError("Unsupported file type for {}".format(full_path)) return self.add_object(digest=dir_digest, buffer=directory.SerializeToString()) @@ -811,7 +751,7 @@ class CASCache(): if dirnode.name == name: return dirnode.digest - raise CASError("Subdirectory {} not found".format(name)) + raise CASCacheError("Subdirectory {} not found".format(name)) def _diff_trees(self, tree_a, tree_b, *, added, removed, modified, path=""): dir_a = remote_execution_pb2.Directory() @@ -1150,183 +1090,6 @@ class CASCache(): batch.send() -# Represents a single remote CAS cache. -# -class CASRemote(): - def __init__(self, spec): - self.spec = spec - self._initialized = False - self.channel = None - self.bytestream = None - self.cas = None - self.ref_storage = None - self.batch_update_supported = None - self.batch_read_supported = None - self.capabilities = None - self.max_batch_total_size_bytes = None - - def init(self): - if not self._initialized: - url = urlparse(self.spec.url) - if url.scheme == 'http': - port = url.port or 80 - self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port)) - elif url.scheme == 'https': - port = url.port or 443 - - if self.spec.server_cert: - with open(self.spec.server_cert, 'rb') as f: - server_cert_bytes = f.read() - else: - server_cert_bytes = None - - if self.spec.client_key: - with open(self.spec.client_key, 'rb') as f: - client_key_bytes = f.read() - else: - client_key_bytes = None - - if self.spec.client_cert: - with open(self.spec.client_cert, 'rb') as f: - client_cert_bytes = f.read() - else: - client_cert_bytes = None - - credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes, - private_key=client_key_bytes, - certificate_chain=client_cert_bytes) - self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials) - else: - raise CASError("Unsupported URL: {}".format(self.spec.url)) - - self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel) - self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) - self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) - self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) - - self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES - try: - request = remote_execution_pb2.GetCapabilitiesRequest(instance_name=self.spec.instance_name) - response = self.capabilities.GetCapabilities(request) - server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes - if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: - self.max_batch_total_size_bytes = server_max_batch_total_size_bytes - except grpc.RpcError as e: - # Simply use the defaults for servers that don't implement GetCapabilities() - if e.code() != grpc.StatusCode.UNIMPLEMENTED: - raise - - # Check whether the server supports BatchReadBlobs() - self.batch_read_supported = False - try: - request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=self.spec.instance_name) - response = self.cas.BatchReadBlobs(request) - self.batch_read_supported = True - except grpc.RpcError as e: - if e.code() != grpc.StatusCode.UNIMPLEMENTED: - raise - - # Check whether the server supports BatchUpdateBlobs() - self.batch_update_supported = False - try: - request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=self.spec.instance_name) - response = self.cas.BatchUpdateBlobs(request) - self.batch_update_supported = True - except grpc.RpcError as e: - if (e.code() != grpc.StatusCode.UNIMPLEMENTED and - e.code() != grpc.StatusCode.PERMISSION_DENIED): - raise - - self._initialized = True - - -# Represents a batch of blobs queued for fetching. -# -class _CASBatchRead(): - def __init__(self, remote): - self._remote = remote - self._max_total_size_bytes = remote.max_batch_total_size_bytes - self._request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=remote.spec.instance_name) - self._size = 0 - self._sent = False - - def add(self, digest): - assert not self._sent - - new_batch_size = self._size + digest.size_bytes - if new_batch_size > self._max_total_size_bytes: - # Not enough space left in current batch - return False - - request_digest = self._request.digests.add() - request_digest.hash = digest.hash - request_digest.size_bytes = digest.size_bytes - self._size = new_batch_size - return True - - def send(self): - assert not self._sent - self._sent = True - - if not self._request.digests: - return - - batch_response = self._remote.cas.BatchReadBlobs(self._request) - - for response in batch_response.responses: - if response.status.code == code_pb2.NOT_FOUND: - raise BlobNotFound(response.digest.hash, "Failed to download blob {}: {}".format( - response.digest.hash, response.status.code)) - if response.status.code != code_pb2.OK: - raise CASError("Failed to download blob {}: {}".format( - response.digest.hash, response.status.code)) - if response.digest.size_bytes != len(response.data): - raise CASError("Failed to download blob {}: expected {} bytes, received {} bytes".format( - response.digest.hash, response.digest.size_bytes, len(response.data))) - - yield (response.digest, response.data) - - -# Represents a batch of blobs queued for upload. -# -class _CASBatchUpdate(): - def __init__(self, remote): - self._remote = remote - self._max_total_size_bytes = remote.max_batch_total_size_bytes - self._request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=remote.spec.instance_name) - self._size = 0 - self._sent = False - - def add(self, digest, stream): - assert not self._sent - - new_batch_size = self._size + digest.size_bytes - if new_batch_size > self._max_total_size_bytes: - # Not enough space left in current batch - return False - - blob_request = self._request.requests.add() - blob_request.digest.hash = digest.hash - blob_request.digest.size_bytes = digest.size_bytes - blob_request.data = stream.read(digest.size_bytes) - self._size = new_batch_size - return True - - def send(self): - assert not self._sent - self._sent = True - - if not self._request.requests: - return - - batch_response = self._remote.cas.BatchUpdateBlobs(self._request) - - for response in batch_response.responses: - if response.status.code != code_pb2.OK: - raise CASError("Failed to upload blob {}: {}".format( - response.digest.hash, response.status.code)) - - def _grouper(iterable, n): while True: try: |