diff options
Diffstat (limited to 'src/buildstream/_sourcecache.py')
-rw-r--r-- | src/buildstream/_sourcecache.py | 195 |
1 files changed, 118 insertions, 77 deletions
diff --git a/src/buildstream/_sourcecache.py b/src/buildstream/_sourcecache.py index 2a6a6e220..76a2e4f39 100644 --- a/src/buildstream/_sourcecache.py +++ b/src/buildstream/_sourcecache.py @@ -20,7 +20,7 @@ import os import grpc -from ._cas import CASRemote +from ._remote import BaseRemote from .storage._casbaseddirectory import CasBasedDirectory from ._basecache import BaseCache from ._exceptions import CASError, CASRemoteError, SourceCacheError @@ -29,43 +29,75 @@ from ._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc, \ source_pb2, source_pb2_grpc +class SourceRemote(BaseRemote): - -class SourceRemote(CASRemote): - def __init__(self, *args): - super().__init__(*args) - self.capabilities_service = None + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.source_service = None - def init(self): - if not self._initialized: - super().init() - - self.capabilities_service = buildstream_pb2_grpc.CapabilitiesStub(self.channel) + def _configure_protocols(self): + capabilities_service = buildstream_pb2_grpc.CapabilitiesStub(self.channel) + # check that the service supports sources + try: + request = buildstream_pb2.GetCapabilitiesRequest() + if self.instance_name: + request.instance_name = self.instance_name - # check that the service supports sources - try: - request = buildstream_pb2.GetCapabilitiesRequest() - if self.instance_name: - request.instance_name = self.instance_name - - response = self.capabilities_service.GetCapabilities(request) - except grpc.RpcError as e: - # Check if this remote has the artifact service - if e.code() == grpc.StatusCode.UNIMPLEMENTED: - raise SourceCacheError( - "Configured remote does not have the BuildStream " - "capabilities service. Please check remote configuration.") - # Else raise exception with details + response = capabilities_service.GetCapabilities(request) + except grpc.RpcError as e: + # Check if this remote has the artifact service + if e.code() == grpc.StatusCode.UNIMPLEMENTED: raise SourceCacheError( - "Remote initialisation failed: {}".format(e.details())) + "Configured remote does not have the BuildStream " + "capabilities service. Please check remote configuration.") + # Else raise exception with details + raise SourceCacheError( + "Remote initialisation failed: {}".format(e.details())) - if not response.source_capabilities: - raise SourceCacheError( - "Configured remote does not support source service") + if not response.source_capabilities: + raise SourceCacheError( + "Configured remote does not support source service") + + # set up source service + self.source_service = source_pb2_grpc.SourceServiceStub(self.channel) + + # get_source(): + # + # Get a source proto for a given source_ref from the remote. + # + # Args: + # source_ref (str): The source ref of the source to pull. + # + # Returns: + # (Source): The source proto + # + # Raises: + # grpc.RpcError: If something goes wrong during the request. + # + def get_source(self, source_ref): + request = source_pb2.GetSourceRequest() + request.cache_key = source_ref + return self.source_service.GetSource(request) - # set up source service - self.source_service = source_pb2_grpc.SourceServiceStub(self.channel) + # update_source(): + # + # Update the source on the remote. + # + # Args: + # source_ref (str): The source ref of the source to update. + # source (Source): The proto to update with. + # + # Returns: + # (bool): Whether the update was successful. + # + # Raises: + # grpc.RpcError: If something goes wrong during the request. + # + def update_source(self, source_ref, source): + request = source_pb2.UpdateSourceRequest() + request.cache_key = source_ref + request.source.CopyFrom(source) + return self.source_service.UpdateSource(request) # Class that keeps config of remotes and deals with caching of sources. @@ -78,7 +110,7 @@ class SourceCache(BaseCache): spec_name = "source_cache_specs" spec_error = SourceCacheError config_node_name = "source-caches" - remote_class = SourceRemote + index_remote_class = SourceRemote def __init__(self, context): super().__init__(context) @@ -172,39 +204,53 @@ class SourceCache(BaseCache): # (bool): True if pull successful, False if not def pull(self, source): ref = source._get_source_name() - project = source._get_project() - display_key = source._get_brief_display_key() - for remote in self._remotes[project]: + index_remotes = self._index_remotes[project] + storage_remotes = self._storage_remotes[project] + + # First fetch the source proto so we know what to pull + source_proto = None + for remote in index_remotes: try: - source.status("Pulling source {} <- {}".format(display_key, remote.spec.url)) + remote.init() + source.status("Pulling source {} <- {}".format(display_key, remote)) - # fetch source proto - response = self._pull_source(ref, remote) - if response is None: + source_proto = self._pull_source(ref, remote) + if source_proto is None: source.info("Remote source service ({}) does not have source {} cached".format( - remote.spec.url, display_key)) + remote, display_key)) continue + except CASError as e: + raise SourceCacheError("Failed to pull source {}: {}".format( + display_key, e)) from e + + if not source_proto: + return False + + for remote in storage_remotes: + try: + remote.init() + source.status("Pulling data for source {} <- {}".format(display_key, remote)) # Fetch source blobs - self.cas._fetch_directory(remote, response.files) - required_blobs = self.cas.required_blobs_for_directory(response.files) + self.cas._fetch_directory(remote, source_proto.files) + required_blobs = self.cas.required_blobs_for_directory(source_proto.files) missing_blobs = self.cas.local_missing_blobs(required_blobs) missing_blobs = self.cas.fetch_blobs(remote, missing_blobs) if missing_blobs: source.info("Remote cas ({}) does not have source {} cached".format( - remote.spec.url, display_key)) + remote, display_key)) continue - source.info("Pulled source {} <- {}".format(display_key, remote.spec.url)) + source.info("Pulled source {} <- {}".format(display_key, remote)) return True - except CASError as e: raise SourceCacheError("Failed to pull source {}: {}".format( display_key, e)) from e + return False # push() @@ -221,41 +267,48 @@ class SourceCache(BaseCache): ref = source._get_source_name() project = source._get_project() + index_remotes = [] + storage_remotes = [] + # find configured push remotes for this source if self._has_push_remotes: - push_remotes = [r for r in self._remotes[project] if r.spec.push] - else: - push_remotes = [] + index_remotes = [r for r in self._index_remotes[project] if r.push] + storage_remotes = [r for r in self._storage_remotes[project] if r.push] - pushed = False + pushed_storage = False + pushed_index = False display_key = source._get_brief_display_key() - for remote in push_remotes: + for remote in storage_remotes: remote.init() - source.status("Pushing source {} -> {}".format(display_key, remote.spec.url)) + source.status("Pushing data for source {} -> {}".format(display_key, remote)) - # check whether cache has files already - if self._pull_source(ref, remote) is not None: - source.info("Remote ({}) already has source {} cached" - .format(remote.spec.url, display_key)) - continue - - # push files to storage source_proto = self._get_source(ref) try: self.cas._send_directory(remote, source_proto.files) + pushed_storage = True except CASRemoteError: - source.info("Failed to push source files {} -> {}".format(display_key, remote.spec.url)) + source.info("Failed to push source files {} -> {}".format(display_key, remote)) + continue + + for remote in index_remotes: + remote.init() + source.status("Pushing source {} -> {}".format(display_key, remote)) + + # check whether cache has files already + if self._pull_source(ref, remote) is not None: + source.info("Remote ({}) already has source {} cached" + .format(remote, display_key)) continue if not self._push_source(ref, remote): - source.info("Failed to push source metadata {} -> {}".format(display_key, remote.spec.url)) + source.info("Failed to push source metadata {} -> {}".format(display_key, remote)) continue - source.info("Pushed source {} -> {}".format(display_key, remote.spec.url)) - pushed = True + source.info("Pushed source {} -> {}".format(display_key, remote)) + pushed_index = True - return pushed + return pushed_index and pushed_storage def _remove_source(self, ref, *, defer_prune=False): return self.cas.remove(ref, basedir=self.sourcerefdir, defer_prune=defer_prune) @@ -304,14 +357,8 @@ class SourceCache(BaseCache): def _pull_source(self, source_ref, remote): try: remote.init() - - request = source_pb2.GetSourceRequest() - request.cache_key = source_ref - - response = remote.source_service.GetSource(request) - + response = remote.get_source(source_ref) self._store_proto(response, source_ref) - return response except grpc.RpcError as e: @@ -322,14 +369,8 @@ class SourceCache(BaseCache): def _push_source(self, source_ref, remote): try: remote.init() - source_proto = self._get_source(source_ref) - - request = source_pb2.UpdateSourceRequest() - request.cache_key = source_ref - request.source.CopyFrom(source_proto) - - return remote.source_service.UpdateSource(request) + return remote.update_source(source_ref, source_proto) except grpc.RpcError as e: if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED: |