summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJürg Billeter <j@bitron.ch>2018-08-14 17:24:46 +0200
committerJürg Billeter <j@bitron.ch>2018-08-15 15:02:32 +0200
commit6a9d737e56077ba735b83fc94040e5707ce10d84 (patch)
treea8958e438a2d5c4ef802fcc0f0fd55ff20df6684
parent1202ef8af64d328bd5db2ebf19d9246b9bfec195 (diff)
downloadbuildstream-6a9d737e56077ba735b83fc94040e5707ce10d84.tar.gz
_artifactcache/casserver.py: Improve ByteStream error handling
Replace assertions with gRPC error responses.
-rw-r--r--buildstream/_artifactcache/casserver.py71
1 files changed, 52 insertions, 19 deletions
diff --git a/buildstream/_artifactcache/casserver.py b/buildstream/_artifactcache/casserver.py
index 3cb8944ad..0af65729b 100644
--- a/buildstream/_artifactcache/casserver.py
+++ b/buildstream/_artifactcache/casserver.py
@@ -132,11 +132,20 @@ class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer):
def Read(self, request, context):
resource_name = request.resource_name
client_digest = _digest_from_download_resource_name(resource_name)
- assert request.read_offset <= client_digest.size_bytes
+ if client_digest is None:
+ context.set_code(grpc.StatusCode.NOT_FOUND)
+ return
+
+ if request.read_offset > client_digest.size_bytes:
+ context.set_code(grpc.StatusCode.OUT_OF_RANGE)
+ return
try:
with open(self.cas.objpath(client_digest), 'rb') as f:
- assert os.fstat(f.fileno()).st_size == client_digest.size_bytes
+ if os.fstat(f.fileno()).st_size != client_digest.size_bytes:
+ context.set_code(grpc.StatusCode.NOT_FOUND)
+ return
+
if request.read_offset > 0:
f.seek(request.read_offset)
@@ -164,12 +173,18 @@ class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer):
resource_name = None
with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out:
for request in request_iterator:
- assert not finished
- assert request.write_offset == offset
+ if finished or request.write_offset != offset:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ return response
+
if resource_name is None:
# First request
resource_name = request.resource_name
client_digest = _digest_from_upload_resource_name(resource_name)
+ if client_digest is None:
+ context.set_code(grpc.StatusCode.NOT_FOUND)
+ return response
+
try:
_clean_up_cache(self.cas, client_digest.size_bytes)
except ArtifactTooLargeException as e:
@@ -178,14 +193,20 @@ class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer):
return response
elif request.resource_name:
# If it is set on subsequent calls, it **must** match the value of the first request.
- assert request.resource_name == resource_name
+ if request.resource_name != resource_name:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ return response
out.write(request.data)
offset += len(request.data)
if request.finish_write:
- assert client_digest.size_bytes == offset
+ if client_digest.size_bytes != offset:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ return response
out.flush()
digest = self.cas.add_object(path=out.name)
- assert digest.hash == client_digest.hash
+ if digest.hash != client_digest.hash:
+ context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
+ return response
finished = True
assert finished
@@ -255,11 +276,16 @@ def _digest_from_download_resource_name(resource_name):
if len(parts) == 2:
parts.insert(0, 'blobs')
- assert len(parts) == 3 and parts[0] == 'blobs'
- digest = remote_execution_pb2.Digest()
- digest.hash = parts[1]
- digest.size_bytes = int(parts[2])
- return digest
+ if len(parts) != 3 or parts[0] != 'blobs':
+ return None
+
+ try:
+ digest = remote_execution_pb2.Digest()
+ digest.hash = parts[1]
+ digest.size_bytes = int(parts[2])
+ return digest
+ except ValueError:
+ return None
def _digest_from_upload_resource_name(resource_name):
@@ -271,13 +297,20 @@ def _digest_from_upload_resource_name(resource_name):
parts.insert(1, str(uuid.uuid4()))
parts.insert(2, 'blobs')
- assert len(parts) >= 5 and parts[0] == 'uploads' and parts[2] == 'blobs'
- uuid_ = uuid.UUID(hex=parts[1])
- assert uuid_.version == 4
- digest = remote_execution_pb2.Digest()
- digest.hash = parts[3]
- digest.size_bytes = int(parts[4])
- return digest
+ if len(parts) < 5 or parts[0] != 'uploads' or parts[2] != 'blobs':
+ return None
+
+ try:
+ uuid_ = uuid.UUID(hex=parts[1])
+ if uuid_.version != 4:
+ return None
+
+ digest = remote_execution_pb2.Digest()
+ digest.hash = parts[3]
+ digest.size_bytes = int(parts[4])
+ return digest
+ except ValueError:
+ return None
def _has_object(cas, digest):