# # Copyright (C) 2018 Codethink Limited # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library. If not, see . # # Authors: # Jürg Billeter import hashlib import itertools import io import multiprocessing import os import signal import stat import tempfile import uuid import errno from urllib.parse import urlparse import grpc from .. import _yaml from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc from .._message import MessageType, Message from .. import _signals, utils from .._exceptions import ArtifactError from . import ArtifactCache # The default limit for gRPC messages is 4 MiB. # Limit payload to 1 MiB to leave sufficient headroom for metadata. _MAX_PAYLOAD_BYTES = 1024 * 1024 # A CASCache manages artifacts in a CAS repository as specified in the # Remote Execution API. # # Args: # context (Context): The BuildStream context # # Pushing is explicitly disabled by the platform in some cases, # like when we are falling back to functioning without using # user namespaces. # class CASCache(ArtifactCache): def __init__(self, context): super().__init__(context) self.casdir = os.path.join(context.artifactdir, 'cas') os.makedirs(os.path.join(self.casdir, 'refs', 'heads'), exist_ok=True) os.makedirs(os.path.join(self.casdir, 'objects'), exist_ok=True) self._calculate_cache_quota() # Per-project list of _CASRemote instances. self._remotes = {} self._has_fetch_remotes = False self._has_push_remotes = False ################################################ # Implementation of abstract methods # ################################################ def preflight(self): if (not os.path.isdir(os.path.join(self.casdir, 'refs', 'heads')) or not os.path.isdir(os.path.join(self.casdir, 'objects'))): raise ArtifactError("CAS repository check failed for '{}'" .format(self.casdir)) def contains(self, element, key): refpath = self._refpath(self.get_artifact_fullname(element, key)) # This assumes that the repository doesn't have any dangling pointers return os.path.exists(refpath) def extract(self, element, key): ref = self.get_artifact_fullname(element, key) tree = self.resolve_ref(ref, update_mtime=True) dest = os.path.join(self.extractdir, element._get_project().name, element.normal_name, tree.hash) if os.path.isdir(dest): # artifact has already been extracted return dest with tempfile.TemporaryDirectory(prefix='tmp', dir=self.extractdir) as tmpdir: checkoutdir = os.path.join(tmpdir, ref) self._checkout(checkoutdir, tree) os.makedirs(os.path.dirname(dest), exist_ok=True) try: os.rename(checkoutdir, dest) except OSError as e: # With rename it's possible to get either ENOTEMPTY or EEXIST # in the case that the destination path is a not empty directory. # # If rename fails with these errors, another process beat # us to it so just ignore. if e.errno not in [errno.ENOTEMPTY, errno.EEXIST]: raise ArtifactError("Failed to extract artifact for ref '{}': {}" .format(ref, e)) from e return dest def commit(self, element, content, keys): refs = [self.get_artifact_fullname(element, key) for key in keys] tree = self._commit_directory(content) for ref in refs: self.set_ref(ref, tree) def diff(self, element, key_a, key_b, *, subdir=None): ref_a = self.get_artifact_fullname(element, key_a) ref_b = self.get_artifact_fullname(element, key_b) tree_a = self.resolve_ref(ref_a) tree_b = self.resolve_ref(ref_b) if subdir: tree_a = self._get_subdir(tree_a, subdir) tree_b = self._get_subdir(tree_b, subdir) added = [] removed = [] modified = [] self._diff_trees(tree_a, tree_b, added=added, removed=removed, modified=modified) return modified, removed, added def initialize_remotes(self, *, on_failure=None): remote_specs = self.global_remote_specs for project in self.project_remote_specs: remote_specs += self.project_remote_specs[project] remote_specs = list(utils._deduplicate(remote_specs)) remotes = {} q = multiprocessing.Queue() for remote_spec in remote_specs: # Use subprocess to avoid creation of gRPC threads in main BuildStream process # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details p = multiprocessing.Process(target=self._initialize_remote, args=(remote_spec, q)) try: # Keep SIGINT blocked in the child process with _signals.blocked([signal.SIGINT], ignore=False): p.start() error = q.get() p.join() except KeyboardInterrupt: utils._kill_process_tree(p.pid) raise if error and on_failure: on_failure(remote_spec.url, error) elif error: raise ArtifactError(error) else: self._has_fetch_remotes = True if remote_spec.push: self._has_push_remotes = True remotes[remote_spec.url] = _CASRemote(remote_spec) for project in self.context.get_projects(): remote_specs = self.global_remote_specs if project in self.project_remote_specs: remote_specs = list(utils._deduplicate(remote_specs + self.project_remote_specs[project])) project_remotes = [] for remote_spec in remote_specs: # Errors are already handled in the loop above, # skip unreachable remotes here. if remote_spec.url not in remotes: continue remote = remotes[remote_spec.url] project_remotes.append(remote) self._remotes[project] = project_remotes def has_fetch_remotes(self, *, element=None): if not self._has_fetch_remotes: # No project has fetch remotes return False elif element is None: # At least one (sub)project has fetch remotes return True else: # Check whether the specified element's project has fetch remotes remotes_for_project = self._remotes[element._get_project()] return bool(remotes_for_project) def has_push_remotes(self, *, element=None): if not self._has_push_remotes: # No project has push remotes return False elif element is None: # At least one (sub)project has push remotes return True else: # Check whether the specified element's project has push remotes remotes_for_project = self._remotes[element._get_project()] return any(remote.spec.push for remote in remotes_for_project) def pull(self, element, key, *, progress=None): ref = self.get_artifact_fullname(element, key) project = element._get_project() for remote in self._remotes[project]: try: remote.init() display_key = element._get_brief_display_key() element.status("Pulling artifact {} <- {}".format(display_key, remote.spec.url)) request = buildstream_pb2.GetReferenceRequest() request.key = ref response = remote.ref_storage.GetReference(request) tree = remote_execution_pb2.Digest() tree.hash = response.digest.hash tree.size_bytes = response.digest.size_bytes self._fetch_directory(remote, tree) self.set_ref(ref, tree) element.info("Pulled artifact {} <- {}".format(display_key, remote.spec.url)) # no need to pull from additional remotes return True except grpc.RpcError as e: if e.code() != grpc.StatusCode.NOT_FOUND: raise ArtifactError("Failed to pull artifact {}: {}".format( element._get_brief_display_key(), e)) from e else: element.info("Remote ({}) does not have {} cached".format( remote.spec.url, element._get_brief_display_key() )) return False def pull_tree(self, project, digest): """ Pull a single Tree rather than an artifact. Does not update local refs. """ for remote in self._remotes[project]: try: remote.init() digest = self._fetch_tree(remote, digest) # no need to pull from additional remotes return digest except grpc.RpcError as e: if e.code() != grpc.StatusCode.NOT_FOUND: raise return None def link_key(self, element, oldkey, newkey): oldref = self.get_artifact_fullname(element, oldkey) newref = self.get_artifact_fullname(element, newkey) tree = self.resolve_ref(oldref) self.set_ref(newref, tree) def _push_refs_to_remote(self, refs, remote): skipped_remote = True try: for ref in refs: tree = self.resolve_ref(ref) # Check whether ref is already on the server in which case # there is no need to push the artifact try: request = buildstream_pb2.GetReferenceRequest() request.key = ref response = remote.ref_storage.GetReference(request) if response.digest.hash == tree.hash and response.digest.size_bytes == tree.size_bytes: # ref is already on the server with the same tree continue except grpc.RpcError as e: if e.code() != grpc.StatusCode.NOT_FOUND: # Intentionally re-raise RpcError for outer except block. raise self._send_directory(remote, tree) request = buildstream_pb2.UpdateReferenceRequest() request.keys.append(ref) request.digest.hash = tree.hash request.digest.size_bytes = tree.size_bytes remote.ref_storage.UpdateReference(request) skipped_remote = False except grpc.RpcError as e: if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED: raise ArtifactError("Failed to push artifact {}: {}".format(refs, e), temporary=True) from e return not skipped_remote def push(self, element, keys): refs = [self.get_artifact_fullname(element, key) for key in list(keys)] project = element._get_project() push_remotes = [r for r in self._remotes[project] if r.spec.push] pushed = False for remote in push_remotes: remote.init() display_key = element._get_brief_display_key() element.status("Pushing artifact {} -> {}".format(display_key, remote.spec.url)) if self._push_refs_to_remote(refs, remote): element.info("Pushed artifact {} -> {}".format(display_key, remote.spec.url)) pushed = True else: element.info("Remote ({}) already has {} cached".format( remote.spec.url, element._get_brief_display_key() )) return pushed def push_directory(self, project, directory): """ Push the given virtual directory to all remotes. Args: project (Project): The current project directory (Directory): A virtual directory object to push. Raises: ArtifactError if no push remotes are configured. """ if self._has_push_remotes: push_remotes = [r for r in self._remotes[project] if r.spec.push] else: push_remotes = [] if not push_remotes: raise ArtifactError("CASCache: push_directory was called, but no remote artifact " + "servers are configured as push remotes.") if directory.ref is None: return for remote in push_remotes: remote.init() self._send_directory(remote, directory.ref) def push_message(self, project, message): push_remotes = [r for r in self._remotes[project] if r.spec.push] message_buffer = message.SerializeToString() message_sha = hashlib.sha256(message_buffer) message_digest = remote_execution_pb2.Digest() message_digest.hash = message_sha.hexdigest() message_digest.size_bytes = len(message_buffer) for remote in push_remotes: remote.init() with io.BytesIO(message_buffer) as b: self._send_blob(remote, message_digest, b) return message_digest def _verify_digest_on_remote(self, remote, digest): # Check whether ref is already on the server in which case # there is no need to push the artifact request = remote_execution_pb2.FindMissingBlobsRequest() request.blob_digests.extend([digest]) response = remote.cas.FindMissingBlobs(request) if digest in response.missing_blob_digests: return False return True def verify_digest_pushed(self, project, digest): push_remotes = [r for r in self._remotes[project] if r.spec.push] pushed = False for remote in push_remotes: remote.init() if self._verify_digest_on_remote(remote, digest): pushed = True return pushed ################################################ # API Private Methods # ################################################ # objpath(): # # Return the path of an object based on its digest. # # Args: # digest (Digest): The digest of the object # # Returns: # (str): The path of the object # def objpath(self, digest): return os.path.join(self.casdir, 'objects', digest.hash[:2], digest.hash[2:]) # add_object(): # # Hash and write object to CAS. # # Args: # digest (Digest): An optional Digest object to populate # path (str): Path to file to add # buffer (bytes): Byte buffer to add # # Returns: # (Digest): The digest of the added object # # Either `path` or `buffer` must be passed, but not both. # def add_object(self, *, digest=None, path=None, buffer=None): # Exactly one of the two parameters has to be specified assert (path is None) != (buffer is None) if digest is None: digest = remote_execution_pb2.Digest() try: h = hashlib.sha256() # Always write out new file to avoid corruption if input file is modified with tempfile.NamedTemporaryFile(dir=self.tmpdir) as out: # Set mode bits to 0644 os.chmod(out.name, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH) if path: with open(path, 'rb') as f: for chunk in iter(lambda: f.read(4096), b""): h.update(chunk) out.write(chunk) else: h.update(buffer) out.write(buffer) out.flush() digest.hash = h.hexdigest() digest.size_bytes = os.fstat(out.fileno()).st_size # Place file at final location objpath = self.objpath(digest) os.makedirs(os.path.dirname(objpath), exist_ok=True) os.link(out.name, objpath) except FileExistsError as e: # We can ignore the failed link() if the object is already in the repo. pass except OSError as e: raise ArtifactError("Failed to hash object: {}".format(e)) from e return digest # set_ref(): # # Create or replace a ref. # # Args: # ref (str): The name of the ref # def set_ref(self, ref, tree): refpath = self._refpath(ref) os.makedirs(os.path.dirname(refpath), exist_ok=True) with utils.save_file_atomic(refpath, 'wb', tempdir=self.tmpdir) as f: f.write(tree.SerializeToString()) # resolve_ref(): # # Resolve a ref to a digest. # # Args: # ref (str): The name of the ref # update_mtime (bool): Whether to update the mtime of the ref # # Returns: # (Digest): The digest stored in the ref # def resolve_ref(self, ref, *, update_mtime=False): refpath = self._refpath(ref) try: with open(refpath, 'rb') as f: if update_mtime: os.utime(refpath) digest = remote_execution_pb2.Digest() digest.ParseFromString(f.read()) return digest except FileNotFoundError as e: raise ArtifactError("Attempt to access unavailable artifact: {}".format(e)) from e def update_mtime(self, element, key): try: ref = self.get_artifact_fullname(element, key) os.utime(self._refpath(ref)) except FileNotFoundError as e: raise ArtifactError("Attempt to access unavailable artifact: {}".format(e)) from e def calculate_cache_size(self): return utils._get_dir_size(self.casdir) # list_artifacts(): # # List cached artifacts in Least Recently Modified (LRM) order. # # Returns: # (list) - A list of refs in LRM order # def list_artifacts(self): # string of: /path/to/repo/refs/heads ref_heads = os.path.join(self.casdir, 'refs', 'heads') refs = [] mtimes = [] for root, _, files in os.walk(ref_heads): for filename in files: ref_path = os.path.join(root, filename) refs.append(os.path.relpath(ref_path, ref_heads)) # Obtain the mtime (the time a file was last modified) mtimes.append(os.path.getmtime(ref_path)) # NOTE: Sorted will sort from earliest to latest, thus the # first element of this list will be the file modified earliest. return [ref for _, ref in sorted(zip(mtimes, refs))] # remove(): # # Removes the given symbolic ref from the repo. # # Args: # ref (str): A symbolic ref # defer_prune (bool): Whether to defer pruning to the caller. NOTE: # The space won't be freed until you manually # call prune. # # Returns: # (int|None) The amount of space pruned from the repository in # Bytes, or None if defer_prune is True # def remove(self, ref, *, defer_prune=False): # Remove extract if not used by other ref tree = self.resolve_ref(ref) ref_name, ref_hash = os.path.split(ref) extract = os.path.join(self.extractdir, ref_name, tree.hash) keys_file = os.path.join(extract, 'meta', 'keys.yaml') if os.path.exists(keys_file): keys_meta = _yaml.load(keys_file) keys = [keys_meta['strong'], keys_meta['weak']] remove_extract = True for other_hash in keys: if other_hash == ref_hash: continue remove_extract = False break if remove_extract: utils._force_rmtree(extract) # Remove cache ref refpath = self._refpath(ref) if not os.path.exists(refpath): raise ArtifactError("Could not find artifact for ref '{}'".format(ref)) os.unlink(refpath) if not defer_prune: pruned = self.prune() return pruned return None # prune(): # # Prune unreachable objects from the repo. # def prune(self): ref_heads = os.path.join(self.casdir, 'refs', 'heads') pruned = 0 reachable = set() # Check which objects are reachable for root, _, files in os.walk(ref_heads): for filename in files: ref_path = os.path.join(root, filename) ref = os.path.relpath(ref_path, ref_heads) tree = self.resolve_ref(ref) self._reachable_refs_dir(reachable, tree) # Prune unreachable objects for root, _, files in os.walk(os.path.join(self.casdir, 'objects')): for filename in files: objhash = os.path.basename(root) + filename if objhash not in reachable: obj_path = os.path.join(root, filename) pruned += os.stat(obj_path).st_size os.unlink(obj_path) return pruned ################################################ # Local Private Methods # ################################################ def _checkout(self, dest, tree): os.makedirs(dest, exist_ok=True) directory = remote_execution_pb2.Directory() with open(self.objpath(tree), 'rb') as f: directory.ParseFromString(f.read()) for filenode in directory.files: # regular file, create hardlink fullpath = os.path.join(dest, filenode.name) os.link(self.objpath(filenode.digest), fullpath) if filenode.is_executable: os.chmod(fullpath, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH) for dirnode in directory.directories: fullpath = os.path.join(dest, dirnode.name) self._checkout(fullpath, dirnode.digest) for symlinknode in directory.symlinks: # symlink fullpath = os.path.join(dest, symlinknode.name) os.symlink(symlinknode.target, fullpath) def _refpath(self, ref): return os.path.join(self.casdir, 'refs', 'heads', ref) # _commit_directory(): # # Adds local directory to content addressable store. # # Adds files, symbolic links and recursively other directories in # a local directory to the content addressable store. # # Args: # path (str): Path to the directory to add. # dir_digest (Digest): An optional Digest object to use. # # Returns: # (Digest): Digest object for the directory added. # def _commit_directory(self, path, *, dir_digest=None): directory = remote_execution_pb2.Directory() for name in sorted(os.listdir(path)): full_path = os.path.join(path, name) mode = os.lstat(full_path).st_mode if stat.S_ISDIR(mode): dirnode = directory.directories.add() dirnode.name = name self._commit_directory(full_path, dir_digest=dirnode.digest) elif stat.S_ISREG(mode): filenode = directory.files.add() filenode.name = name self.add_object(path=full_path, digest=filenode.digest) filenode.is_executable = (mode & stat.S_IXUSR) == stat.S_IXUSR elif stat.S_ISLNK(mode): symlinknode = directory.symlinks.add() symlinknode.name = name symlinknode.target = os.readlink(full_path) elif stat.S_ISSOCK(mode): # The process serving the socket can't be cached anyway pass else: raise ArtifactError("Unsupported file type for {}".format(full_path)) return self.add_object(digest=dir_digest, buffer=directory.SerializeToString()) def _get_subdir(self, tree, subdir): head, name = os.path.split(subdir) if head: tree = self._get_subdir(tree, head) directory = remote_execution_pb2.Directory() with open(self.objpath(tree), 'rb') as f: directory.ParseFromString(f.read()) for dirnode in directory.directories: if dirnode.name == name: return dirnode.digest raise ArtifactError("Subdirectory {} not found".format(name)) def _diff_trees(self, tree_a, tree_b, *, added, removed, modified, path=""): dir_a = remote_execution_pb2.Directory() dir_b = remote_execution_pb2.Directory() if tree_a: with open(self.objpath(tree_a), 'rb') as f: dir_a.ParseFromString(f.read()) if tree_b: with open(self.objpath(tree_b), 'rb') as f: dir_b.ParseFromString(f.read()) a = 0 b = 0 while a < len(dir_a.files) or b < len(dir_b.files): if b < len(dir_b.files) and (a >= len(dir_a.files) or dir_a.files[a].name > dir_b.files[b].name): added.append(os.path.join(path, dir_b.files[b].name)) b += 1 elif a < len(dir_a.files) and (b >= len(dir_b.files) or dir_b.files[b].name > dir_a.files[a].name): removed.append(os.path.join(path, dir_a.files[a].name)) a += 1 else: # File exists in both directories if dir_a.files[a].digest.hash != dir_b.files[b].digest.hash: modified.append(os.path.join(path, dir_a.files[a].name)) a += 1 b += 1 a = 0 b = 0 while a < len(dir_a.directories) or b < len(dir_b.directories): if b < len(dir_b.directories) and (a >= len(dir_a.directories) or dir_a.directories[a].name > dir_b.directories[b].name): self._diff_trees(None, dir_b.directories[b].digest, added=added, removed=removed, modified=modified, path=os.path.join(path, dir_b.directories[b].name)) b += 1 elif a < len(dir_a.directories) and (b >= len(dir_b.directories) or dir_b.directories[b].name > dir_a.directories[a].name): self._diff_trees(dir_a.directories[a].digest, None, added=added, removed=removed, modified=modified, path=os.path.join(path, dir_a.directories[a].name)) a += 1 else: # Subdirectory exists in both directories if dir_a.directories[a].digest.hash != dir_b.directories[b].digest.hash: self._diff_trees(dir_a.directories[a].digest, dir_b.directories[b].digest, added=added, removed=removed, modified=modified, path=os.path.join(path, dir_a.directories[a].name)) a += 1 b += 1 def _reachable_refs_dir(self, reachable, tree): if tree.hash in reachable: return reachable.add(tree.hash) directory = remote_execution_pb2.Directory() with open(self.objpath(tree), 'rb') as f: directory.ParseFromString(f.read()) for filenode in directory.files: reachable.add(filenode.digest.hash) for dirnode in directory.directories: self._reachable_refs_dir(reachable, dirnode.digest) def _initialize_remote(self, remote_spec, q): try: remote = _CASRemote(remote_spec) remote.init() request = buildstream_pb2.StatusRequest() response = remote.ref_storage.Status(request) if remote_spec.push and not response.allow_updates: q.put('Artifact server does not allow push') else: # No error q.put(None) except grpc.RpcError as e: # str(e) is too verbose for errors reported to the user q.put(e.details()) except Exception as e: # pylint: disable=broad-except # Whatever happens, we need to return it to the calling process # q.put(str(e)) def _required_blobs(self, directory_digest): # parse directory, and recursively add blobs d = remote_execution_pb2.Digest() d.hash = directory_digest.hash d.size_bytes = directory_digest.size_bytes yield d directory = remote_execution_pb2.Directory() with open(self.objpath(directory_digest), 'rb') as f: directory.ParseFromString(f.read()) for filenode in directory.files: d = remote_execution_pb2.Digest() d.hash = filenode.digest.hash d.size_bytes = filenode.digest.size_bytes yield d for dirnode in directory.directories: yield from self._required_blobs(dirnode.digest) def _fetch_blob(self, remote, digest, stream): resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)]) request = bytestream_pb2.ReadRequest() request.resource_name = resource_name request.read_offset = 0 for response in remote.bytestream.Read(request): stream.write(response.data) stream.flush() assert digest.size_bytes == os.fstat(stream.fileno()).st_size # _ensure_blob(): # # Fetch and add blob if it's not already local. # # Args: # remote (Remote): The remote to use. # digest (Digest): Digest object for the blob to fetch. # # Returns: # (str): The path of the object # def _ensure_blob(self, remote, digest): objpath = self.objpath(digest) if os.path.exists(objpath): # already in local repository return objpath with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: self._fetch_blob(remote, digest, f) added_digest = self.add_object(path=f.name) assert added_digest.hash == digest.hash return objpath def _batch_download_complete(self, batch): for digest, data in batch.send(): with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f: f.write(data) f.flush() added_digest = self.add_object(path=f.name) assert added_digest.hash == digest.hash # Helper function for _fetch_directory(). def _fetch_directory_batch(self, remote, batch, fetch_queue, fetch_next_queue): self._batch_download_complete(batch) # All previously scheduled directories are now locally available, # move them to the processing queue. fetch_queue.extend(fetch_next_queue) fetch_next_queue.clear() return _CASBatchRead(remote) # Helper function for _fetch_directory(). def _fetch_directory_node(self, remote, digest, batch, fetch_queue, fetch_next_queue, *, recursive=False): in_local_cache = os.path.exists(self.objpath(digest)) if in_local_cache: # Skip download, already in local cache. pass elif (digest.size_bytes >= remote.max_batch_total_size_bytes or not remote.batch_read_supported): # Too large for batch request, download in independent request. self._ensure_blob(remote, digest) in_local_cache = True else: if not batch.add(digest): # Not enough space left in batch request. # Complete pending batch first. batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) batch.add(digest) if recursive: if in_local_cache: # Add directory to processing queue. fetch_queue.append(digest) else: # Directory will be available after completing pending batch. # Add directory to deferred processing queue. fetch_next_queue.append(digest) return batch # _fetch_directory(): # # Fetches remote directory and adds it to content addressable store. # # Fetches files, symbolic links and recursively other directories in # the remote directory and adds them to the content addressable # store. # # Args: # remote (Remote): The remote to use. # dir_digest (Digest): Digest object for the directory to fetch. # def _fetch_directory(self, remote, dir_digest): fetch_queue = [dir_digest] fetch_next_queue = [] batch = _CASBatchRead(remote) while len(fetch_queue) + len(fetch_next_queue) > 0: if len(fetch_queue) == 0: batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) dir_digest = fetch_queue.pop(0) objpath = self._ensure_blob(remote, dir_digest) directory = remote_execution_pb2.Directory() with open(objpath, 'rb') as f: directory.ParseFromString(f.read()) for dirnode in directory.directories: batch = self._fetch_directory_node(remote, dirnode.digest, batch, fetch_queue, fetch_next_queue, recursive=True) for filenode in directory.files: batch = self._fetch_directory_node(remote, filenode.digest, batch, fetch_queue, fetch_next_queue) # Fetch final batch self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue) def _fetch_tree(self, remote, digest): # download but do not store the Tree object with tempfile.NamedTemporaryFile(dir=self.tmpdir) as out: self._fetch_blob(remote, digest, out) tree = remote_execution_pb2.Tree() with open(out.name, 'rb') as f: tree.ParseFromString(f.read()) tree.children.extend([tree.root]) for directory in tree.children: for filenode in directory.files: self._ensure_blob(remote, filenode.digest) # place directory blob only in final location when we've downloaded # all referenced blobs to avoid dangling references in the repository dirbuffer = directory.SerializeToString() dirdigest = self.add_object(buffer=dirbuffer) assert dirdigest.size_bytes == len(dirbuffer) return dirdigest def _send_blob(self, remote, digest, stream, u_uid=uuid.uuid4()): resource_name = '/'.join(['uploads', str(u_uid), 'blobs', digest.hash, str(digest.size_bytes)]) def request_stream(resname, instream): offset = 0 finished = False remaining = digest.size_bytes while not finished: chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) remaining -= chunk_size request = bytestream_pb2.WriteRequest() request.write_offset = offset # max. _MAX_PAYLOAD_BYTES chunks request.data = instream.read(chunk_size) request.resource_name = resname request.finish_write = remaining <= 0 yield request offset += chunk_size finished = request.finish_write response = remote.bytestream.Write(request_stream(resource_name, stream)) assert response.committed_size == digest.size_bytes def _send_directory(self, remote, digest, u_uid=uuid.uuid4()): required_blobs = self._required_blobs(digest) missing_blobs = dict() # Limit size of FindMissingBlobs request for required_blobs_group in _grouper(required_blobs, 512): request = remote_execution_pb2.FindMissingBlobsRequest() for required_digest in required_blobs_group: d = request.blob_digests.add() d.hash = required_digest.hash d.size_bytes = required_digest.size_bytes response = remote.cas.FindMissingBlobs(request) for missing_digest in response.missing_blob_digests: d = remote_execution_pb2.Digest() d.hash = missing_digest.hash d.size_bytes = missing_digest.size_bytes missing_blobs[d.hash] = d # Upload any blobs missing on the server self._send_blobs(remote, missing_blobs.values(), u_uid) def _send_blobs(self, remote, digests, u_uid=uuid.uuid4()): batch = _CASBatchUpdate(remote) for digest in digests: with open(self.objpath(digest), 'rb') as f: assert os.fstat(f.fileno()).st_size == digest.size_bytes if (digest.size_bytes >= remote.max_batch_total_size_bytes or not remote.batch_update_supported): # Too large for batch request, upload in independent request. self._send_blob(remote, digest, f, u_uid=u_uid) else: if not batch.add(digest, f): # Not enough space left in batch request. # Complete pending batch first. batch.send() batch = _CASBatchUpdate(remote) batch.add(digest, f) # Send final batch batch.send() # Represents a single remote CAS cache. # class _CASRemote(): def __init__(self, spec): self.spec = spec self._initialized = False self.channel = None self.bytestream = None self.cas = None self.ref_storage = None def init(self): if not self._initialized: url = urlparse(self.spec.url) if url.scheme == 'http': port = url.port or 80 self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port)) elif url.scheme == 'https': port = url.port or 443 if self.spec.server_cert: with open(self.spec.server_cert, 'rb') as f: server_cert_bytes = f.read() else: server_cert_bytes = None if self.spec.client_key: with open(self.spec.client_key, 'rb') as f: client_key_bytes = f.read() else: client_key_bytes = None if self.spec.client_cert: with open(self.spec.client_cert, 'rb') as f: client_cert_bytes = f.read() else: client_cert_bytes = None credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes, private_key=client_key_bytes, certificate_chain=client_cert_bytes) self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials) else: raise ArtifactError("Unsupported URL: {}".format(self.spec.url)) self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel) self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel) self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel) self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES try: request = remote_execution_pb2.GetCapabilitiesRequest() response = self.capabilities.GetCapabilities(request) server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes: self.max_batch_total_size_bytes = server_max_batch_total_size_bytes except grpc.RpcError as e: # Simply use the defaults for servers that don't implement GetCapabilities() if e.code() != grpc.StatusCode.UNIMPLEMENTED: raise # Check whether the server supports BatchReadBlobs() self.batch_read_supported = False try: request = remote_execution_pb2.BatchReadBlobsRequest() response = self.cas.BatchReadBlobs(request) self.batch_read_supported = True except grpc.RpcError as e: if e.code() != grpc.StatusCode.UNIMPLEMENTED: raise # Check whether the server supports BatchUpdateBlobs() self.batch_update_supported = False try: request = remote_execution_pb2.BatchUpdateBlobsRequest() response = self.cas.BatchUpdateBlobs(request) self.batch_update_supported = True except grpc.RpcError as e: if (e.code() != grpc.StatusCode.UNIMPLEMENTED and e.code() != grpc.StatusCode.PERMISSION_DENIED): raise self._initialized = True # Represents a batch of blobs queued for fetching. # class _CASBatchRead(): def __init__(self, remote): self._remote = remote self._max_total_size_bytes = remote.max_batch_total_size_bytes self._request = remote_execution_pb2.BatchReadBlobsRequest() self._size = 0 self._sent = False def add(self, digest): assert not self._sent new_batch_size = self._size + digest.size_bytes if new_batch_size > self._max_total_size_bytes: # Not enough space left in current batch return False request_digest = self._request.digests.add() request_digest.hash = digest.hash request_digest.size_bytes = digest.size_bytes self._size = new_batch_size return True def send(self): assert not self._sent self._sent = True if len(self._request.digests) == 0: return batch_response = self._remote.cas.BatchReadBlobs(self._request) for response in batch_response.responses: if response.status.code != grpc.StatusCode.OK.value[0]: raise ArtifactError("Failed to download blob {}: {}".format( response.digest.hash, response.status.code)) if response.digest.size_bytes != len(response.data): raise ArtifactError("Failed to download blob {}: expected {} bytes, received {} bytes".format( response.digest.hash, response.digest.size_bytes, len(response.data))) yield (response.digest, response.data) # Represents a batch of blobs queued for upload. # class _CASBatchUpdate(): def __init__(self, remote): self._remote = remote self._max_total_size_bytes = remote.max_batch_total_size_bytes self._request = remote_execution_pb2.BatchUpdateBlobsRequest() self._size = 0 self._sent = False def add(self, digest, stream): assert not self._sent new_batch_size = self._size + digest.size_bytes if new_batch_size > self._max_total_size_bytes: # Not enough space left in current batch return False blob_request = self._request.requests.add() blob_request.digest.hash = digest.hash blob_request.digest.size_bytes = digest.size_bytes blob_request.data = stream.read(digest.size_bytes) self._size = new_batch_size return True def send(self): assert not self._sent self._sent = True if len(self._request.requests) == 0: return batch_response = self._remote.cas.BatchUpdateBlobs(self._request) for response in batch_response.responses: if response.status.code != grpc.StatusCode.OK.value[0]: raise ArtifactError("Failed to upload blob {}: {}".format( response.digest.hash, response.status.code)) def _grouper(iterable, n): while True: try: current = next(iterable) except StopIteration: return yield itertools.chain([current], itertools.islice(iterable, n - 1))