From cdc5b6f5bcf604841f6a291349231df60f407d18 Mon Sep 17 00:00:00 2001 From: Raoul Hidalgo Charman Date: Tue, 11 Dec 2018 11:41:44 +0000 Subject: cas: move remote only functions to CASRemote List of methods moved * Initialization check: made it a class method that is run in a subprocess, for when checking in the main buildstream process. * fetch_blobs * send_blobs * verify_digest_on_remote * push_method Part of #802 --- buildstream/_artifactcache.py | 18 +---- buildstream/_cas/cascache.py | 127 +----------------------------- buildstream/_cas/casremote.py | 141 +++++++++++++++++++++++++++++++++- buildstream/sandbox/_sandboxremote.py | 6 +- 4 files changed, 148 insertions(+), 144 deletions(-) diff --git a/buildstream/_artifactcache.py b/buildstream/_artifactcache.py index 1b2b55da2..cdbf2d9db 100644 --- a/buildstream/_artifactcache.py +++ b/buildstream/_artifactcache.py @@ -19,14 +19,12 @@ import multiprocessing import os -import signal import string from collections.abc import Mapping from .types import _KeyStrength from ._exceptions import ArtifactError, CASError, LoadError, LoadErrorReason from ._message import Message, MessageType -from . import _signals from . import utils from . import _yaml @@ -375,20 +373,8 @@ class ArtifactCache(): remotes = {} q = multiprocessing.Queue() for remote_spec in remote_specs: - # Use subprocess to avoid creation of gRPC threads in main BuildStream process - # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details - p = multiprocessing.Process(target=self.cas.initialize_remote, args=(remote_spec, q)) - try: - # Keep SIGINT blocked in the child process - with _signals.blocked([signal.SIGINT], ignore=False): - p.start() - - error = q.get() - p.join() - except KeyboardInterrupt: - utils._kill_process_tree(p.pid) - raise + error = CASRemote.check_remote(remote_spec, q) if error and on_failure: on_failure(remote_spec.url, error) @@ -747,7 +733,7 @@ class ArtifactCache(): "servers are configured as push remotes.") for remote in push_remotes: - message_digest = self.cas.push_message(remote, message) + message_digest = remote.push_message(message) return message_digest diff --git a/buildstream/_cas/cascache.py b/buildstream/_cas/cascache.py index 5f62e6105..adbd34c9e 100644 --- a/buildstream/_cas/cascache.py +++ b/buildstream/_cas/cascache.py @@ -19,7 +19,6 @@ import hashlib import itertools -import io import os import stat import tempfile @@ -28,14 +27,13 @@ import contextlib import grpc -from .._protos.google.bytestream import bytestream_pb2 from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 from .._protos.buildstream.v2 import buildstream_pb2 from .. import utils from .._exceptions import CASCacheError -from .casremote import CASRemote, BlobNotFound, _CASBatchRead, _CASBatchUpdate, _MAX_PAYLOAD_BYTES +from .casremote import BlobNotFound, _CASBatchRead, _CASBatchUpdate # A CASCache manages a CAS repository as specified in the Remote Execution API. @@ -185,29 +183,6 @@ class CASCache(): return modified, removed, added - def initialize_remote(self, remote_spec, q): - try: - remote = CASRemote(remote_spec) - remote.init() - - request = buildstream_pb2.StatusRequest(instance_name=remote_spec.instance_name) - response = remote.ref_storage.Status(request) - - if remote_spec.push and not response.allow_updates: - q.put('CAS server does not allow push') - else: - # No error - q.put(None) - - except grpc.RpcError as e: - # str(e) is too verbose for errors reported to the user - q.put(e.details()) - - except Exception as e: # pylint: disable=broad-except - # Whatever happens, we need to return it to the calling process - # - q.put(str(e)) - # pull(): # # Pull a ref from a remote repository. @@ -355,50 +330,6 @@ class CASCache(): self._send_directory(remote, directory.ref) - # push_message(): - # - # Push the given protobuf message to a remote. - # - # Args: - # remote (CASRemote): The remote to push to - # message (Message): A protobuf message to push. - # - # Raises: - # (CASCacheError): if there was an error - # - def push_message(self, remote, message): - - message_buffer = message.SerializeToString() - message_digest = utils._message_digest(message_buffer) - - remote.init() - - with io.BytesIO(message_buffer) as b: - self._send_blob(remote, message_digest, b) - - return message_digest - - # verify_digest_on_remote(): - # - # Check whether the object is already on the server in which case - # there is no need to upload it. - # - # Args: - # remote (CASRemote): The remote to check - # digest (Digest): The object digest. - # - def verify_digest_on_remote(self, remote, digest): - remote.init() - - request = remote_execution_pb2.FindMissingBlobsRequest(instance_name=remote.spec.instance_name) - request.blob_digests.extend([digest]) - - response = remote.cas.FindMissingBlobs(request) - if digest in response.missing_blob_digests: - return False - - return True - # objpath(): # # Return the path of an object based on its digest. @@ -849,23 +780,6 @@ class CASCache(): for dirnode in directory.directories: yield from self._required_blobs(dirnode.digest) - def _fetch_blob(self, remote, digest, stream): - resource_name_components = ['blobs', digest.hash, str(digest.size_bytes)] - - if remote.spec.instance_name: - resource_name_components.insert(0, remote.spec.instance_name) - - resource_name = '/'.join(resource_name_components) - - request = bytestream_pb2.ReadRequest() - request.resource_name = resource_name - request.read_offset = 0 - for response in remote.bytestream.Read(request): - stream.write(response.data) - stream.flush() - - assert digest.size_bytes == os.fstat(stream.fileno()).st_size - # _ensure_blob(): # # Fetch and add blob if it's not already local. @@ -884,7 +798,7 @@ class CASCache(): return objpath with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: - self._fetch_blob(remote, digest, f) + remote._fetch_blob(digest, f) added_digest = self.add_object(path=f.name, link_directly=True) assert added_digest.hash == digest.hash @@ -991,7 +905,7 @@ class CASCache(): def _fetch_tree(self, remote, digest): # download but do not store the Tree object with tempfile.NamedTemporaryFile(dir=self.tmpdir) as out: - self._fetch_blob(remote, digest, out) + remote._fetch_blob(digest, out) tree = remote_execution_pb2.Tree() @@ -1011,39 +925,6 @@ class CASCache(): return dirdigest - def _send_blob(self, remote, digest, stream, u_uid=uuid.uuid4()): - resource_name_components = ['uploads', str(u_uid), 'blobs', - digest.hash, str(digest.size_bytes)] - - if remote.spec.instance_name: - resource_name_components.insert(0, remote.spec.instance_name) - - resource_name = '/'.join(resource_name_components) - - def request_stream(resname, instream): - offset = 0 - finished = False - remaining = digest.size_bytes - while not finished: - chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) - remaining -= chunk_size - - request = bytestream_pb2.WriteRequest() - request.write_offset = offset - # max. _MAX_PAYLOAD_BYTES chunks - request.data = instream.read(chunk_size) - request.resource_name = resname - request.finish_write = remaining <= 0 - - yield request - - offset += chunk_size - finished = request.finish_write - - response = remote.bytestream.Write(request_stream(resource_name, stream)) - - assert response.committed_size == digest.size_bytes - def _send_directory(self, remote, digest, u_uid=uuid.uuid4()): required_blobs = self._required_blobs(digest) @@ -1077,7 +958,7 @@ class CASCache(): if (digest.size_bytes >= remote.max_batch_total_size_bytes or not remote.batch_update_supported): # Too large for batch request, upload in independent request. - self._send_blob(remote, digest, f, u_uid=u_uid) + remote._send_blob(digest, f, u_uid=u_uid) else: if not batch.add(digest, f): # Not enough space left in batch request. diff --git a/buildstream/_cas/casremote.py b/buildstream/_cas/casremote.py index 59eb7e363..56ba4c5d8 100644 --- a/buildstream/_cas/casremote.py +++ b/buildstream/_cas/casremote.py @@ -1,16 +1,22 @@ from collections import namedtuple +import io import os +import multiprocessing +import signal from urllib.parse import urlparse +import uuid import grpc from .. import _yaml from .._protos.google.rpc import code_pb2 -from .._protos.google.bytestream import bytestream_pb2_grpc +from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc -from .._protos.buildstream.v2 import buildstream_pb2_grpc +from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc from .._exceptions import CASRemoteError, LoadError, LoadErrorReason +from .. import _signals +from .. import utils # The default limit for gRPC messages is 4 MiB. # Limit payload to 1 MiB to leave sufficient headroom for metadata. @@ -159,6 +165,137 @@ class CASRemote(): self._initialized = True + # check_remote + # + # Used when checking whether remote_specs work in the buildstream main + # thread, runs this in a seperate process to avoid creation of gRPC threads + # in the main BuildStream process + # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details + @classmethod + def check_remote(cls, remote_spec, q): + + def __check_remote(): + try: + remote = cls(remote_spec) + remote.init() + + request = buildstream_pb2.StatusRequest() + response = remote.ref_storage.Status(request) + + if remote_spec.push and not response.allow_updates: + q.put('CAS server does not allow push') + else: + # No error + q.put(None) + + except grpc.RpcError as e: + # str(e) is too verbose for errors reported to the user + q.put(e.details()) + + except Exception as e: # pylint: disable=broad-except + # Whatever happens, we need to return it to the calling process + # + q.put(str(e)) + + p = multiprocessing.Process(target=__check_remote) + + try: + # Keep SIGINT blocked in the child process + with _signals.blocked([signal.SIGINT], ignore=False): + p.start() + + error = q.get() + p.join() + except KeyboardInterrupt: + utils._kill_process_tree(p.pid) + raise + + return error + + # verify_digest_on_remote(): + # + # Check whether the object is already on the server in which case + # there is no need to upload it. + # + # Args: + # digest (Digest): The object digest. + # + def verify_digest_on_remote(self, digest): + self.init() + + request = remote_execution_pb2.FindMissingBlobsRequest() + request.blob_digests.extend([digest]) + + response = self.cas.FindMissingBlobs(request) + if digest in response.missing_blob_digests: + return False + + return True + + # push_message(): + # + # Push the given protobuf message to a remote. + # + # Args: + # message (Message): A protobuf message to push. + # + # Raises: + # (CASRemoteError): if there was an error + # + def push_message(self, message): + + message_buffer = message.SerializeToString() + message_digest = utils._message_digest(message_buffer) + + self.init() + + with io.BytesIO(message_buffer) as b: + self._send_blob(message_digest, b) + + return message_digest + + ################################################ + # Local Private Methods # + ################################################ + def _fetch_blob(self, digest, stream): + resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)]) + request = bytestream_pb2.ReadRequest() + request.resource_name = resource_name + request.read_offset = 0 + for response in self.bytestream.Read(request): + stream.write(response.data) + stream.flush() + + assert digest.size_bytes == os.fstat(stream.fileno()).st_size + + def _send_blob(self, digest, stream, u_uid=uuid.uuid4()): + resource_name = '/'.join(['uploads', str(u_uid), 'blobs', + digest.hash, str(digest.size_bytes)]) + + def request_stream(resname, instream): + offset = 0 + finished = False + remaining = digest.size_bytes + while not finished: + chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) + remaining -= chunk_size + + request = bytestream_pb2.WriteRequest() + request.write_offset = offset + # max. _MAX_PAYLOAD_BYTES chunks + request.data = instream.read(chunk_size) + request.resource_name = resname + request.finish_write = remaining <= 0 + + yield request + + offset += chunk_size + finished = request.finish_write + + response = self.bytestream.Write(request_stream(resource_name, stream)) + + assert response.committed_size == digest.size_bytes + # Represents a batch of blobs queued for fetching. # diff --git a/buildstream/sandbox/_sandboxremote.py b/buildstream/sandbox/_sandboxremote.py index 8b4c87cf5..6a1f6f2f8 100644 --- a/buildstream/sandbox/_sandboxremote.py +++ b/buildstream/sandbox/_sandboxremote.py @@ -348,17 +348,17 @@ class SandboxRemote(Sandbox): except grpc.RpcError as e: raise SandboxError("Failed to push source directory to remote: {}".format(e)) from e - if not cascache.verify_digest_on_remote(casremote, upload_vdir.ref): + if not casremote.verify_digest_on_remote(upload_vdir.ref): raise SandboxError("Failed to verify that source has been pushed to the remote artifact cache.") # Push command and action try: - cascache.push_message(casremote, command_proto) + casremote.push_message(command_proto) except grpc.RpcError as e: raise SandboxError("Failed to push command to remote: {}".format(e)) try: - cascache.push_message(casremote, action) + casremote.push_message(action) except grpc.RpcError as e: raise SandboxError("Failed to push action to remote: {}".format(e)) -- cgit v1.2.1