diff options
Diffstat (limited to 'buildstream/_artifactcache/cascache.py')
-rw-r--r-- | buildstream/_artifactcache/cascache.py | 225 |
1 files changed, 196 insertions, 29 deletions
diff --git a/buildstream/_artifactcache/cascache.py b/buildstream/_artifactcache/cascache.py index b6b1d436d..e2c0d44b5 100644 --- a/buildstream/_artifactcache/cascache.py +++ b/buildstream/_artifactcache/cascache.py @@ -43,6 +43,11 @@ from .._exceptions import ArtifactError from . import ArtifactCache +# The default limit for gRPC messages is 4 MiB. +# Limit payload to 1 MiB to leave sufficient headroom for metadata. +_MAX_PAYLOAD_BYTES = 1024 * 1024 + + # A CASCache manages artifacts in a CAS repository as specified in the # Remote Execution API. # @@ -115,7 +120,7 @@ class CASCache(ArtifactCache): def commit(self, element, content, keys): refs = [self.get_artifact_fullname(element, key) for key in keys] - tree = self._create_tree(content) + tree = self._commit_directory(content) for ref in refs: self.set_ref(ref, tree) @@ -330,12 +335,12 @@ class CASCache(ArtifactCache): finished = False remaining = digest.size_bytes while not finished: - chunk_size = min(remaining, 64 * 1024) + chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) remaining -= chunk_size request = bytestream_pb2.WriteRequest() request.write_offset = offset - # max. 64 kB chunks + # max. _MAX_PAYLOAD_BYTES chunks request.data = f.read(chunk_size) request.resource_name = resname request.finish_write = remaining <= 0 @@ -623,7 +628,21 @@ class CASCache(ArtifactCache): def _refpath(self, ref): return os.path.join(self.casdir, 'refs', 'heads', ref) - def _create_tree(self, path, *, digest=None): + # _commit_directory(): + # + # Adds local directory to content addressable store. + # + # Adds files, symbolic links and recursively other directories in + # a local directory to the content addressable store. + # + # Args: + # path (str): Path to the directory to add. + # dir_digest (Digest): An optional Digest object to use. + # + # Returns: + # (Digest): Digest object for the directory added. + # + def _commit_directory(self, path, *, dir_digest=None): directory = remote_execution_pb2.Directory() for name in sorted(os.listdir(path)): @@ -632,7 +651,7 @@ class CASCache(ArtifactCache): if stat.S_ISDIR(mode): dirnode = directory.directories.add() dirnode.name = name - self._create_tree(full_path, digest=dirnode.digest) + self._commit_directory(full_path, dir_digest=dirnode.digest) elif stat.S_ISREG(mode): filenode = directory.files.add() filenode.name = name @@ -645,7 +664,8 @@ class CASCache(ArtifactCache): else: raise ArtifactError("Unsupported file type for {}".format(full_path)) - return self.add_object(digest=digest, buffer=directory.SerializeToString()) + return self.add_object(digest=dir_digest, + buffer=directory.SerializeToString()) def _get_subdir(self, tree, subdir): head, name = os.path.split(subdir) @@ -788,39 +808,119 @@ class CASCache(ArtifactCache): out.flush() assert digest.size_bytes == os.fstat(out.fileno()).st_size - def _fetch_directory(self, remote, tree): - objpath = self.objpath(tree) + # _ensure_blob(): + # + # Fetch and add blob if it's not already local. + # + # Args: + # remote (Remote): The remote to use. + # digest (Digest): Digest object for the blob to fetch. + # + # Returns: + # (str): The path of the object + # + def _ensure_blob(self, remote, digest): + objpath = self.objpath(digest) if os.path.exists(objpath): - # already in local cache - return + # already in local repository + return objpath - with tempfile.NamedTemporaryFile(dir=self.tmpdir) as out: - self._fetch_blob(remote, tree, out) + with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: + self._fetch_blob(remote, digest, f) - directory = remote_execution_pb2.Directory() + added_digest = self.add_object(path=f.name) + assert added_digest.hash == digest.hash - with open(out.name, 'rb') as f: - directory.ParseFromString(f.read()) + return objpath - for filenode in directory.files: - fileobjpath = self.objpath(tree) - if os.path.exists(fileobjpath): - # already in local cache - continue + def _batch_download_complete(self, batch): + for digest, data in batch.send(): + with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: + f.write(data) + f.flush() + + added_digest = self.add_object(path=f.name) + assert added_digest.hash == digest.hash + + # Helper function for _fetch_directory(). + def _fetch_directory_batch(self, remote, batch, fetch_queue, fetch_next_queue): + self._batch_download_complete(batch) - with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: - self._fetch_blob(remote, filenode.digest, f) + # All previously scheduled directories are now locally available, + # move them to the processing queue. + fetch_queue.extend(fetch_next_queue) + fetch_next_queue.clear() + return _CASBatchRead(remote) - digest = self.add_object(path=f.name) - assert digest.hash == filenode.digest.hash + # Helper function for _fetch_directory(). + def _fetch_directory_node(self, remote, digest, batch, fetch_queue, fetch_next_queue, *, recursive=False): + in_local_cache = os.path.exists(self.objpath(digest)) + + if in_local_cache: + # Skip download, already in local cache. + pass + elif (digest.size_bytes >= remote.max_batch_total_size_bytes or + not remote.batch_read_supported): + # Too large for batch request, download in independent request. + self._ensure_blob(remote, digest) + in_local_cache = True + else: + if not batch.add(digest): + # Not enough space left in batch request. + # Complete pending batch first. + batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) + batch.add(digest) + + if recursive: + if in_local_cache: + # Add directory to processing queue. + fetch_queue.append(digest) + else: + # Directory will be available after completing pending batch. + # Add directory to deferred processing queue. + fetch_next_queue.append(digest) + + return batch + + # _fetch_directory(): + # + # Fetches remote directory and adds it to content addressable store. + # + # Fetches files, symbolic links and recursively other directories in + # the remote directory and adds them to the content addressable + # store. + # + # Args: + # remote (Remote): The remote to use. + # dir_digest (Digest): Digest object for the directory to fetch. + # + def _fetch_directory(self, remote, dir_digest): + fetch_queue = [dir_digest] + fetch_next_queue = [] + batch = _CASBatchRead(remote) + + while len(fetch_queue) + len(fetch_next_queue) > 0: + if len(fetch_queue) == 0: + batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) + + dir_digest = fetch_queue.pop(0) + + objpath = self._ensure_blob(remote, dir_digest) + + directory = remote_execution_pb2.Directory() + with open(objpath, 'rb') as f: + directory.ParseFromString(f.read()) for dirnode in directory.directories: - self._fetch_directory(remote, dirnode.digest) + batch = self._fetch_directory_node(remote, dirnode.digest, batch, + fetch_queue, fetch_next_queue, recursive=True) + + for filenode in directory.files: + batch = self._fetch_directory_node(remote, filenode.digest, batch, + fetch_queue, fetch_next_queue) - # place directory blob only in final location when we've downloaded - # all referenced blobs to avoid dangling references in the repository - digest = self.add_object(path=out.name) - assert digest.hash == tree.hash + # Fetch final batch + self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) # Represents a single remote CAS cache. @@ -870,11 +970,78 @@ class _CASRemote(): self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel) self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) + self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) + self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES + try: + request = remote_execution_pb2.GetCapabilitiesRequest() + response = self.capabilities.GetCapabilities(request) + server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes + if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: + self.max_batch_total_size_bytes = server_max_batch_total_size_bytes + except grpc.RpcError as e: + # Simply use the defaults for servers that don't implement GetCapabilities() + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + + # Check whether the server supports BatchReadBlobs() + self.batch_read_supported = False + try: + request = remote_execution_pb2.BatchReadBlobsRequest() + response = self.cas.BatchReadBlobs(request) + self.batch_read_supported = True + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + self._initialized = True +# Represents a batch of blobs queued for fetching. +# +class _CASBatchRead(): + def __init__(self, remote): + self._remote = remote + self._max_total_size_bytes = remote.max_batch_total_size_bytes + self._request = remote_execution_pb2.BatchReadBlobsRequest() + self._size = 0 + self._sent = False + + def add(self, digest): + assert not self._sent + + new_batch_size = self._size + digest.size_bytes + if new_batch_size > self._max_total_size_bytes: + # Not enough space left in current batch + return False + + request_digest = self._request.digests.add() + request_digest.hash = digest.hash + request_digest.size_bytes = digest.size_bytes + self._size = new_batch_size + return True + + def send(self): + assert not self._sent + self._sent = True + + if len(self._request.digests) == 0: + return + + batch_response = self._remote.cas.BatchReadBlobs(self._request) + + for response in batch_response.responses: + if response.status.code != grpc.StatusCode.OK.value[0]: + raise ArtifactError("Failed to download blob {}: {}".format( + response.digest.hash, response.status.code)) + if response.digest.size_bytes != len(response.data): + raise ArtifactError("Failed to download blob {}: expected {} bytes, received {} bytes".format( + response.digest.hash, response.digest.size_bytes, len(response.data))) + + yield (response.digest, response.data) + + def _grouper(iterable, n): while True: try: |