diff options
author | Finn <finn.ball@codethink.co.uk> | 2018-12-06 17:12:48 +0000 |
---|---|---|
committer | Finn <finn.ball@codethink.co.uk> | 2018-12-07 14:38:27 +0000 |
commit | 6e3e34f73c68e36a3eaab6b6dbb8d3b7972561f7 (patch) | |
tree | b9979609ba96b633da603624e8a815072e2ca1d9 | |
parent | 617b581e39e6253d7f43e86cc91e527e5a18c861 (diff) | |
download | buildstream-finn/refactor-remote-stubs.tar.gz |
Refactored cas stubs.finn/refactor-remote-stubs
-rw-r--r-- | buildstream/_artifactcache/cascache.py | 103 |
1 files changed, 66 insertions, 37 deletions
diff --git a/buildstream/_artifactcache/cascache.py b/buildstream/_artifactcache/cascache.py index 96d938987..0f958fe3b 100644 --- a/buildstream/_artifactcache/cascache.py +++ b/buildstream/_artifactcache/cascache.py @@ -901,10 +901,8 @@ class CASCache(): def _fetch_blob(self, remote, digest, stream): resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)]) - request = bytestream_pb2.ReadRequest() - request.resource_name = resource_name - request.read_offset = 0 - for response in remote.bytestream.Read(request): + for response in remote.read(resource_name=resource_name, read_offset=0): + print(response) stream.write(response.data) stream.flush() @@ -1059,27 +1057,7 @@ class CASCache(): resource_name = '/'.join(['uploads', str(u_uid), 'blobs', digest.hash, str(digest.size_bytes)]) - def request_stream(resname, instream): - offset = 0 - finished = False - remaining = digest.size_bytes - while not finished: - chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) - remaining -= chunk_size - - request = bytestream_pb2.WriteRequest() - request.write_offset = offset - # max. _MAX_PAYLOAD_BYTES chunks - request.data = instream.read(chunk_size) - request.resource_name = resname - request.finish_write = remaining <= 0 - - yield request - - offset += chunk_size - finished = request.finish_write - - response = remote.bytestream.Write(request_stream(resource_name, stream)) + response = remote.write(digest, stream, resource_name) assert response.committed_size == digest.size_bytes @@ -1089,14 +1067,16 @@ class CASCache(): missing_blobs = dict() # Limit size of FindMissingBlobs request for required_blobs_group in _grouper(required_blobs, 512): - request = remote_execution_pb2.FindMissingBlobsRequest() + digests = [] for required_digest in required_blobs_group: - d = request.blob_digests.add() + # d = request.blob_digests.add() + d = remote_execution_pb2.Digest() d.hash = required_digest.hash d.size_bytes = required_digest.size_bytes + digests.append(d) + response = remote.find_missing_blobs(digests) - response = remote.cas.FindMissingBlobs(request) for missing_digest in response.missing_blob_digests: d = remote_execution_pb2.Digest() d.hash = missing_digest.hash @@ -1136,12 +1116,12 @@ class CASRemote(): self.spec = spec self._initialized = False self.channel = None - self.bytestream = None - self.cas = None self.batch_update_supported = None self.batch_read_supported = None self.max_batch_total_size_bytes = None + self._bytestream_stub = None + self._cas_stub = None self._capabilities_stub = None self._ref_storage_stub = None @@ -1179,8 +1159,8 @@ class CASRemote(): else: raise CASError("Unsupported URL: {}".format(self.spec.url)) - self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel) - self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) + self._bytestream_stub = bytestream_pb2_grpc.ByteStreamStub(self.channel) + self._cas_stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) self._capabilities_stub = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) self._ref_storage_stub = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) @@ -1200,7 +1180,7 @@ class CASRemote(): self.batch_read_supported = False try: request = remote_execution_pb2.BatchReadBlobsRequest() - response = self.cas.BatchReadBlobs(request) + response = self._cas_stub.BatchReadBlobs(request) self.batch_read_supported = True except grpc.RpcError as e: if e.code() != grpc.StatusCode.UNIMPLEMENTED: @@ -1210,7 +1190,7 @@ class CASRemote(): self.batch_update_supported = False try: request = remote_execution_pb2.BatchUpdateBlobsRequest() - response = self.cas.BatchUpdateBlobs(request) + response = self._cas_stub.BatchUpdateBlobs(request) self.batch_update_supported = True except grpc.RpcError as e: if (e.code() != grpc.StatusCode.UNIMPLEMENTED and @@ -1219,6 +1199,52 @@ class CASRemote(): self._initialized = True + def write(self, digest, stream, resource_name): + + def __request_stream(resname, instream): + offset = 0 + finished = False + remaining = digest.size_bytes + while not finished: + chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) + remaining -= chunk_size + + request = bytestream_pb2.WriteRequest() + request.write_offset = offset + # max. _MAX_PAYLOAD_BYTES chunks + request.data = instream.read(chunk_size) + request.resource_name = resname + request.finish_write = remaining <= 0 + + yield request + + offset += chunk_size + finished = request.finish_write + + response = self._bytestream_stub.Write(__request_stream(resource_name, stream)) + return response + + def read(self, resource_name, read_offset): + request = bytestream_pb2.ReadRequest() + request.resource_name = resource_name + request.read_offset = read_offset + response = self._bytestream_stub.Read(request) + return response + + def batch_read_blobs(self, request): + response = self._cas_stub.BatchReadBlobs(request) + return response + + def find_missing_blobs(self, digests): + request = remote_execution_pb2.FindMissingBlobsRequest() + request.blob_digests.extend(digests) + response = self._cas_stub.FindMissingBlobs(request) + return response + + def batch_update_blobs(self, request): + response = self._cas_stub.BatchUpdateBlobs(request) + return response + def get_capabilities(self): request = remote_execution_pb2.GetCapabilitiesRequest() return self._capabilities_stub.GetCapabilities(request) @@ -1267,7 +1293,7 @@ class _CASBatchRead(): if not self._request.digests: return - batch_response = self._remote.cas.BatchReadBlobs(self._request) + batch_response = self._remote.batch_read_blobs(self._request) for response in batch_response.responses: if response.status.code == code_pb2.NOT_FOUND: @@ -1293,6 +1319,9 @@ class _CASBatchUpdate(): self._size = 0 self._sent = False + self._hash = None + self._size_bytes = None + def add(self, digest, stream): assert not self._sent @@ -1305,7 +1334,7 @@ class _CASBatchUpdate(): blob_request.digest.hash = digest.hash blob_request.digest.size_bytes = digest.size_bytes blob_request.data = stream.read(digest.size_bytes) - self._size = new_batch_size + return True def send(self): @@ -1315,7 +1344,7 @@ class _CASBatchUpdate(): if not self._request.requests: return - batch_response = self._remote.cas.BatchUpdateBlobs(self._request) + batch_response = self._remote.batch_update_blobs(self._request) for response in batch_response.responses: if response.status.code != code_pb2.OK: |