summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFinn <finn.ball@codethink.co.uk>2018-12-06 17:12:48 +0000
committerFinn <finn.ball@codethink.co.uk>2018-12-07 14:38:27 +0000
commit6e3e34f73c68e36a3eaab6b6dbb8d3b7972561f7 (patch)
treeb9979609ba96b633da603624e8a815072e2ca1d9
parent617b581e39e6253d7f43e86cc91e527e5a18c861 (diff)
downloadbuildstream-finn/refactor-remote-stubs.tar.gz
Refactored cas stubs.finn/refactor-remote-stubs
-rw-r--r--buildstream/_artifactcache/cascache.py103
1 files changed, 66 insertions, 37 deletions
diff --git a/buildstream/_artifactcache/cascache.py b/buildstream/_artifactcache/cascache.py
index 96d938987..0f958fe3b 100644
--- a/buildstream/_artifactcache/cascache.py
+++ b/buildstream/_artifactcache/cascache.py
@@ -901,10 +901,8 @@ class CASCache():
def _fetch_blob(self, remote, digest, stream):
resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
- request = bytestream_pb2.ReadRequest()
- request.resource_name = resource_name
- request.read_offset = 0
- for response in remote.bytestream.Read(request):
+ for response in remote.read(resource_name=resource_name, read_offset=0):
+ print(response)
stream.write(response.data)
stream.flush()
@@ -1059,27 +1057,7 @@ class CASCache():
resource_name = '/'.join(['uploads', str(u_uid), 'blobs',
digest.hash, str(digest.size_bytes)])
- def request_stream(resname, instream):
- offset = 0
- finished = False
- remaining = digest.size_bytes
- while not finished:
- chunk_size = min(remaining, _MAX_PAYLOAD_BYTES)
- remaining -= chunk_size
-
- request = bytestream_pb2.WriteRequest()
- request.write_offset = offset
- # max. _MAX_PAYLOAD_BYTES chunks
- request.data = instream.read(chunk_size)
- request.resource_name = resname
- request.finish_write = remaining <= 0
-
- yield request
-
- offset += chunk_size
- finished = request.finish_write
-
- response = remote.bytestream.Write(request_stream(resource_name, stream))
+ response = remote.write(digest, stream, resource_name)
assert response.committed_size == digest.size_bytes
@@ -1089,14 +1067,16 @@ class CASCache():
missing_blobs = dict()
# Limit size of FindMissingBlobs request
for required_blobs_group in _grouper(required_blobs, 512):
- request = remote_execution_pb2.FindMissingBlobsRequest()
+ digests = []
for required_digest in required_blobs_group:
- d = request.blob_digests.add()
+ # d = request.blob_digests.add()
+ d = remote_execution_pb2.Digest()
d.hash = required_digest.hash
d.size_bytes = required_digest.size_bytes
+ digests.append(d)
+ response = remote.find_missing_blobs(digests)
- response = remote.cas.FindMissingBlobs(request)
for missing_digest in response.missing_blob_digests:
d = remote_execution_pb2.Digest()
d.hash = missing_digest.hash
@@ -1136,12 +1116,12 @@ class CASRemote():
self.spec = spec
self._initialized = False
self.channel = None
- self.bytestream = None
- self.cas = None
self.batch_update_supported = None
self.batch_read_supported = None
self.max_batch_total_size_bytes = None
+ self._bytestream_stub = None
+ self._cas_stub = None
self._capabilities_stub = None
self._ref_storage_stub = None
@@ -1179,8 +1159,8 @@ class CASRemote():
else:
raise CASError("Unsupported URL: {}".format(self.spec.url))
- self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel)
- self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
+ self._bytestream_stub = bytestream_pb2_grpc.ByteStreamStub(self.channel)
+ self._cas_stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
self._capabilities_stub = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
self._ref_storage_stub = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)
@@ -1200,7 +1180,7 @@ class CASRemote():
self.batch_read_supported = False
try:
request = remote_execution_pb2.BatchReadBlobsRequest()
- response = self.cas.BatchReadBlobs(request)
+ response = self._cas_stub.BatchReadBlobs(request)
self.batch_read_supported = True
except grpc.RpcError as e:
if e.code() != grpc.StatusCode.UNIMPLEMENTED:
@@ -1210,7 +1190,7 @@ class CASRemote():
self.batch_update_supported = False
try:
request = remote_execution_pb2.BatchUpdateBlobsRequest()
- response = self.cas.BatchUpdateBlobs(request)
+ response = self._cas_stub.BatchUpdateBlobs(request)
self.batch_update_supported = True
except grpc.RpcError as e:
if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
@@ -1219,6 +1199,52 @@ class CASRemote():
self._initialized = True
+ def write(self, digest, stream, resource_name):
+
+ def __request_stream(resname, instream):
+ offset = 0
+ finished = False
+ remaining = digest.size_bytes
+ while not finished:
+ chunk_size = min(remaining, _MAX_PAYLOAD_BYTES)
+ remaining -= chunk_size
+
+ request = bytestream_pb2.WriteRequest()
+ request.write_offset = offset
+ # max. _MAX_PAYLOAD_BYTES chunks
+ request.data = instream.read(chunk_size)
+ request.resource_name = resname
+ request.finish_write = remaining <= 0
+
+ yield request
+
+ offset += chunk_size
+ finished = request.finish_write
+
+ response = self._bytestream_stub.Write(__request_stream(resource_name, stream))
+ return response
+
+ def read(self, resource_name, read_offset):
+ request = bytestream_pb2.ReadRequest()
+ request.resource_name = resource_name
+ request.read_offset = read_offset
+ response = self._bytestream_stub.Read(request)
+ return response
+
+ def batch_read_blobs(self, request):
+ response = self._cas_stub.BatchReadBlobs(request)
+ return response
+
+ def find_missing_blobs(self, digests):
+ request = remote_execution_pb2.FindMissingBlobsRequest()
+ request.blob_digests.extend(digests)
+ response = self._cas_stub.FindMissingBlobs(request)
+ return response
+
+ def batch_update_blobs(self, request):
+ response = self._cas_stub.BatchUpdateBlobs(request)
+ return response
+
def get_capabilities(self):
request = remote_execution_pb2.GetCapabilitiesRequest()
return self._capabilities_stub.GetCapabilities(request)
@@ -1267,7 +1293,7 @@ class _CASBatchRead():
if not self._request.digests:
return
- batch_response = self._remote.cas.BatchReadBlobs(self._request)
+ batch_response = self._remote.batch_read_blobs(self._request)
for response in batch_response.responses:
if response.status.code == code_pb2.NOT_FOUND:
@@ -1293,6 +1319,9 @@ class _CASBatchUpdate():
self._size = 0
self._sent = False
+ self._hash = None
+ self._size_bytes = None
+
def add(self, digest, stream):
assert not self._sent
@@ -1305,7 +1334,7 @@ class _CASBatchUpdate():
blob_request.digest.hash = digest.hash
blob_request.digest.size_bytes = digest.size_bytes
blob_request.data = stream.read(digest.size_bytes)
- self._size = new_batch_size
+
return True
def send(self):
@@ -1315,7 +1344,7 @@ class _CASBatchUpdate():
if not self._request.requests:
return
- batch_response = self._remote.cas.BatchUpdateBlobs(self._request)
+ batch_response = self._remote.batch_update_blobs(self._request)
for response in batch_response.responses:
if response.status.code != code_pb2.OK: