summaryrefslogtreecommitdiff
path: root/buildstream/_cas/cascache.py
diff options
context:
space:
mode:
Diffstat (limited to 'buildstream/_cas/cascache.py')
-rw-r--r--buildstream/_cas/cascache.py275
1 files changed, 19 insertions, 256 deletions
diff --git a/buildstream/_cas/cascache.py b/buildstream/_cas/cascache.py
index 482d4006f..5f62e6105 100644
--- a/buildstream/_cas/cascache.py
+++ b/buildstream/_cas/cascache.py
@@ -17,7 +17,6 @@
# Authors:
# Jürg Billeter <juerg.billeter@codethink.co.uk>
-from collections import namedtuple
import hashlib
import itertools
import io
@@ -26,76 +25,17 @@ import stat
import tempfile
import uuid
import contextlib
-from urllib.parse import urlparse
import grpc
-from .._protos.google.rpc import code_pb2
-from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
-from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
-from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc
+from .._protos.google.bytestream import bytestream_pb2
+from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
+from .._protos.buildstream.v2 import buildstream_pb2
from .. import utils
-from .._exceptions import CASError, LoadError, LoadErrorReason
-from .. import _yaml
+from .._exceptions import CASCacheError
-
-# The default limit for gRPC messages is 4 MiB.
-# Limit payload to 1 MiB to leave sufficient headroom for metadata.
-_MAX_PAYLOAD_BYTES = 1024 * 1024
-
-
-class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')):
-
- # _new_from_config_node
- #
- # Creates an CASRemoteSpec() from a YAML loaded node
- #
- @staticmethod
- def _new_from_config_node(spec_node, basedir=None):
- _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance-name'])
- url = _yaml.node_get(spec_node, str, 'url')
- push = _yaml.node_get(spec_node, bool, 'push', default_value=False)
- if not url:
- provenance = _yaml.node_get_provenance(spec_node, 'url')
- raise LoadError(LoadErrorReason.INVALID_DATA,
- "{}: empty artifact cache URL".format(provenance))
-
- instance_name = _yaml.node_get(spec_node, str, 'instance-name', default_value=None)
-
- server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None)
- if server_cert and basedir:
- server_cert = os.path.join(basedir, server_cert)
-
- client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None)
- if client_key and basedir:
- client_key = os.path.join(basedir, client_key)
-
- client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None)
- if client_cert and basedir:
- client_cert = os.path.join(basedir, client_cert)
-
- if client_key and not client_cert:
- provenance = _yaml.node_get_provenance(spec_node, 'client-key')
- raise LoadError(LoadErrorReason.INVALID_DATA,
- "{}: 'client-key' was specified without 'client-cert'".format(provenance))
-
- if client_cert and not client_key:
- provenance = _yaml.node_get_provenance(spec_node, 'client-cert')
- raise LoadError(LoadErrorReason.INVALID_DATA,
- "{}: 'client-cert' was specified without 'client-key'".format(provenance))
-
- return CASRemoteSpec(url, push, server_cert, client_key, client_cert, instance_name)
-
-
-CASRemoteSpec.__new__.__defaults__ = (None, None, None, None)
-
-
-class BlobNotFound(CASError):
-
- def __init__(self, blob, msg):
- self.blob = blob
- super().__init__(msg)
+from .casremote import CASRemote, BlobNotFound, _CASBatchRead, _CASBatchUpdate, _MAX_PAYLOAD_BYTES
# A CASCache manages a CAS repository as specified in the Remote Execution API.
@@ -120,7 +60,7 @@ class CASCache():
headdir = os.path.join(self.casdir, 'refs', 'heads')
objdir = os.path.join(self.casdir, 'objects')
if not (os.path.isdir(headdir) and os.path.isdir(objdir)):
- raise CASError("CAS repository check failed for '{}'".format(self.casdir))
+ raise CASCacheError("CAS repository check failed for '{}'".format(self.casdir))
# contains():
#
@@ -169,7 +109,7 @@ class CASCache():
# subdir (str): Optional specific dir to extract
#
# Raises:
- # CASError: In cases there was an OSError, or if the ref did not exist.
+ # CASCacheError: In cases there was an OSError, or if the ref did not exist.
#
# Returns: path to extracted directory
#
@@ -201,7 +141,7 @@ class CASCache():
# Another process beat us to rename
pass
except OSError as e:
- raise CASError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e
+ raise CASCacheError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e
return originaldest
@@ -306,7 +246,7 @@ class CASCache():
return True
except grpc.RpcError as e:
if e.code() != grpc.StatusCode.NOT_FOUND:
- raise CASError("Failed to pull ref {}: {}".format(ref, e)) from e
+ raise CASCacheError("Failed to pull ref {}: {}".format(ref, e)) from e
else:
return False
except BlobNotFound as e:
@@ -360,7 +300,7 @@ class CASCache():
# (bool): True if any remote was updated, False if no pushes were required
#
# Raises:
- # (CASError): if there was an error
+ # (CASCacheError): if there was an error
#
def push(self, refs, remote):
skipped_remote = True
@@ -395,7 +335,7 @@ class CASCache():
skipped_remote = False
except grpc.RpcError as e:
if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED:
- raise CASError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e
+ raise CASCacheError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e
return not skipped_remote
@@ -408,7 +348,7 @@ class CASCache():
# directory (Directory): A virtual directory object to push.
#
# Raises:
- # (CASError): if there was an error
+ # (CASCacheError): if there was an error
#
def push_directory(self, remote, directory):
remote.init()
@@ -424,7 +364,7 @@ class CASCache():
# message (Message): A protobuf message to push.
#
# Raises:
- # (CASError): if there was an error
+ # (CASCacheError): if there was an error
#
def push_message(self, remote, message):
@@ -531,7 +471,7 @@ class CASCache():
pass
except OSError as e:
- raise CASError("Failed to hash object: {}".format(e)) from e
+ raise CASCacheError("Failed to hash object: {}".format(e)) from e
return digest
@@ -572,7 +512,7 @@ class CASCache():
return digest
except FileNotFoundError as e:
- raise CASError("Attempt to access unavailable ref: {}".format(e)) from e
+ raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e
# update_mtime()
#
@@ -585,7 +525,7 @@ class CASCache():
try:
os.utime(self._refpath(ref))
except FileNotFoundError as e:
- raise CASError("Attempt to access unavailable ref: {}".format(e)) from e
+ raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e
# calculate_cache_size()
#
@@ -676,7 +616,7 @@ class CASCache():
# Remove cache ref
refpath = self._refpath(ref)
if not os.path.exists(refpath):
- raise CASError("Could not find ref '{}'".format(ref))
+ raise CASCacheError("Could not find ref '{}'".format(ref))
os.unlink(refpath)
@@ -792,7 +732,7 @@ class CASCache():
# The process serving the socket can't be cached anyway
pass
else:
- raise CASError("Unsupported file type for {}".format(full_path))
+ raise CASCacheError("Unsupported file type for {}".format(full_path))
return self.add_object(digest=dir_digest,
buffer=directory.SerializeToString())
@@ -811,7 +751,7 @@ class CASCache():
if dirnode.name == name:
return dirnode.digest
- raise CASError("Subdirectory {} not found".format(name))
+ raise CASCacheError("Subdirectory {} not found".format(name))
def _diff_trees(self, tree_a, tree_b, *, added, removed, modified, path=""):
dir_a = remote_execution_pb2.Directory()
@@ -1150,183 +1090,6 @@ class CASCache():
batch.send()
-# Represents a single remote CAS cache.
-#
-class CASRemote():
- def __init__(self, spec):
- self.spec = spec
- self._initialized = False
- self.channel = None
- self.bytestream = None
- self.cas = None
- self.ref_storage = None
- self.batch_update_supported = None
- self.batch_read_supported = None
- self.capabilities = None
- self.max_batch_total_size_bytes = None
-
- def init(self):
- if not self._initialized:
- url = urlparse(self.spec.url)
- if url.scheme == 'http':
- port = url.port or 80
- self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port))
- elif url.scheme == 'https':
- port = url.port or 443
-
- if self.spec.server_cert:
- with open(self.spec.server_cert, 'rb') as f:
- server_cert_bytes = f.read()
- else:
- server_cert_bytes = None
-
- if self.spec.client_key:
- with open(self.spec.client_key, 'rb') as f:
- client_key_bytes = f.read()
- else:
- client_key_bytes = None
-
- if self.spec.client_cert:
- with open(self.spec.client_cert, 'rb') as f:
- client_cert_bytes = f.read()
- else:
- client_cert_bytes = None
-
- credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes,
- private_key=client_key_bytes,
- certificate_chain=client_cert_bytes)
- self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials)
- else:
- raise CASError("Unsupported URL: {}".format(self.spec.url))
-
- self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel)
- self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
- self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
- self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)
-
- self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES
- try:
- request = remote_execution_pb2.GetCapabilitiesRequest(instance_name=self.spec.instance_name)
- response = self.capabilities.GetCapabilities(request)
- server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes
- if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes:
- self.max_batch_total_size_bytes = server_max_batch_total_size_bytes
- except grpc.RpcError as e:
- # Simply use the defaults for servers that don't implement GetCapabilities()
- if e.code() != grpc.StatusCode.UNIMPLEMENTED:
- raise
-
- # Check whether the server supports BatchReadBlobs()
- self.batch_read_supported = False
- try:
- request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=self.spec.instance_name)
- response = self.cas.BatchReadBlobs(request)
- self.batch_read_supported = True
- except grpc.RpcError as e:
- if e.code() != grpc.StatusCode.UNIMPLEMENTED:
- raise
-
- # Check whether the server supports BatchUpdateBlobs()
- self.batch_update_supported = False
- try:
- request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=self.spec.instance_name)
- response = self.cas.BatchUpdateBlobs(request)
- self.batch_update_supported = True
- except grpc.RpcError as e:
- if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
- e.code() != grpc.StatusCode.PERMISSION_DENIED):
- raise
-
- self._initialized = True
-
-
-# Represents a batch of blobs queued for fetching.
-#
-class _CASBatchRead():
- def __init__(self, remote):
- self._remote = remote
- self._max_total_size_bytes = remote.max_batch_total_size_bytes
- self._request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=remote.spec.instance_name)
- self._size = 0
- self._sent = False
-
- def add(self, digest):
- assert not self._sent
-
- new_batch_size = self._size + digest.size_bytes
- if new_batch_size > self._max_total_size_bytes:
- # Not enough space left in current batch
- return False
-
- request_digest = self._request.digests.add()
- request_digest.hash = digest.hash
- request_digest.size_bytes = digest.size_bytes
- self._size = new_batch_size
- return True
-
- def send(self):
- assert not self._sent
- self._sent = True
-
- if not self._request.digests:
- return
-
- batch_response = self._remote.cas.BatchReadBlobs(self._request)
-
- for response in batch_response.responses:
- if response.status.code == code_pb2.NOT_FOUND:
- raise BlobNotFound(response.digest.hash, "Failed to download blob {}: {}".format(
- response.digest.hash, response.status.code))
- if response.status.code != code_pb2.OK:
- raise CASError("Failed to download blob {}: {}".format(
- response.digest.hash, response.status.code))
- if response.digest.size_bytes != len(response.data):
- raise CASError("Failed to download blob {}: expected {} bytes, received {} bytes".format(
- response.digest.hash, response.digest.size_bytes, len(response.data)))
-
- yield (response.digest, response.data)
-
-
-# Represents a batch of blobs queued for upload.
-#
-class _CASBatchUpdate():
- def __init__(self, remote):
- self._remote = remote
- self._max_total_size_bytes = remote.max_batch_total_size_bytes
- self._request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=remote.spec.instance_name)
- self._size = 0
- self._sent = False
-
- def add(self, digest, stream):
- assert not self._sent
-
- new_batch_size = self._size + digest.size_bytes
- if new_batch_size > self._max_total_size_bytes:
- # Not enough space left in current batch
- return False
-
- blob_request = self._request.requests.add()
- blob_request.digest.hash = digest.hash
- blob_request.digest.size_bytes = digest.size_bytes
- blob_request.data = stream.read(digest.size_bytes)
- self._size = new_batch_size
- return True
-
- def send(self):
- assert not self._sent
- self._sent = True
-
- if not self._request.requests:
- return
-
- batch_response = self._remote.cas.BatchUpdateBlobs(self._request)
-
- for response in batch_response.responses:
- if response.status.code != code_pb2.OK:
- raise CASError("Failed to upload blob {}: {}".format(
- response.digest.hash, response.status.code))
-
-
def _grouper(iterable, n):
while True:
try: