diff options
author | Tristan Maat <tristan.maat@codethink.co.uk> | 2019-08-22 17:48:34 +0100 |
---|---|---|
committer | Tristan Maat <tristan.maat@codethink.co.uk> | 2019-09-06 15:55:10 +0100 |
commit | 47a3f93d9795be6af849c112d4180f0ad50ca23b (patch) | |
tree | 2d65dd2c24d9d6bd6795f0680811cf95ae3803e4 /src | |
parent | e71621510de7c55aae4855f8bbb64eb2755346a8 (diff) | |
download | buildstream-47a3f93d9795be6af849c112d4180f0ad50ca23b.tar.gz |
Allow splitting artifact caches
This is now split into storage/index remotes, where the former is
expected to be a CASRemote and the latter a BuildStream-specific
remote with the extensions required to store BuildStream artifact
protos.
Diffstat (limited to 'src')
-rw-r--r-- | src/buildstream/_artifactcache.py | 373 | ||||
-rw-r--r-- | src/buildstream/_basecache.py | 150 | ||||
-rw-r--r-- | src/buildstream/_cas/casremote.py | 255 | ||||
-rw-r--r-- | src/buildstream/_cas/casserver.py | 56 | ||||
-rw-r--r-- | src/buildstream/_exceptions.py | 10 | ||||
-rw-r--r-- | src/buildstream/_remote.py | 121 | ||||
-rw-r--r-- | src/buildstream/_sourcecache.py | 195 |
7 files changed, 714 insertions, 446 deletions
diff --git a/src/buildstream/_artifactcache.py b/src/buildstream/_artifactcache.py index 73047d376..0e2eb1091 100644 --- a/src/buildstream/_artifactcache.py +++ b/src/buildstream/_artifactcache.py @@ -25,48 +25,89 @@ from ._exceptions import ArtifactError, CASError, CASCacheError, CASRemoteError from ._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc, \ artifact_pb2, artifact_pb2_grpc -from ._cas import CASRemote +from ._remote import BaseRemote from .storage._casbaseddirectory import CasBasedDirectory from ._artifact import Artifact from . import utils +# ArtifactRemote(): +# +# Facilitates communication with the BuildStream-specific part of +# artifact remotes. +# +class ArtifactRemote(BaseRemote): + # _configure_protocols(): + # + # Configure the protocols used by this remote as part of the + # remote initialization; Note that this should only be used in + # Remote.init(), and is expected to fail when called by itself. + # + def _configure_protocols(self): + # Add artifact stub + capabilities_service = buildstream_pb2_grpc.CapabilitiesStub(self.channel) + # Check whether the server supports newer proto based artifact. + try: + request = buildstream_pb2.GetCapabilitiesRequest() + if self.instance_name: + request.instance_name = self.instance_name + response = capabilities_service.GetCapabilities(request) + except grpc.RpcError as e: + # Check if this remote has the artifact service + if e.code() == grpc.StatusCode.UNIMPLEMENTED: + raise ArtifactError( + "Configured remote does not have the BuildStream " + "capabilities service. Please check remote configuration.") + # Else raise exception with details + raise ArtifactError( + "Remote initialisation failed: {}".format(e.details())) -# ArtifactRemote extends CASRemote to check during initialisation that there is -# an artifact service -class ArtifactRemote(CASRemote): - def __init__(self, *args): - super().__init__(*args) - self.capabilities_service = None + if not response.artifact_capabilities: + raise ArtifactError( + "Configured remote does not support artifact service") - def init(self): - if not self._initialized: - # do default initialisation - super().init() + # get_artifact(): + # + # Get an artifact proto for a given cache key from the remote. + # + # Args: + # cache_key (str): The artifact cache key. NOTE: This "key" + # is actually the ref/name and its name in + # the protocol is inaccurate. You have been warned. + # + # Returns: + # (Artifact): The artifact proto + # + # Raises: + # grpc.RpcError: If someting goes wrong during the request. + # + def get_artifact(self, cache_key): + artifact_request = artifact_pb2.GetArtifactRequest() + artifact_request.cache_key = cache_key - # Add artifact stub - self.capabilities_service = buildstream_pb2_grpc.CapabilitiesStub(self.channel) + artifact_service = artifact_pb2_grpc.ArtifactServiceStub(self.channel) + return artifact_service.GetArtifact(artifact_request) - # Check whether the server supports newer proto based artifact. - try: - request = buildstream_pb2.GetCapabilitiesRequest() - if self.instance_name: - request.instance_name = self.instance_name - response = self.capabilities_service.GetCapabilities(request) - except grpc.RpcError as e: - # Check if this remote has the artifact service - if e.code() == grpc.StatusCode.UNIMPLEMENTED: - raise ArtifactError( - "Configured remote does not have the BuildStream " - "capabilities service. Please check remote configuration.") - # Else raise exception with details - raise ArtifactError( - "Remote initialisation failed: {}".format(e.details())) + # update_artifact(): + # + # Update an artifact with the given cache key on the remote with + # the given proto. + # + # Args: + # cache_key (str): The artifact cache key of the artifact to update. + # artifact (ArtifactProto): The artifact proto to send. + # + # Raises: + # grpc.RpcError: If someting goes wrong during the request. + # + def update_artifact(self, cache_key, artifact): + update_request = artifact_pb2.UpdateArtifactRequest() + update_request.cache_key = cache_key + update_request.artifact.CopyFrom(artifact) - if not response.artifact_capabilities: - raise ArtifactError( - "Configured remote does not support artifact service") + artifact_service = artifact_pb2_grpc.ArtifactServiceStub(self.channel) + artifact_service.UpdateArtifact(update_request) # An ArtifactCache manages artifacts. @@ -79,7 +120,7 @@ class ArtifactCache(BaseCache): spec_name = "artifact_cache_specs" spec_error = ArtifactError config_node_name = "artifacts" - remote_class = ArtifactRemote + index_remote_class = ArtifactRemote def __init__(self, context): super().__init__(context) @@ -187,22 +228,35 @@ class ArtifactCache(BaseCache): # def push(self, element, artifact): project = element._get_project() + display_key = element._get_brief_display_key() - push_remotes = [r for r in self._remotes[project] if r.spec.push] + index_remotes = [r for r in self._index_remotes[project] if r.push] + storage_remotes = [r for r in self._storage_remotes[project] if r.push] pushed = False + # First push our files to all storage remotes, so that they + # can perform file checks on their end + for remote in storage_remotes: + remote.init() + element.status("Pushing data from artifact {} -> {}".format(display_key, remote)) - for remote in push_remotes: + if self._push_artifact_blobs(artifact, remote): + element.info("Pushed data from artifact {} -> {}".format(display_key, remote)) + else: + element.info("Remote ({}) already has all data of artifact {} cached".format( + remote, element._get_brief_display_key() + )) + + for remote in index_remotes: remote.init() - display_key = element._get_brief_display_key() - element.status("Pushing artifact {} -> {}".format(display_key, remote.spec.url)) + element.status("Pushing artifact {} -> {}".format(display_key, remote)) - if self._push_artifact(element, artifact, remote): - element.info("Pushed artifact {} -> {}".format(display_key, remote.spec.url)) + if self._push_artifact_proto(element, artifact, remote): + element.info("Pushed artifact {} -> {}".format(display_key, remote)) pushed = True else: element.info("Remote ({}) already has artifact {} cached".format( - remote.spec.url, element._get_brief_display_key() + remote, element._get_brief_display_key() )) return pushed @@ -220,26 +274,59 @@ class ArtifactCache(BaseCache): # (bool): True if pull was successful, False if artifact was not available # def pull(self, element, key, *, pull_buildtrees=False): + artifact = None display_key = key[:self.context.log_key_length] project = element._get_project() - for remote in self._remotes[project]: + errors = [] + # Start by pulling our artifact proto, so that we know which + # blobs to pull + for remote in self._index_remotes[project]: remote.init() try: - element.status("Pulling artifact {} <- {}".format(display_key, remote.spec.url)) - - if self._pull_artifact(element, key, remote, pull_buildtrees=pull_buildtrees): - element.info("Pulled artifact {} <- {}".format(display_key, remote.spec.url)) - # no need to pull from additional remotes - return True + element.status("Pulling artifact {} <- {}".format(display_key, remote)) + artifact = self._pull_artifact_proto(element, key, remote) + if artifact: + element.info("Pulled artifact {} <- {}".format(display_key, remote)) + break else: element.info("Remote ({}) does not have artifact {} cached".format( - remote.spec.url, display_key + remote, display_key )) + except CASError as e: + element.warn("Could not pull from remote {}: {}".format(remote, e)) + errors.append(e) + + if errors and not artifact: + raise ArtifactError("Failed to pull artifact {}".format(display_key), + detail="\n".join(str(e) for e in errors)) + + # If we don't have an artifact, we can't exactly pull our + # artifact + if not artifact: + return False + + errors = [] + # If we do, we can pull it! + for remote in self._storage_remotes[project]: + remote.init() + try: + element.status("Pulling data for artifact {} <- {}".format(display_key, remote)) + if self._pull_artifact_storage(element, artifact, remote, pull_buildtrees=pull_buildtrees): + element.info("Pulled data for artifact {} <- {}".format(display_key, remote)) + return True + + element.info("Remote ({}) does not have artifact {} cached".format( + remote, display_key + )) except CASError as e: - raise ArtifactError("Failed to pull artifact {}: {}".format( - display_key, e)) from e + element.warn("Could not pull from remote {}: {}".format(remote, e)) + errors.append(e) + + if errors: + raise ArtifactError("Failed to pull artifact {}".format(display_key), + detail="\n".join(str(e) for e in errors)) return False @@ -253,7 +340,7 @@ class ArtifactCache(BaseCache): # digest (Digest): The digest of the tree # def pull_tree(self, project, digest): - for remote in self._remotes[project]: + for remote in self._storage_remotes[project]: digest = self.cas.pull_tree(remote, digest) if digest: @@ -276,7 +363,7 @@ class ArtifactCache(BaseCache): def push_message(self, project, message): if self._has_push_remotes: - push_remotes = [r for r in self._remotes[project] if r.spec.push] + push_remotes = [r for r in self._storage_remotes[project] if r.spec.push] else: push_remotes = [] @@ -330,7 +417,7 @@ class ArtifactCache(BaseCache): # missing_blobs (list): The Digests of the blobs to fetch # def fetch_missing_blobs(self, project, missing_blobs): - for remote in self._remotes[project]: + for remote in self._index_remotes[project]: if not missing_blobs: break @@ -357,7 +444,7 @@ class ArtifactCache(BaseCache): if not missing_blobs: return [] - push_remotes = [r for r in self._remotes[project] if r.spec.push] + push_remotes = [r for r in self._storage_remotes[project] if r.spec.push] remote_missing_blobs_list = [] @@ -384,12 +471,12 @@ class ArtifactCache(BaseCache): # def check_remotes_for_element(self, element): # If there are no remotes - if not self._remotes: + if not self._index_remotes: return False project = element._get_project() ref = element.get_artifact_name() - for remote in self._remotes[project]: + for remote in self._index_remotes[project]: remote.init() if self._query_remote(ref, remote): @@ -401,40 +488,59 @@ class ArtifactCache(BaseCache): # Local Private Methods # ################################################ - # _push_artifact() + # _reachable_directories() # - # Pushes relevant directories and then artifact proto to remote. + # Returns: + # (iter): Iterator over directories digests available from artifacts. # - # Args: - # element (Element): The element - # artifact (Artifact): The related artifact being pushed - # remote (CASRemote): Remote to push to + def _reachable_directories(self): + for root, _, files in os.walk(self.artifactdir): + for artifact_file in files: + artifact = artifact_pb2.Artifact() + with open(os.path.join(root, artifact_file), 'r+b') as f: + artifact.ParseFromString(f.read()) + + if str(artifact.files): + yield artifact.files + + if str(artifact.buildtree): + yield artifact.buildtree + + # _reachable_digests() # # Returns: - # (bool): whether the push was successful + # (iter): Iterator over single file digests in artifacts # - def _push_artifact(self, element, artifact, remote): + def _reachable_digests(self): + for root, _, files in os.walk(self.artifactdir): + for artifact_file in files: + artifact = artifact_pb2.Artifact() + with open(os.path.join(root, artifact_file), 'r+b') as f: + artifact.ParseFromString(f.read()) - artifact_proto = artifact._get_proto() + if str(artifact.public_data): + yield artifact.public_data - keys = list(utils._deduplicate([artifact_proto.strong_key, artifact_proto.weak_key])) + for log_file in artifact.logs: + yield log_file.digest - # Check whether the artifact is on the server - present = False - for key in keys: - get_artifact = artifact_pb2.GetArtifactRequest() - get_artifact.cache_key = element.get_artifact_name(key) - try: - artifact_service = artifact_pb2_grpc.ArtifactServiceStub(remote.channel) - artifact_service.GetArtifact(get_artifact) - except grpc.RpcError as e: - if e.code() != grpc.StatusCode.NOT_FOUND: - raise ArtifactError("Error checking artifact cache: {}" - .format(e.details())) - else: - present = True - if present: - return False + # _push_artifact_blobs() + # + # Push the blobs that make up an artifact to the remote server. + # + # Args: + # artifact (Artifact): The artifact whose blobs to push. + # remote (CASRemote): The remote to push the blobs to. + # + # Returns: + # (bool) - True if we uploaded anything, False otherwise. + # + # Raises: + # ArtifactError: If we fail to push blobs (*unless* they're + # already there or we run out of space on the server). + # + def _push_artifact_blobs(self, artifact, remote): + artifact_proto = artifact._get_proto() try: self.cas._send_directory(remote, artifact_proto.files) @@ -463,33 +569,68 @@ class ArtifactCache(BaseCache): raise ArtifactError("Failed to push artifact blobs: {}".format(e.details())) return False - # finally need to send the artifact proto + return True + + # _push_artifact_proto() + # + # Pushes the artifact proto to remote. + # + # Args: + # element (Element): The element + # artifact (Artifact): The related artifact being pushed + # remote (ArtifactRemote): Remote to push to + # + # Returns: + # (bool): Whether we pushed the artifact. + # + # Raises: + # ArtifactError: If the push fails for any reason except the + # artifact already existing. + # + def _push_artifact_proto(self, element, artifact, remote): + + artifact_proto = artifact._get_proto() + + keys = list(utils._deduplicate([artifact_proto.strong_key, artifact_proto.weak_key])) + + # Check whether the artifact is on the server for key in keys: - update_artifact = artifact_pb2.UpdateArtifactRequest() - update_artifact.cache_key = element.get_artifact_name(key) - update_artifact.artifact.CopyFrom(artifact_proto) + try: + remote.get_artifact(element.get_artifact_name(key=key)) + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.NOT_FOUND: + raise ArtifactError("Error checking artifact cache: {}" + .format(e.details())) + else: + return False + # If not, we send the artifact proto + for key in keys: try: - artifact_service = artifact_pb2_grpc.ArtifactServiceStub(remote.channel) - artifact_service.UpdateArtifact(update_artifact) + remote.update_artifact(element.get_artifact_name(key=key), artifact_proto) except grpc.RpcError as e: raise ArtifactError("Failed to push artifact: {}".format(e.details())) return True - # _pull_artifact() + # _pull_artifact_storage(): + # + # Pull artifact blobs from the given remote. # # Args: - # element (Element): element to pull - # key (str): specific key of element to pull - # remote (CASRemote): remote to pull from - # pull_buildtree (bool): whether to pull buildtrees or not + # element (Element): element to pull + # key (str): The specific key for the artifact to pull + # remote (CASRemote): remote to pull from + # pull_buildtree (bool): whether to pull buildtrees or not # # Returns: - # (bool): whether the pull was successful + # (bool): True if we pulled any blobs. # - def _pull_artifact(self, element, key, remote, pull_buildtrees=False): - + # Raises: + # ArtifactError: If the pull failed for any reason except the + # blobs not existing on the server. + # + def _pull_artifact_storage(self, element, artifact, remote, pull_buildtrees=False): def __pull_digest(digest): self.cas._fetch_directory(remote, digest) required_blobs = self.cas.required_blobs_for_directory(digest) @@ -497,16 +638,6 @@ class ArtifactCache(BaseCache): if missing_blobs: self.cas.fetch_blobs(remote, missing_blobs) - request = artifact_pb2.GetArtifactRequest() - request.cache_key = element.get_artifact_name(key=key) - try: - artifact_service = artifact_pb2_grpc.ArtifactServiceStub(remote.channel) - artifact = artifact_service.GetArtifact(request) - except grpc.RpcError as e: - if e.code() != grpc.StatusCode.NOT_FOUND: - raise ArtifactError("Failed to pull artifact: {}".format(e.details())) - return False - try: if str(artifact.files): __pull_digest(artifact.files) @@ -527,13 +658,41 @@ class ArtifactCache(BaseCache): raise ArtifactError("Failed to pull artifact: {}".format(e.details())) return False + return True + + # _pull_artifact_proto(): + # + # Pull an artifact proto from a remote server. + # + # Args: + # element (Element): The element whose artifact to pull. + # key (str): The specific key for the artifact to pull. + # remote (ArtifactRemote): The remote to pull from. + # + # Returns: + # (Artifact|None): The artifact proto, or None if the server + # doesn't have it. + # + # Raises: + # ArtifactError: If the pull fails. + # + def _pull_artifact_proto(self, element, key, remote): + artifact_name = element.get_artifact_name(key=key) + + try: + artifact = remote.get_artifact(artifact_name) + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.NOT_FOUND: + raise ArtifactError("Failed to pull artifact: {}".format(e.details())) + return None + # Write the artifact proto to cache - artifact_path = os.path.join(self.artifactdir, request.cache_key) + artifact_path = os.path.join(self.artifactdir, artifact_name) os.makedirs(os.path.dirname(artifact_path), exist_ok=True) with utils.save_file_atomic(artifact_path, mode='wb') as f: f.write(artifact.SerializeToString()) - return True + return artifact # _query_remote() # diff --git a/src/buildstream/_basecache.py b/src/buildstream/_basecache.py index 9197c91b0..df50bfb62 100644 --- a/src/buildstream/_basecache.py +++ b/src/buildstream/_basecache.py @@ -16,21 +16,23 @@ # Authors: # Raoul Hidalgo Charman <raoul.hidalgocharman@codethink.co.uk> # -import multiprocessing import os from fnmatch import fnmatch +from itertools import chain from typing import TYPE_CHECKING from . import utils from . import _yaml from ._cas import CASRemote from ._message import Message, MessageType -from ._exceptions import LoadError -from ._remote import RemoteSpec +from ._exceptions import LoadError, RemoteError +from ._remote import RemoteSpec, RemoteType + if TYPE_CHECKING: from typing import Optional, Type from ._exceptions import BstError + from ._remote import BaseRemote # Base Cache for Caches to derive from @@ -39,18 +41,20 @@ class BaseCache(): # None of these should ever be called in the base class, but this appeases # pylint to some degree - spec_name = None # type: Type[RemoteSpec] - spec_error = None # type: Type[BstError] - config_node_name = None # type: str - remote_class = CASRemote # type: Type[CASRemote] + spec_name = None # type: str + spec_error = None # type: Type[BstError] + config_node_name = None # type: str + index_remote_class = None # type: Type[BaseRemote] + storage_remote_class = CASRemote # type: Type[BaseRemote] def __init__(self, context): self.context = context self.cas = context.get_cascache() self._remotes_setup = False # Check to prevent double-setup of remotes - # Per-project list of _CASRemote instances. - self._remotes = {} + # Per-project list of Remote instances. + self._storage_remotes = {} + self._index_remotes = {} self.global_remote_specs = [] self.project_remote_specs = {} @@ -64,7 +68,7 @@ class BaseCache(): # against fork() with open gRPC channels. # def has_open_grpc_channels(self): - for project_remotes in self._remotes.values(): + for project_remotes in chain(self._index_remotes.values(), self._storage_remotes.values()): for remote in project_remotes: if remote.channel: return True @@ -76,7 +80,7 @@ class BaseCache(): # def release_resources(self): # Close all remotes and their gRPC channels - for project_remotes in self._remotes.values(): + for project_remotes in chain(self._index_remotes.values(), self._storage_remotes.values()): for remote in project_remotes: remote.close() @@ -157,7 +161,6 @@ class BaseCache(): # the user config in some cases (for example `bst artifact push --remote=...`). has_remote_caches = False if remote_url: - # pylint: disable=not-callable self._set_remotes([RemoteSpec(remote_url, push=True)]) has_remote_caches = True if use_config: @@ -169,6 +172,15 @@ class BaseCache(): if has_remote_caches: self._initialize_remotes() + # Notify remotes that forking is disabled + def notify_fork_disabled(self): + for project in self._index_remotes: + for remote in self._index_remotes[project]: + remote.notify_fork_disabled() + for project in self._storage_remotes: + for remote in self._storage_remotes[project]: + remote.notify_fork_disabled() + # initialize_remotes(): # # This will contact each remote cache. @@ -177,7 +189,7 @@ class BaseCache(): # on_failure (callable): Called if we fail to contact one of the caches. # def initialize_remotes(self, *, on_failure=None): - remotes = self._create_remote_instances(on_failure=on_failure) + index_remotes, storage_remotes = self._create_remote_instances(on_failure=on_failure) # Assign remote instances to their respective projects for project in self.context.get_projects(): @@ -188,21 +200,20 @@ class BaseCache(): remote_specs.extend(self.project_remote_specs[project]) # De-duplicate the list - remote_specs = utils._deduplicate(remote_specs) + remote_specs = list(utils._deduplicate(remote_specs)) - project_remotes = [] + def get_remotes(remote_list, remote_specs): + for remote_spec in remote_specs: + # If a remote_spec didn't make it into the remotes + # dict, that means we can't access it, and it has been + # disabled for this session. + if remote_spec not in remote_list: + continue - for remote_spec in remote_specs: - # If a remote_spec didn't make it into the remotes - # dict, that means we can't access it, and it has been - # disabled for this session. - if remote_spec not in remotes: - continue + yield remote_list[remote_spec] - remote = remotes[remote_spec] - project_remotes.append(remote) - - self._remotes[project] = project_remotes + self._index_remotes[project] = list(get_remotes(index_remotes, remote_specs)) + self._storage_remotes[project] = list(get_remotes(storage_remotes, remote_specs)) # has_fetch_remotes(): # @@ -222,8 +233,9 @@ class BaseCache(): return True else: # Check whether the specified element's project has fetch remotes - remotes_for_project = self._remotes[plugin._get_project()] - return bool(remotes_for_project) + index_remotes = self._index_remotes[plugin._get_project()] + storage_remotes = self._storage_remotes[plugin._get_project()] + return index_remotes and storage_remotes # has_push_remotes(): # @@ -243,8 +255,10 @@ class BaseCache(): return True else: # Check whether the specified element's project has push remotes - remotes_for_project = self._remotes[plugin._get_project()] - return any(remote.spec.push for remote in remotes_for_project) + index_remotes = self._index_remotes[plugin._get_project()] + storage_remotes = self._storage_remotes[plugin._get_project()] + return (any(remote.spec.push for remote in index_remotes) and + any(remote.spec.push for remote in storage_remotes)) ################################################ # Local Private Methods # @@ -261,8 +275,9 @@ class BaseCache(): # What do do when a remote doesn't respond. # # Returns: - # (Dict[RemoteSpec, self.remote_class]) - - # The created remote instances. + # (Dict[RemoteSpec, self.remote_class], Dict[RemoteSpec, + # self.remote_class]) - + # The created remote instances, index first, storage last. # def _create_remote_instances(self, *, on_failure=None): # Create a flat list of all remote specs, global or @@ -278,30 +293,63 @@ class BaseCache(): # Now let's create a dict of this, indexed by their specs, so # that we can later assign them to the right projects. - remotes = {} - q = multiprocessing.Queue() + index_remotes = {} + storage_remotes = {} for remote_spec in remote_specs: - # First, let's check if the remote works - error = self.remote_class.check_remote(remote_spec, self.cas, q) - - # If it doesn't, report the error in some way - if error and on_failure: - on_failure(remote_spec.url, error) - continue - elif error: - raise self.spec_error(error) # pylint: disable=not-callable - - # If it does, we have fetch remotes, and potentially push remotes - self._has_fetch_remotes = True - if remote_spec.push: - self._has_push_remotes = True + try: + index, storage = self._instantiate_remote(remote_spec) + except RemoteError as err: + if on_failure: + on_failure(remote_spec, str(err)) + continue + else: + raise # Finally, we can instantiate the remote. Note that # NamedTuples are hashable, so we can use them as pretty # low-overhead keys. - remotes[remote_spec] = self.remote_class(remote_spec, self.cas) + if index: + index_remotes[remote_spec] = index + if storage: + storage_remotes[remote_spec] = storage + + self._has_fetch_remotes = storage_remotes and index_remotes + self._has_push_remotes = (any(spec.push for spec in storage_remotes) and + any(spec.push for spec in index_remotes)) - return remotes + return index_remotes, storage_remotes + + # _instantiate_remote() + # + # Instantiate a remote given its spec, asserting that it is + # reachable - this may produce two remote instances (a storage and + # an index remote as specified by the class variables). + # + # Args: + # + # remote_spec (RemoteSpec): The spec of the remote to + # instantiate. + # + # Returns: + # + # (Tuple[Remote|None, Remote|None]) - The remotes, index remote + # first, storage remote second. One must always be specified, + # the other may be None. + # + def _instantiate_remote(self, remote_spec): + # Our remotes can be index, storage or both. In either case, + # we need to use a different type of Remote for our calls, so + # we create two objects here + index = None + storage = None + if remote_spec.type in [RemoteType.INDEX, RemoteType.ALL]: + index = self.index_remote_class(remote_spec) # pylint: disable=not-callable + index.check() + if remote_spec.type in [RemoteType.STORAGE, RemoteType.ALL]: + storage = self.storage_remote_class(remote_spec, self.cas) + storage.check() + + return (index, storage) # _message() # @@ -334,8 +382,8 @@ class BaseCache(): # reports takes care of messaging # def _initialize_remotes(self): - def remote_failed(url, error): - self._message(MessageType.WARN, "Failed to initialize remote {}: {}".format(url, error)) + def remote_failed(remote, error): + self._message(MessageType.WARN, "Failed to initialize remote {}: {}".format(remote.url, error)) with self.context.messenger.timed_activity("Initializing remote caches", silent_nested=True): self.initialize_remotes(on_failure=remote_failed) diff --git a/src/buildstream/_cas/casremote.py b/src/buildstream/_cas/casremote.py index 35bbb68ec..1efed22e6 100644 --- a/src/buildstream/_cas/casremote.py +++ b/src/buildstream/_cas/casremote.py @@ -1,9 +1,3 @@ -from collections import namedtuple -import os -import multiprocessing -import signal -from urllib.parse import urlparse - import grpc from .._protos.google.rpc import code_pb2 @@ -11,9 +5,8 @@ from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remo from .._protos.build.buildgrid import local_cas_pb2 from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc -from .._exceptions import CASRemoteError, LoadError, LoadErrorReason -from .. import _signals -from .. import utils +from .._remote import BaseRemote +from .._exceptions import CASRemoteError # The default limit for gRPC messages is 4 MiB. # Limit payload to 1 MiB to leave sufficient headroom for metadata. @@ -29,13 +22,12 @@ class BlobNotFound(CASRemoteError): # Represents a single remote CAS cache. # -class CASRemote(): - def __init__(self, spec, cascache): - self.spec = spec - self._initialized = False +class CASRemote(BaseRemote): + + def __init__(self, spec, cascache, **kwargs): + super().__init__(spec, **kwargs) + self.cascache = cascache - self.channel = None - self.instance_name = None self.cas = None self.ref_storage = None self.batch_update_supported = None @@ -44,157 +36,102 @@ class CASRemote(): self.max_batch_total_size_bytes = None self.local_cas_instance_name = None - def init(self): - if not self._initialized: - server_cert_bytes = None - client_key_bytes = None - client_cert_bytes = None - - url = urlparse(self.spec.url) - if url.scheme == 'http': - port = url.port or 80 - self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port)) - elif url.scheme == 'https': - port = url.port or 443 - - if self.spec.server_cert: - with open(self.spec.server_cert, 'rb') as f: - server_cert_bytes = f.read() - - if self.spec.client_key: - with open(self.spec.client_key, 'rb') as f: - client_key_bytes = f.read() - - if self.spec.client_cert: - with open(self.spec.client_cert, 'rb') as f: - client_cert_bytes = f.read() - - credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes, - private_key=client_key_bytes, - certificate_chain=client_cert_bytes) - self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials) - else: - raise CASRemoteError("Unsupported URL: {}".format(self.spec.url)) - - self.instance_name = self.spec.instance_name or None - - self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) - self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) - self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) - - self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES - try: - request = remote_execution_pb2.GetCapabilitiesRequest() - if self.instance_name: - request.instance_name = self.instance_name - response = self.capabilities.GetCapabilities(request) - server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes - if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: - self.max_batch_total_size_bytes = server_max_batch_total_size_bytes - except grpc.RpcError as e: - # Simply use the defaults for servers that don't implement GetCapabilities() - if e.code() != grpc.StatusCode.UNIMPLEMENTED: - raise - - # Check whether the server supports BatchReadBlobs() - self.batch_read_supported = False - try: - request = remote_execution_pb2.BatchReadBlobsRequest() - if self.instance_name: - request.instance_name = self.instance_name - response = self.cas.BatchReadBlobs(request) - self.batch_read_supported = True - except grpc.RpcError as e: - if e.code() != grpc.StatusCode.UNIMPLEMENTED: - raise - - # Check whether the server supports BatchUpdateBlobs() - self.batch_update_supported = False - try: - request = remote_execution_pb2.BatchUpdateBlobsRequest() - if self.instance_name: - request.instance_name = self.instance_name - response = self.cas.BatchUpdateBlobs(request) - self.batch_update_supported = True - except grpc.RpcError as e: - if (e.code() != grpc.StatusCode.UNIMPLEMENTED and - e.code() != grpc.StatusCode.PERMISSION_DENIED): - raise - - local_cas = self.cascache._get_local_cas() - request = local_cas_pb2.GetInstanceNameForRemoteRequest() - request.url = self.spec.url - if self.spec.instance_name: - request.instance_name = self.spec.instance_name - if server_cert_bytes: - request.server_cert = server_cert_bytes - if client_key_bytes: - request.client_key = client_key_bytes - if client_cert_bytes: - request.client_cert = client_cert_bytes - response = local_cas.GetInstanceNameForRemote(request) - self.local_cas_instance_name = response.instance_name - - self._initialized = True - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - return False - - def close(self): - if self.channel: - self.channel.close() - self.channel = None - # check_remote + # _configure_protocols(): # - # Used when checking whether remote_specs work in the buildstream main - # thread, runs this in a seperate process to avoid creation of gRPC threads - # in the main BuildStream process - # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details - @classmethod - def check_remote(cls, remote_spec, cascache, q): - - def __check_remote(): - try: - remote = cls(remote_spec, cascache) - remote.init() - - request = buildstream_pb2.StatusRequest() - response = remote.ref_storage.Status(request) - - if remote_spec.push and not response.allow_updates: - q.put('CAS server does not allow push') - else: - # No error - q.put(None) - - except grpc.RpcError as e: - # str(e) is too verbose for errors reported to the user - q.put(e.details()) + # Configure remote-specific protocols. This method should *never* + # be called outside of init(). + # + def _configure_protocols(self): + self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) + self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) + self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) + + # Figure out what batch sizes the server will accept, falling + # back to our _MAX_PAYLOAD_BYTES + self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES + try: + request = remote_execution_pb2.GetCapabilitiesRequest() + if self.instance_name: + request.instance_name = self.instance_name + response = self.capabilities.GetCapabilities(request) + server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes + if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: + self.max_batch_total_size_bytes = server_max_batch_total_size_bytes + except grpc.RpcError as e: + # Simply use the defaults for servers that don't implement + # GetCapabilities() + if e.code() != grpc.StatusCode.UNIMPLEMENTED: + raise + + # Check whether the server supports BatchReadBlobs() + self.batch_read_supported = self._check_support( + remote_execution_pb2.BatchReadBlobsRequest, + self.cas.BatchReadBlobs + ) + + # Check whether the server supports BatchUpdateBlobs() + self.batch_update_supported = self._check_support( + remote_execution_pb2.BatchUpdateBlobsRequest, + self.cas.BatchUpdateBlobs + ) + + local_cas = self.cascache._get_local_cas() + request = local_cas_pb2.GetInstanceNameForRemoteRequest() + request.url = self.spec.url + if self.spec.instance_name: + request.instance_name = self.spec.instance_name + if self.server_cert: + request.server_cert = self.server_cert + if self.client_key: + request.client_key = self.client_key + if self.client_cert: + request.client_cert = self.client_cert + response = local_cas.GetInstanceNameForRemote(request) + self.local_cas_instance_name = response.instance_name + + # _check(): + # + # Check if this remote provides everything required for the + # particular kind of remote. This is expected to be called as part + # of check(), and must be called in a non-main process. + # + # Returns: + # (str|None): An error message, or None if no error message. + # + def _check(self): + request = buildstream_pb2.StatusRequest() + response = self.ref_storage.Status(request) - except Exception as e: # pylint: disable=broad-except - # Whatever happens, we need to return it to the calling process - # - q.put(str(e)) + if self.spec.push and not response.allow_updates: + return 'CAS server does not allow push' - p = multiprocessing.Process(target=__check_remote) + return None + # _check_support(): + # + # Figure out if a remote server supports a given method based on + # grpc.StatusCode.UNIMPLEMENTED and grpc.StatusCode.PERMISSION_DENIED. + # + # Args: + # request_type (callable): The type of request to check. + # invoker (callable): The remote method that will be invoked. + # + # Returns: + # (bool) - Whether the request is supported. + # + def _check_support(self, request_type, invoker): try: - # Keep SIGINT blocked in the child process - with _signals.blocked([signal.SIGINT], ignore=False): - p.start() + request = request_type() + if self.instance_name: + request.instance_name = self.instance_name + invoker(request) + return True + except grpc.RpcError as e: + if not e.code() in (grpc.StatusCode.UNIMPLEMENTED, grpc.StatusCode.PERMISSION_DENIED): + raise - error = q.get() - p.join() - except KeyboardInterrupt: - utils._kill_process_tree(p.pid) - raise - - return error + return False # push_message(): # diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py index ca7a21955..bb011146e 100644 --- a/src/buildstream/_cas/casserver.py +++ b/src/buildstream/_cas/casserver.py @@ -54,9 +54,10 @@ _MAX_PAYLOAD_BYTES = 1024 * 1024 # Args: # repo (str): Path to CAS repository # enable_push (bool): Whether to allow blob uploads and artifact updates +# index_only (bool): Whether to store CAS blobs or only artifacts # @contextmanager -def create_server(repo, *, enable_push, quota): +def create_server(repo, *, enable_push, quota, index_only): cas = CASCache(os.path.abspath(repo), cache_quota=quota, protect_session_blobs=False) try: @@ -67,11 +68,12 @@ def create_server(repo, *, enable_push, quota): max_workers = (os.cpu_count() or 1) * 5 server = grpc.server(futures.ThreadPoolExecutor(max_workers)) - bytestream_pb2_grpc.add_ByteStreamServicer_to_server( - _ByteStreamServicer(cas, enable_push=enable_push), server) + if not index_only: + bytestream_pb2_grpc.add_ByteStreamServicer_to_server( + _ByteStreamServicer(cas, enable_push=enable_push), server) - remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server( - _ContentAddressableStorageServicer(cas, enable_push=enable_push), server) + remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server( + _ContentAddressableStorageServicer(cas, enable_push=enable_push), server) remote_execution_pb2_grpc.add_CapabilitiesServicer_to_server( _CapabilitiesServicer(), server) @@ -80,7 +82,7 @@ def create_server(repo, *, enable_push, quota): _ReferenceStorageServicer(cas, enable_push=enable_push), server) artifact_pb2_grpc.add_ArtifactServiceServicer_to_server( - _ArtifactServicer(cas, artifactdir), server) + _ArtifactServicer(cas, artifactdir, update_cas=not index_only), server) source_pb2_grpc.add_SourceServiceServicer_to_server( _SourceServicer(sourcedir), server) @@ -110,9 +112,12 @@ def create_server(repo, *, enable_push, quota): @click.option('--quota', type=click.INT, help="Maximum disk usage in bytes", default=10e9) +@click.option('--index-only', type=click.BOOL, + help="Only provide the BuildStream artifact and source services (\"index\"), not the CAS (\"storage\")", + default=False) @click.argument('repo') def server_main(repo, port, server_key, server_cert, client_certs, enable_push, - quota): + quota, index_only): # Handle SIGTERM by calling sys.exit(0), which will raise a SystemExit exception, # properly executing cleanup code in `finally` clauses and context managers. # This is required to terminate buildbox-casd on SIGTERM. @@ -120,7 +125,8 @@ def server_main(repo, port, server_key, server_cert, client_certs, enable_push, with create_server(repo, quota=quota, - enable_push=enable_push) as server: + enable_push=enable_push, + index_only=index_only) as server: use_tls = bool(server_key) @@ -434,10 +440,11 @@ class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer): class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): - def __init__(self, cas, artifactdir): + def __init__(self, cas, artifactdir, *, update_cas=True): super().__init__() self.cas = cas self.artifactdir = artifactdir + self.update_cas = update_cas os.makedirs(artifactdir, exist_ok=True) def GetArtifact(self, request, context): @@ -449,6 +456,20 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): with open(artifact_path, 'rb') as f: artifact.ParseFromString(f.read()) + # Artifact-only servers will not have blobs on their system, + # so we can't reasonably update their mtimes. Instead, we exit + # early, and let the CAS server deal with its blobs. + # + # FIXME: We could try to run FindMissingBlobs on the other + # server. This is tricky to do from here, of course, + # because we don't know who the other server is, but + # the client could be smart about it - but this might + # make things slower. + # + # It needs some more thought... + if not self.update_cas: + return artifact + # Now update mtimes of files present. try: @@ -481,16 +502,17 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): def UpdateArtifact(self, request, context): artifact = request.artifact - # Check that the files specified are in the CAS - self._check_directory("files", artifact.files, context) + if self.update_cas: + # Check that the files specified are in the CAS + self._check_directory("files", artifact.files, context) - # Unset protocol buffers don't evaluated to False but do return empty - # strings, hence str() - if str(artifact.public_data): - self._check_file("public data", artifact.public_data, context) + # Unset protocol buffers don't evaluated to False but do return empty + # strings, hence str() + if str(artifact.public_data): + self._check_file("public data", artifact.public_data, context) - for log_file in artifact.logs: - self._check_file("log digest", log_file.digest, context) + for log_file in artifact.logs: + self._check_file("log digest", log_file.digest, context) # Add the artifact proto to the cas artifact_path = os.path.join(self.artifactdir, request.cache_key) diff --git a/src/buildstream/_exceptions.py b/src/buildstream/_exceptions.py index 648742dbb..947b83149 100644 --- a/src/buildstream/_exceptions.py +++ b/src/buildstream/_exceptions.py @@ -96,6 +96,7 @@ class ErrorDomain(Enum): VIRTUAL_FS = 13 CAS = 14 PROG_NOT_FOUND = 15 + REMOTE = 16 # BstError is an internal base exception class for BuildStream @@ -290,6 +291,15 @@ class ArtifactError(BstError): super().__init__(message, detail=detail, domain=ErrorDomain.ARTIFACT, reason=reason, temporary=True) +# RemoteError +# +# Raised when errors are encountered in Remotes +# +class RemoteError(BstError): + def __init__(self, message, *, detail=None, reason=None): + super().__init__(message, detail=detail, domain=ErrorDomain.REMOTE, reason=reason) + + # CASError # # Raised when errors are encountered in the CAS diff --git a/src/buildstream/_remote.py b/src/buildstream/_remote.py index 9761e8238..75c626c47 100644 --- a/src/buildstream/_remote.py +++ b/src/buildstream/_remote.py @@ -25,15 +25,29 @@ import grpc from . import _signals from . import utils -from ._exceptions import LoadError, LoadErrorReason, ImplError +from ._exceptions import LoadError, LoadErrorReason, ImplError, RemoteError from ._protos.google.bytestream import bytestream_pb2_grpc +from .types import FastEnum + + +# RemoteType(): +# +# Defines the different types of remote. +# +class RemoteType(FastEnum): + INDEX = "index" + STORAGE = "storage" + ALL = "all" + + def __str__(self): + return self.name.lower().replace('_', '-') # RemoteSpec(): # # Defines the basic structure of a remote specification. # -class RemoteSpec(namedtuple('RemoteSpec', 'url push server_cert client_key client_cert instance_name')): +class RemoteSpec(namedtuple('RemoteSpec', 'url push server_cert client_key client_cert instance_name type')): # new_from_config_node # @@ -51,7 +65,8 @@ class RemoteSpec(namedtuple('RemoteSpec', 'url push server_cert client_key clien # @classmethod def new_from_config_node(cls, spec_node, basedir=None): - spec_node.validate_keys(['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance-name']) + spec_node.validate_keys(['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance-name', 'type']) + url = spec_node.get_str('url') if not url: provenance = spec_node.get_node('url').get_provenance() @@ -79,8 +94,9 @@ class RemoteSpec(namedtuple('RemoteSpec', 'url push server_cert client_key clien raise LoadError("{}: 'client-cert' was specified without 'client-key'".format(provenance), LoadErrorReason.INVALID_DATA) - return cls(url, push, server_cert, client_key, client_cert, instance_name) + type_ = spec_node.get_enum('type', RemoteType, default=RemoteType.ALL) + return cls(url, push, server_cert, client_key, client_cert, instance_name, type_) # FIXME: This can be made much nicer in python 3.7 through the use of @@ -90,14 +106,15 @@ class RemoteSpec(namedtuple('RemoteSpec', 'url push server_cert client_key clien # are considered mandatory. # # Disable type-checking since "Callable[...] has no attributes __defaults__" -RemoteSpec.__new__.__defaults__ = ( +RemoteSpec.__new__.__defaults__ = ( # type: ignore # mandatory # url - The url of the remote # mandatory # push - Whether the remote should be used for pushing None, # server_cert - The server certificate None, # client_key - The (private) client key None, # client_cert - The (public) client certificate - None # instance_name - The (grpc) instance name of the remote -) # type: ignore + None, # instance_name - The (grpc) instance name of the remote + RemoteType.ALL # type - The type of the remote (index, storage, both) +) # BaseRemote(): @@ -120,6 +137,10 @@ class BaseRemote(): self.bytestream = None self.channel = None + self.server_cert = None + self.client_key = None + self.client_cert = None + self.instance_name = spec.instance_name self.push = spec.push self.url = spec.url @@ -133,11 +154,6 @@ class BaseRemote(): if self._initialized: return - # gRPC doesn't support fork without exec, which is used in the - # main process. - if self.fork_allowed: - assert not utils._is_main_process() - # Set up the communcation channel url = urlparse(self.spec.url) if url.scheme == 'http': @@ -152,9 +168,12 @@ class BaseRemote(): self.spec.client_cert) except FileNotFoundError as e: raise RemoteError("Could not read certificates: {}".format(e)) from e - credentials = grpc.ssl_channel_credentials(root_certificates=server_cert, - private_key=client_key, - certificate_chain=client_cert) + self.server_cert = server_cert + self.client_key = client_key + self.client_cert = client_cert + credentials = grpc.ssl_channel_credentials(root_certificates=self.server_cert, + private_key=self.client_key, + certificate_chain=self.client_cert) self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials) else: raise RemoteError("Unsupported URL: {}".format(self.spec.url)) @@ -166,6 +185,18 @@ class BaseRemote(): self._initialized = True + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + return False + + def close(self): + if self.channel: + self.channel.close() + self.channel = None + # _configure_protocols(): # # An abstract method to configure remote-specific protocols. This @@ -178,46 +209,66 @@ class BaseRemote(): def _configure_protocols(self): raise ImplError("An implementation of a Remote must configure its protocols.") - # check_remote + # check(): # - # Used when checking whether remote_specs work in the buildstream main - # thread, runs this in a seperate process to avoid creation of gRPC threads - # in the main BuildStream process - # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details - @classmethod - def check_remote(cls, remote_spec, q): + # Check if the remote is functional and has all the required + # capabilities. This should be used somewhat like an assertion, + # expecting a RemoteError. + # + # Note that this method runs the calls on a separate process, so + # that we can use grpc calls even if we are on the main process. + # + # Raises: + # RemoteError: If the grpc call fails. + # + def check(self): + queue = multiprocessing.Queue() + def __check_remote(): try: - remote = cls(remote_spec) - remote.init() - remote.check(q) + self.init() + queue.put(self._check()) except grpc.RpcError as e: # str(e) is too verbose for errors reported to the user - q.put(e.details()) + queue.put(e.details()) except Exception as e: # pylint: disable=broad-except # Whatever happens, we need to return it to the calling process # - q.put(str(e)) + queue.put(str(e)) - p = multiprocessing.Process(target=__check_remote) + process = multiprocessing.Process(target=__check_remote) try: # Keep SIGINT blocked in the child process with _signals.blocked([signal.SIGINT], ignore=False): - p.start() + process.start() - error = q.get() - p.join() + error = queue.get() + process.join() except KeyboardInterrupt: - utils._kill_process_tree(p.pid) + utils._kill_process_tree(process.pid) raise + finally: + # Should not be necessary, but let's avoid keeping them + # alive too long + queue.close() - return error + if error: + raise RemoteError(error) - def check(self, q): - q.put("An implementation of BaseCache should set a _check method") + # _check(): + # + # Check if this remote provides everything required for the + # particular kind of remote. This is expected to be called as part + # of check(), and must be called in a non-main process. + # + # Returns: + # (str|None): An error message, or None if no error message. + # + def _check(self): + return None def __str__(self): return self.url diff --git a/src/buildstream/_sourcecache.py b/src/buildstream/_sourcecache.py index 2a6a6e220..76a2e4f39 100644 --- a/src/buildstream/_sourcecache.py +++ b/src/buildstream/_sourcecache.py @@ -20,7 +20,7 @@ import os import grpc -from ._cas import CASRemote +from ._remote import BaseRemote from .storage._casbaseddirectory import CasBasedDirectory from ._basecache import BaseCache from ._exceptions import CASError, CASRemoteError, SourceCacheError @@ -29,43 +29,75 @@ from ._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc, \ source_pb2, source_pb2_grpc +class SourceRemote(BaseRemote): - -class SourceRemote(CASRemote): - def __init__(self, *args): - super().__init__(*args) - self.capabilities_service = None + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.source_service = None - def init(self): - if not self._initialized: - super().init() - - self.capabilities_service = buildstream_pb2_grpc.CapabilitiesStub(self.channel) + def _configure_protocols(self): + capabilities_service = buildstream_pb2_grpc.CapabilitiesStub(self.channel) + # check that the service supports sources + try: + request = buildstream_pb2.GetCapabilitiesRequest() + if self.instance_name: + request.instance_name = self.instance_name - # check that the service supports sources - try: - request = buildstream_pb2.GetCapabilitiesRequest() - if self.instance_name: - request.instance_name = self.instance_name - - response = self.capabilities_service.GetCapabilities(request) - except grpc.RpcError as e: - # Check if this remote has the artifact service - if e.code() == grpc.StatusCode.UNIMPLEMENTED: - raise SourceCacheError( - "Configured remote does not have the BuildStream " - "capabilities service. Please check remote configuration.") - # Else raise exception with details + response = capabilities_service.GetCapabilities(request) + except grpc.RpcError as e: + # Check if this remote has the artifact service + if e.code() == grpc.StatusCode.UNIMPLEMENTED: raise SourceCacheError( - "Remote initialisation failed: {}".format(e.details())) + "Configured remote does not have the BuildStream " + "capabilities service. Please check remote configuration.") + # Else raise exception with details + raise SourceCacheError( + "Remote initialisation failed: {}".format(e.details())) - if not response.source_capabilities: - raise SourceCacheError( - "Configured remote does not support source service") + if not response.source_capabilities: + raise SourceCacheError( + "Configured remote does not support source service") + + # set up source service + self.source_service = source_pb2_grpc.SourceServiceStub(self.channel) + + # get_source(): + # + # Get a source proto for a given source_ref from the remote. + # + # Args: + # source_ref (str): The source ref of the source to pull. + # + # Returns: + # (Source): The source proto + # + # Raises: + # grpc.RpcError: If something goes wrong during the request. + # + def get_source(self, source_ref): + request = source_pb2.GetSourceRequest() + request.cache_key = source_ref + return self.source_service.GetSource(request) - # set up source service - self.source_service = source_pb2_grpc.SourceServiceStub(self.channel) + # update_source(): + # + # Update the source on the remote. + # + # Args: + # source_ref (str): The source ref of the source to update. + # source (Source): The proto to update with. + # + # Returns: + # (bool): Whether the update was successful. + # + # Raises: + # grpc.RpcError: If something goes wrong during the request. + # + def update_source(self, source_ref, source): + request = source_pb2.UpdateSourceRequest() + request.cache_key = source_ref + request.source.CopyFrom(source) + return self.source_service.UpdateSource(request) # Class that keeps config of remotes and deals with caching of sources. @@ -78,7 +110,7 @@ class SourceCache(BaseCache): spec_name = "source_cache_specs" spec_error = SourceCacheError config_node_name = "source-caches" - remote_class = SourceRemote + index_remote_class = SourceRemote def __init__(self, context): super().__init__(context) @@ -172,39 +204,53 @@ class SourceCache(BaseCache): # (bool): True if pull successful, False if not def pull(self, source): ref = source._get_source_name() - project = source._get_project() - display_key = source._get_brief_display_key() - for remote in self._remotes[project]: + index_remotes = self._index_remotes[project] + storage_remotes = self._storage_remotes[project] + + # First fetch the source proto so we know what to pull + source_proto = None + for remote in index_remotes: try: - source.status("Pulling source {} <- {}".format(display_key, remote.spec.url)) + remote.init() + source.status("Pulling source {} <- {}".format(display_key, remote)) - # fetch source proto - response = self._pull_source(ref, remote) - if response is None: + source_proto = self._pull_source(ref, remote) + if source_proto is None: source.info("Remote source service ({}) does not have source {} cached".format( - remote.spec.url, display_key)) + remote, display_key)) continue + except CASError as e: + raise SourceCacheError("Failed to pull source {}: {}".format( + display_key, e)) from e + + if not source_proto: + return False + + for remote in storage_remotes: + try: + remote.init() + source.status("Pulling data for source {} <- {}".format(display_key, remote)) # Fetch source blobs - self.cas._fetch_directory(remote, response.files) - required_blobs = self.cas.required_blobs_for_directory(response.files) + self.cas._fetch_directory(remote, source_proto.files) + required_blobs = self.cas.required_blobs_for_directory(source_proto.files) missing_blobs = self.cas.local_missing_blobs(required_blobs) missing_blobs = self.cas.fetch_blobs(remote, missing_blobs) if missing_blobs: source.info("Remote cas ({}) does not have source {} cached".format( - remote.spec.url, display_key)) + remote, display_key)) continue - source.info("Pulled source {} <- {}".format(display_key, remote.spec.url)) + source.info("Pulled source {} <- {}".format(display_key, remote)) return True - except CASError as e: raise SourceCacheError("Failed to pull source {}: {}".format( display_key, e)) from e + return False # push() @@ -221,41 +267,48 @@ class SourceCache(BaseCache): ref = source._get_source_name() project = source._get_project() + index_remotes = [] + storage_remotes = [] + # find configured push remotes for this source if self._has_push_remotes: - push_remotes = [r for r in self._remotes[project] if r.spec.push] - else: - push_remotes = [] + index_remotes = [r for r in self._index_remotes[project] if r.push] + storage_remotes = [r for r in self._storage_remotes[project] if r.push] - pushed = False + pushed_storage = False + pushed_index = False display_key = source._get_brief_display_key() - for remote in push_remotes: + for remote in storage_remotes: remote.init() - source.status("Pushing source {} -> {}".format(display_key, remote.spec.url)) + source.status("Pushing data for source {} -> {}".format(display_key, remote)) - # check whether cache has files already - if self._pull_source(ref, remote) is not None: - source.info("Remote ({}) already has source {} cached" - .format(remote.spec.url, display_key)) - continue - - # push files to storage source_proto = self._get_source(ref) try: self.cas._send_directory(remote, source_proto.files) + pushed_storage = True except CASRemoteError: - source.info("Failed to push source files {} -> {}".format(display_key, remote.spec.url)) + source.info("Failed to push source files {} -> {}".format(display_key, remote)) + continue + + for remote in index_remotes: + remote.init() + source.status("Pushing source {} -> {}".format(display_key, remote)) + + # check whether cache has files already + if self._pull_source(ref, remote) is not None: + source.info("Remote ({}) already has source {} cached" + .format(remote, display_key)) continue if not self._push_source(ref, remote): - source.info("Failed to push source metadata {} -> {}".format(display_key, remote.spec.url)) + source.info("Failed to push source metadata {} -> {}".format(display_key, remote)) continue - source.info("Pushed source {} -> {}".format(display_key, remote.spec.url)) - pushed = True + source.info("Pushed source {} -> {}".format(display_key, remote)) + pushed_index = True - return pushed + return pushed_index and pushed_storage def _remove_source(self, ref, *, defer_prune=False): return self.cas.remove(ref, basedir=self.sourcerefdir, defer_prune=defer_prune) @@ -304,14 +357,8 @@ class SourceCache(BaseCache): def _pull_source(self, source_ref, remote): try: remote.init() - - request = source_pb2.GetSourceRequest() - request.cache_key = source_ref - - response = remote.source_service.GetSource(request) - + response = remote.get_source(source_ref) self._store_proto(response, source_ref) - return response except grpc.RpcError as e: @@ -322,14 +369,8 @@ class SourceCache(BaseCache): def _push_source(self, source_ref, remote): try: remote.init() - source_proto = self._get_source(source_ref) - - request = source_pb2.UpdateSourceRequest() - request.cache_key = source_ref - request.source.CopyFrom(source_proto) - - return remote.source_service.UpdateSource(request) + return remote.update_source(source_ref, source_proto) except grpc.RpcError as e: if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED: |