summaryrefslogtreecommitdiff
path: root/buildstream/_artifactcache/cascache.py
diff options
context:
space:
mode:
Diffstat (limited to 'buildstream/_artifactcache/cascache.py')
-rw-r--r--buildstream/_artifactcache/cascache.py225
1 files changed, 196 insertions, 29 deletions
diff --git a/buildstream/_artifactcache/cascache.py b/buildstream/_artifactcache/cascache.py
index b6b1d436d..e2c0d44b5 100644
--- a/buildstream/_artifactcache/cascache.py
+++ b/buildstream/_artifactcache/cascache.py
@@ -43,6 +43,11 @@ from .._exceptions import ArtifactError
from . import ArtifactCache
+# The default limit for gRPC messages is 4 MiB.
+# Limit payload to 1 MiB to leave sufficient headroom for metadata.
+_MAX_PAYLOAD_BYTES = 1024 * 1024
+
+
# A CASCache manages artifacts in a CAS repository as specified in the
# Remote Execution API.
#
@@ -115,7 +120,7 @@ class CASCache(ArtifactCache):
def commit(self, element, content, keys):
refs = [self.get_artifact_fullname(element, key) for key in keys]
- tree = self._create_tree(content)
+ tree = self._commit_directory(content)
for ref in refs:
self.set_ref(ref, tree)
@@ -330,12 +335,12 @@ class CASCache(ArtifactCache):
finished = False
remaining = digest.size_bytes
while not finished:
- chunk_size = min(remaining, 64 * 1024)
+ chunk_size = min(remaining, _MAX_PAYLOAD_BYTES)
remaining -= chunk_size
request = bytestream_pb2.WriteRequest()
request.write_offset = offset
- # max. 64 kB chunks
+ # max. _MAX_PAYLOAD_BYTES chunks
request.data = f.read(chunk_size)
request.resource_name = resname
request.finish_write = remaining <= 0
@@ -623,7 +628,21 @@ class CASCache(ArtifactCache):
def _refpath(self, ref):
return os.path.join(self.casdir, 'refs', 'heads', ref)
- def _create_tree(self, path, *, digest=None):
+ # _commit_directory():
+ #
+ # Adds local directory to content addressable store.
+ #
+ # Adds files, symbolic links and recursively other directories in
+ # a local directory to the content addressable store.
+ #
+ # Args:
+ # path (str): Path to the directory to add.
+ # dir_digest (Digest): An optional Digest object to use.
+ #
+ # Returns:
+ # (Digest): Digest object for the directory added.
+ #
+ def _commit_directory(self, path, *, dir_digest=None):
directory = remote_execution_pb2.Directory()
for name in sorted(os.listdir(path)):
@@ -632,7 +651,7 @@ class CASCache(ArtifactCache):
if stat.S_ISDIR(mode):
dirnode = directory.directories.add()
dirnode.name = name
- self._create_tree(full_path, digest=dirnode.digest)
+ self._commit_directory(full_path, dir_digest=dirnode.digest)
elif stat.S_ISREG(mode):
filenode = directory.files.add()
filenode.name = name
@@ -645,7 +664,8 @@ class CASCache(ArtifactCache):
else:
raise ArtifactError("Unsupported file type for {}".format(full_path))
- return self.add_object(digest=digest, buffer=directory.SerializeToString())
+ return self.add_object(digest=dir_digest,
+ buffer=directory.SerializeToString())
def _get_subdir(self, tree, subdir):
head, name = os.path.split(subdir)
@@ -788,39 +808,119 @@ class CASCache(ArtifactCache):
out.flush()
assert digest.size_bytes == os.fstat(out.fileno()).st_size
- def _fetch_directory(self, remote, tree):
- objpath = self.objpath(tree)
+ # _ensure_blob():
+ #
+ # Fetch and add blob if it's not already local.
+ #
+ # Args:
+ # remote (Remote): The remote to use.
+ # digest (Digest): Digest object for the blob to fetch.
+ #
+ # Returns:
+ # (str): The path of the object
+ #
+ def _ensure_blob(self, remote, digest):
+ objpath = self.objpath(digest)
if os.path.exists(objpath):
- # already in local cache
- return
+ # already in local repository
+ return objpath
- with tempfile.NamedTemporaryFile(dir=self.tmpdir) as out:
- self._fetch_blob(remote, tree, out)
+ with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
+ self._fetch_blob(remote, digest, f)
- directory = remote_execution_pb2.Directory()
+ added_digest = self.add_object(path=f.name)
+ assert added_digest.hash == digest.hash
- with open(out.name, 'rb') as f:
- directory.ParseFromString(f.read())
+ return objpath
- for filenode in directory.files:
- fileobjpath = self.objpath(tree)
- if os.path.exists(fileobjpath):
- # already in local cache
- continue
+ def _batch_download_complete(self, batch):
+ for digest, data in batch.send():
+ with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
+ f.write(data)
+ f.flush()
+
+ added_digest = self.add_object(path=f.name)
+ assert added_digest.hash == digest.hash
+
+ # Helper function for _fetch_directory().
+ def _fetch_directory_batch(self, remote, batch, fetch_queue, fetch_next_queue):
+ self._batch_download_complete(batch)
- with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
- self._fetch_blob(remote, filenode.digest, f)
+ # All previously scheduled directories are now locally available,
+ # move them to the processing queue.
+ fetch_queue.extend(fetch_next_queue)
+ fetch_next_queue.clear()
+ return _CASBatchRead(remote)
- digest = self.add_object(path=f.name)
- assert digest.hash == filenode.digest.hash
+ # Helper function for _fetch_directory().
+ def _fetch_directory_node(self, remote, digest, batch, fetch_queue, fetch_next_queue, *, recursive=False):
+ in_local_cache = os.path.exists(self.objpath(digest))
+
+ if in_local_cache:
+ # Skip download, already in local cache.
+ pass
+ elif (digest.size_bytes >= remote.max_batch_total_size_bytes or
+ not remote.batch_read_supported):
+ # Too large for batch request, download in independent request.
+ self._ensure_blob(remote, digest)
+ in_local_cache = True
+ else:
+ if not batch.add(digest):
+ # Not enough space left in batch request.
+ # Complete pending batch first.
+ batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
+ batch.add(digest)
+
+ if recursive:
+ if in_local_cache:
+ # Add directory to processing queue.
+ fetch_queue.append(digest)
+ else:
+ # Directory will be available after completing pending batch.
+ # Add directory to deferred processing queue.
+ fetch_next_queue.append(digest)
+
+ return batch
+
+ # _fetch_directory():
+ #
+ # Fetches remote directory and adds it to content addressable store.
+ #
+ # Fetches files, symbolic links and recursively other directories in
+ # the remote directory and adds them to the content addressable
+ # store.
+ #
+ # Args:
+ # remote (Remote): The remote to use.
+ # dir_digest (Digest): Digest object for the directory to fetch.
+ #
+ def _fetch_directory(self, remote, dir_digest):
+ fetch_queue = [dir_digest]
+ fetch_next_queue = []
+ batch = _CASBatchRead(remote)
+
+ while len(fetch_queue) + len(fetch_next_queue) > 0:
+ if len(fetch_queue) == 0:
+ batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
+
+ dir_digest = fetch_queue.pop(0)
+
+ objpath = self._ensure_blob(remote, dir_digest)
+
+ directory = remote_execution_pb2.Directory()
+ with open(objpath, 'rb') as f:
+ directory.ParseFromString(f.read())
for dirnode in directory.directories:
- self._fetch_directory(remote, dirnode.digest)
+ batch = self._fetch_directory_node(remote, dirnode.digest, batch,
+ fetch_queue, fetch_next_queue, recursive=True)
+
+ for filenode in directory.files:
+ batch = self._fetch_directory_node(remote, filenode.digest, batch,
+ fetch_queue, fetch_next_queue)
- # place directory blob only in final location when we've downloaded
- # all referenced blobs to avoid dangling references in the repository
- digest = self.add_object(path=out.name)
- assert digest.hash == tree.hash
+ # Fetch final batch
+ self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
# Represents a single remote CAS cache.
@@ -870,11 +970,78 @@ class _CASRemote():
self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel)
self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
+ self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)
+ self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES
+ try:
+ request = remote_execution_pb2.GetCapabilitiesRequest()
+ response = self.capabilities.GetCapabilities(request)
+ server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes
+ if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes:
+ self.max_batch_total_size_bytes = server_max_batch_total_size_bytes
+ except grpc.RpcError as e:
+ # Simply use the defaults for servers that don't implement GetCapabilities()
+ if e.code() != grpc.StatusCode.UNIMPLEMENTED:
+ raise
+
+ # Check whether the server supports BatchReadBlobs()
+ self.batch_read_supported = False
+ try:
+ request = remote_execution_pb2.BatchReadBlobsRequest()
+ response = self.cas.BatchReadBlobs(request)
+ self.batch_read_supported = True
+ except grpc.RpcError as e:
+ if e.code() != grpc.StatusCode.UNIMPLEMENTED:
+ raise
+
self._initialized = True
+# Represents a batch of blobs queued for fetching.
+#
+class _CASBatchRead():
+ def __init__(self, remote):
+ self._remote = remote
+ self._max_total_size_bytes = remote.max_batch_total_size_bytes
+ self._request = remote_execution_pb2.BatchReadBlobsRequest()
+ self._size = 0
+ self._sent = False
+
+ def add(self, digest):
+ assert not self._sent
+
+ new_batch_size = self._size + digest.size_bytes
+ if new_batch_size > self._max_total_size_bytes:
+ # Not enough space left in current batch
+ return False
+
+ request_digest = self._request.digests.add()
+ request_digest.hash = digest.hash
+ request_digest.size_bytes = digest.size_bytes
+ self._size = new_batch_size
+ return True
+
+ def send(self):
+ assert not self._sent
+ self._sent = True
+
+ if len(self._request.digests) == 0:
+ return
+
+ batch_response = self._remote.cas.BatchReadBlobs(self._request)
+
+ for response in batch_response.responses:
+ if response.status.code != grpc.StatusCode.OK.value[0]:
+ raise ArtifactError("Failed to download blob {}: {}".format(
+ response.digest.hash, response.status.code))
+ if response.digest.size_bytes != len(response.data):
+ raise ArtifactError("Failed to download blob {}: expected {} bytes, received {} bytes".format(
+ response.digest.hash, response.digest.size_bytes, len(response.data)))
+
+ yield (response.digest, response.data)
+
+
def _grouper(iterable, n):
while True:
try: