From fd46a9f9779781664fa3eef697bc489e5d1cf4cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrg=20Billeter?= Date: Thu, 13 Sep 2018 10:07:37 +0200 Subject: _artifactcache/cascache.py: Use BatchReadBlobs This uses BatchReadBlobs instead of individual blob download to speed up artifact pulling, if the server supports it. Fixes #554. --- buildstream/_artifactcache/cascache.py | 149 ++++++++++++++++++++++++++++++--- 1 file changed, 137 insertions(+), 12 deletions(-) diff --git a/buildstream/_artifactcache/cascache.py b/buildstream/_artifactcache/cascache.py index 7360d548e..e2c0d44b5 100644 --- a/buildstream/_artifactcache/cascache.py +++ b/buildstream/_artifactcache/cascache.py @@ -833,6 +833,55 @@ class CASCache(ArtifactCache): return objpath + def _batch_download_complete(self, batch): + for digest, data in batch.send(): + with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: + f.write(data) + f.flush() + + added_digest = self.add_object(path=f.name) + assert added_digest.hash == digest.hash + + # Helper function for _fetch_directory(). + def _fetch_directory_batch(self, remote, batch, fetch_queue, fetch_next_queue): + self._batch_download_complete(batch) + + # All previously scheduled directories are now locally available, + # move them to the processing queue. + fetch_queue.extend(fetch_next_queue) + fetch_next_queue.clear() + return _CASBatchRead(remote) + + # Helper function for _fetch_directory(). + def _fetch_directory_node(self, remote, digest, batch, fetch_queue, fetch_next_queue, *, recursive=False): + in_local_cache = os.path.exists(self.objpath(digest)) + + if in_local_cache: + # Skip download, already in local cache. + pass + elif (digest.size_bytes >= remote.max_batch_total_size_bytes or + not remote.batch_read_supported): + # Too large for batch request, download in independent request. + self._ensure_blob(remote, digest) + in_local_cache = True + else: + if not batch.add(digest): + # Not enough space left in batch request. + # Complete pending batch first. + batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) + batch.add(digest) + + if recursive: + if in_local_cache: + # Add directory to processing queue. + fetch_queue.append(digest) + else: + # Directory will be available after completing pending batch. + # Add directory to deferred processing queue. + fetch_next_queue.append(digest) + + return batch + # _fetch_directory(): # # Fetches remote directory and adds it to content addressable store. @@ -846,23 +895,32 @@ class CASCache(ArtifactCache): # dir_digest (Digest): Digest object for the directory to fetch. # def _fetch_directory(self, remote, dir_digest): - objpath = self.objpath(dir_digest) - if os.path.exists(objpath): - # already in local cache - return + fetch_queue = [dir_digest] + fetch_next_queue = [] + batch = _CASBatchRead(remote) - objpath = self._ensure_blob(remote, dir_digest) + while len(fetch_queue) + len(fetch_next_queue) > 0: + if len(fetch_queue) == 0: + batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) - directory = remote_execution_pb2.Directory() + dir_digest = fetch_queue.pop(0) - with open(objpath, 'rb') as f: - directory.ParseFromString(f.read()) + objpath = self._ensure_blob(remote, dir_digest) - for filenode in directory.files: - self._ensure_blob(remote, filenode.digest) + directory = remote_execution_pb2.Directory() + with open(objpath, 'rb') as f: + directory.ParseFromString(f.read()) - for dirnode in directory.directories: - self._fetch_directory(remote, dirnode.digest) + for dirnode in directory.directories: + batch = self._fetch_directory_node(remote, dirnode.digest, batch, + fetch_queue, fetch_next_queue, recursive=True) + + for filenode in directory.files: + batch = self._fetch_directory_node(remote, filenode.digest, batch, + fetch_queue, fetch_next_queue) + + # Fetch final batch + self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) # Represents a single remote CAS cache. @@ -912,11 +970,78 @@ class _CASRemote(): self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel) self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) + self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) + self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES + try: + request = remote_execution_pb2.GetCapabilitiesRequest() + response = self.capabilities.GetCapabilities(request) + server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes + if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: + self.max_batch_total_size_bytes = server_max_batch_total_size_bytes + except grpc.RpcError as e: + # Simply use the defaults for servers that don't implement GetCapabilities() + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + + # Check whether the server supports BatchReadBlobs() + self.batch_read_supported = False + try: + request = remote_execution_pb2.BatchReadBlobsRequest() + response = self.cas.BatchReadBlobs(request) + self.batch_read_supported = True + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + self._initialized = True +# Represents a batch of blobs queued for fetching. +# +class _CASBatchRead(): + def __init__(self, remote): + self._remote = remote + self._max_total_size_bytes = remote.max_batch_total_size_bytes + self._request = remote_execution_pb2.BatchReadBlobsRequest() + self._size = 0 + self._sent = False + + def add(self, digest): + assert not self._sent + + new_batch_size = self._size + digest.size_bytes + if new_batch_size > self._max_total_size_bytes: + # Not enough space left in current batch + return False + + request_digest = self._request.digests.add() + request_digest.hash = digest.hash + request_digest.size_bytes = digest.size_bytes + self._size = new_batch_size + return True + + def send(self): + assert not self._sent + self._sent = True + + if len(self._request.digests) == 0: + return + + batch_response = self._remote.cas.BatchReadBlobs(self._request) + + for response in batch_response.responses: + if response.status.code != grpc.StatusCode.OK.value[0]: + raise ArtifactError("Failed to download blob {}: {}".format( + response.digest.hash, response.status.code)) + if response.digest.size_bytes != len(response.data): + raise ArtifactError("Failed to download blob {}: expected {} bytes, received {} bytes".format( + response.digest.hash, response.digest.size_bytes, len(response.data))) + + yield (response.digest, response.data) + + def _grouper(iterable, n): while True: try: -- cgit v1.2.1