diff options
author | Tristan Maat <tm@tlater.net> | 2019-12-03 11:38:54 +0000 |
---|---|---|
committer | Tristan Maat <tm@tlater.net> | 2019-12-03 11:38:54 +0000 |
commit | 476a53975286364088576df270a3be13fe1bec4a (patch) | |
tree | 6534dd727c43f20a2381ae0df9d430716e4b93c7 | |
parent | c35a8eb80e4a069f3ba6fe1f3719c7c74ebf7fd1 (diff) | |
parent | 845a2fdb4ff786c9e9c7441ba321387124a25354 (diff) | |
download | buildstream-476a53975286364088576df270a3be13fe1bec4a.tar.gz |
Merge branch 'tlater/artifactserver-casd' into 'master'
Refactor casserver.py: Stop relying on the buildstream-internal `CASCache` implementation
Closes #1167
See merge request BuildStream/buildstream!1645
-rw-r--r-- | src/buildstream/_artifactcache.py | 26 | ||||
-rw-r--r-- | src/buildstream/_basecache.py | 26 | ||||
-rw-r--r-- | src/buildstream/_cas/cascache.py | 165 | ||||
-rw-r--r-- | src/buildstream/_cas/casdprocessmanager.py | 8 | ||||
-rw-r--r-- | src/buildstream/_cas/casserver.py | 491 | ||||
-rw-r--r-- | src/buildstream/_exceptions.py | 9 | ||||
-rw-r--r-- | src/buildstream/_sourcecache.py | 12 | ||||
-rw-r--r-- | src/buildstream/utils.py | 38 | ||||
-rw-r--r-- | tests/sourcecache/fetch.py | 10 | ||||
-rw-r--r-- | tests/sourcecache/push.py | 6 | ||||
-rw-r--r-- | tests/testutils/artifactshare.py | 16 |
11 files changed, 364 insertions, 443 deletions
diff --git a/src/buildstream/_artifactcache.py b/src/buildstream/_artifactcache.py index 10ccf1527..02dd21d41 100644 --- a/src/buildstream/_artifactcache.py +++ b/src/buildstream/_artifactcache.py @@ -22,7 +22,7 @@ import os import grpc from ._basecache import BaseCache -from ._exceptions import ArtifactError, CASError, CASCacheError, CASRemoteError, RemoteError +from ._exceptions import ArtifactError, CASError, CacheError, CASRemoteError, RemoteError from ._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc, artifact_pb2, artifact_pb2_grpc from ._remote import BaseRemote @@ -146,12 +146,12 @@ class ArtifactCache(BaseCache): super().__init__(context) # create artifact directory - self.artifactdir = context.artifactdir - os.makedirs(self.artifactdir, exist_ok=True) + self._basedir = context.artifactdir + os.makedirs(self._basedir, exist_ok=True) def update_mtime(self, ref): try: - os.utime(os.path.join(self.artifactdir, ref)) + os.utime(os.path.join(self._basedir, ref)) except FileNotFoundError as e: raise ArtifactError("Couldn't find artifact: {}".format(ref)) from e @@ -176,7 +176,7 @@ class ArtifactCache(BaseCache): def contains(self, element, key): ref = element.get_artifact_name(key) - return os.path.exists(os.path.join(self.artifactdir, ref)) + return os.path.exists(os.path.join(self._basedir, ref)) # list_artifacts(): # @@ -189,7 +189,7 @@ class ArtifactCache(BaseCache): # ([str]) - A list of artifact names as generated in LRU order # def list_artifacts(self, *, glob=None): - return [ref for _, ref in sorted(list(self._list_refs_mtimes(self.artifactdir, glob_expr=glob)))] + return [ref for _, ref in sorted(list(self._list_refs_mtimes(self._basedir, glob_expr=glob)))] # remove(): # @@ -202,8 +202,8 @@ class ArtifactCache(BaseCache): # def remove(self, ref): try: - self.cas.remove(ref, basedir=self.artifactdir) - except CASCacheError as e: + self._remove_ref(ref) + except CacheError as e: raise ArtifactError("{}".format(e)) from e # diff(): @@ -410,8 +410,8 @@ class ArtifactCache(BaseCache): oldref = element.get_artifact_name(oldkey) newref = element.get_artifact_name(newkey) - if not os.path.exists(os.path.join(self.artifactdir, newref)): - os.link(os.path.join(self.artifactdir, oldref), os.path.join(self.artifactdir, newref)) + if not os.path.exists(os.path.join(self._basedir, newref)): + os.link(os.path.join(self._basedir, oldref), os.path.join(self._basedir, newref)) # get_artifact_logs(): # @@ -514,7 +514,7 @@ class ArtifactCache(BaseCache): # (iter): Iterator over directories digests available from artifacts. # def _reachable_directories(self): - for root, _, files in os.walk(self.artifactdir): + for root, _, files in os.walk(self._basedir): for artifact_file in files: artifact = artifact_pb2.Artifact() with open(os.path.join(root, artifact_file), "r+b") as f: @@ -532,7 +532,7 @@ class ArtifactCache(BaseCache): # (iter): Iterator over single file digests in artifacts # def _reachable_digests(self): - for root, _, files in os.walk(self.artifactdir): + for root, _, files in os.walk(self._basedir): for artifact_file in files: artifact = artifact_pb2.Artifact() with open(os.path.join(root, artifact_file), "r+b") as f: @@ -707,7 +707,7 @@ class ArtifactCache(BaseCache): return None # Write the artifact proto to cache - artifact_path = os.path.join(self.artifactdir, artifact_name) + artifact_path = os.path.join(self._basedir, artifact_name) os.makedirs(os.path.dirname(artifact_path), exist_ok=True) with utils.save_file_atomic(artifact_path, mode="wb") as f: f.write(artifact.SerializeToString()) diff --git a/src/buildstream/_basecache.py b/src/buildstream/_basecache.py index 15b1d5389..dff7742e7 100644 --- a/src/buildstream/_basecache.py +++ b/src/buildstream/_basecache.py @@ -25,7 +25,7 @@ from . import utils from . import _yaml from ._cas import CASRemote from ._message import Message, MessageType -from ._exceptions import LoadError, RemoteError +from ._exceptions import LoadError, RemoteError, CacheError from ._remote import RemoteSpec, RemoteType @@ -62,6 +62,8 @@ class BaseCache: self._has_fetch_remotes = False self._has_push_remotes = False + self._basedir = None + # has_open_grpc_channels(): # # Return whether there are gRPC channel instances. This is used to safeguard @@ -429,3 +431,25 @@ class BaseCache: if not glob_expr or fnmatch(relative_path, glob_expr): # Obtain the mtime (the time a file was last modified) yield (os.path.getmtime(ref_path), relative_path) + + # _remove_ref() + # + # Removes a ref. + # + # This also takes care of pruning away directories which can + # be removed after having removed the given ref. + # + # Args: + # ref (str): The ref to remove + # + # Raises: + # (CASCacheError): If the ref didnt exist, or a system error + # occurred while removing it + # + def _remove_ref(self, ref): + try: + utils._remove_path_with_parents(self._basedir, ref) + except FileNotFoundError as e: + raise CacheError("Could not find ref '{}'".format(ref)) from e + except OSError as e: + raise CacheError("System error while removing ref '{}': {}".format(ref, e)) from e diff --git a/src/buildstream/_cas/cascache.py b/src/buildstream/_cas/cascache.py index 98581d351..c45a199fe 100644 --- a/src/buildstream/_cas/cascache.py +++ b/src/buildstream/_cas/cascache.py @@ -21,7 +21,6 @@ import itertools import os import stat -import errno import contextlib import ctypes import multiprocessing @@ -69,7 +68,6 @@ class CASCache: ): self.casdir = os.path.join(path, "cas") self.tmpdir = os.path.join(path, "tmp") - os.makedirs(os.path.join(self.casdir, "refs", "heads"), exist_ok=True) os.makedirs(os.path.join(self.casdir, "objects"), exist_ok=True) os.makedirs(self.tmpdir, exist_ok=True) @@ -134,9 +132,7 @@ class CASCache: # Preflight check. # def preflight(self): - headdir = os.path.join(self.casdir, "refs", "heads") - objdir = os.path.join(self.casdir, "objects") - if not (os.path.isdir(headdir) and os.path.isdir(objdir)): + if not os.path.join(self.casdir, "objects"): raise CASCacheError("CAS repository check failed for '{}'".format(self.casdir)) # close_grpc_channels(): @@ -160,21 +156,6 @@ class CASCache: self._casd_process_manager.release_resources(messenger) self._casd_process_manager = None - # contains(): - # - # Check whether the specified ref is already available in the local CAS cache. - # - # Args: - # ref (str): The ref to check - # - # Returns: True if the ref is in the cache, False otherwise - # - def contains(self, ref): - refpath = self._refpath(ref) - - # This assumes that the repository doesn't have any dangling pointers - return os.path.exists(refpath) - # contains_file(): # # Check whether a digest corresponds to a file which exists in CAS @@ -261,28 +242,6 @@ class CASCache: fullpath = os.path.join(dest, symlinknode.name) os.symlink(symlinknode.target, fullpath) - # diff(): - # - # Return a list of files that have been added or modified between - # the refs described by ref_a and ref_b. - # - # Args: - # ref_a (str): The first ref - # ref_b (str): The second ref - # subdir (str): A subdirectory to limit the comparison to - # - def diff(self, ref_a, ref_b): - tree_a = self.resolve_ref(ref_a) - tree_b = self.resolve_ref(ref_b) - - added = [] - removed = [] - modified = [] - - self.diff_trees(tree_a, tree_b, added=added, removed=removed, modified=modified) - - return modified, removed, added - # pull_tree(): # # Pull a single Tree rather than a ref. @@ -409,74 +368,6 @@ class CASCache: return utils._message_digest(root_directory) - # set_ref(): - # - # Create or replace a ref. - # - # Args: - # ref (str): The name of the ref - # - def set_ref(self, ref, tree): - refpath = self._refpath(ref) - os.makedirs(os.path.dirname(refpath), exist_ok=True) - with utils.save_file_atomic(refpath, "wb", tempdir=self.tmpdir) as f: - f.write(tree.SerializeToString()) - - # resolve_ref(): - # - # Resolve a ref to a digest. - # - # Args: - # ref (str): The name of the ref - # update_mtime (bool): Whether to update the mtime of the ref - # - # Returns: - # (Digest): The digest stored in the ref - # - def resolve_ref(self, ref, *, update_mtime=False): - refpath = self._refpath(ref) - - try: - with open(refpath, "rb") as f: - if update_mtime: - os.utime(refpath) - - digest = remote_execution_pb2.Digest() - digest.ParseFromString(f.read()) - return digest - - except FileNotFoundError as e: - raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e - - # update_mtime() - # - # Update the mtime of a ref. - # - # Args: - # ref (str): The ref to update - # - def update_mtime(self, ref): - try: - os.utime(self._refpath(ref)) - except FileNotFoundError as e: - raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e - - # remove(): - # - # Removes the given symbolic ref from the repo. - # - # Args: - # ref (str): A symbolic ref - # basedir (str): Path of base directory the ref is in, defaults to - # CAS refs heads - # - def remove(self, ref, *, basedir=None): - - if basedir is None: - basedir = os.path.join(self.casdir, "refs", "heads") - # Remove cache ref - self._remove_ref(ref, basedir) - def update_tree_mtime(self, tree): reachable = set() self._reachable_refs_dir(reachable, tree, update_mtime=True) @@ -645,60 +536,6 @@ class CASCache: # Local Private Methods # ################################################ - def _refpath(self, ref): - return os.path.join(self.casdir, "refs", "heads", ref) - - # _remove_ref() - # - # Removes a ref. - # - # This also takes care of pruning away directories which can - # be removed after having removed the given ref. - # - # Args: - # ref (str): The ref to remove - # basedir (str): Path of base directory the ref is in - # - # Raises: - # (CASCacheError): If the ref didnt exist, or a system error - # occurred while removing it - # - def _remove_ref(self, ref, basedir): - - # Remove the ref itself - refpath = os.path.join(basedir, ref) - - try: - os.unlink(refpath) - except FileNotFoundError as e: - raise CASCacheError("Could not find ref '{}'".format(ref)) from e - - # Now remove any leading directories - - components = list(os.path.split(ref)) - while components: - components.pop() - refdir = os.path.join(basedir, *components) - - # Break out once we reach the base - if refdir == basedir: - break - - try: - os.rmdir(refdir) - except FileNotFoundError: - # The parent directory did not exist, but it's - # parent directory might still be ready to prune - pass - except OSError as e: - if e.errno == errno.ENOTEMPTY: - # The parent directory was not empty, so we - # cannot prune directories beyond this point - break - - # Something went wrong here - raise CASCacheError("System error while removing ref '{}': {}".format(ref, e)) from e - def _get_subdir(self, tree, subdir): head, name = os.path.split(subdir) if head: diff --git a/src/buildstream/_cas/casdprocessmanager.py b/src/buildstream/_cas/casdprocessmanager.py index e4a58d7d5..68bb88ef0 100644 --- a/src/buildstream/_cas/casdprocessmanager.py +++ b/src/buildstream/_cas/casdprocessmanager.py @@ -28,6 +28,7 @@ import grpc from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2_grpc from .._protos.build.buildgrid import local_cas_pb2_grpc +from .._protos.google.bytestream import bytestream_pb2_grpc from .. import _signals, utils from .._exceptions import CASCacheError @@ -177,6 +178,7 @@ class CASDChannel: self._connection_string = connection_string self._start_time = start_time self._casd_channel = None + self._bytestream = None self._casd_cas = None self._local_cas = None @@ -192,6 +194,7 @@ class CASDChannel: time.sleep(0.01) self._casd_channel = grpc.insecure_channel(self._connection_string) + self._bytestream = bytestream_pb2_grpc.ByteStreamStub(self._casd_channel) self._casd_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self._casd_channel) self._local_cas = local_cas_pb2_grpc.LocalContentAddressableStorageStub(self._casd_channel) @@ -213,6 +216,11 @@ class CASDChannel: self._establish_connection() return self._local_cas + def get_bytestream(self): + if self._casd_channel is None: + self._establish_connection() + return self._bytestream + # is_closed(): # # Return whether this connection is closed or not. diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py index a2110d8a2..e4acbde55 100644 --- a/src/buildstream/_cas/casserver.py +++ b/src/buildstream/_cas/casserver.py @@ -18,21 +18,24 @@ # Jürg Billeter <juerg.billeter@codethink.co.uk> from concurrent import futures -from contextlib import contextmanager +from enum import Enum +import contextlib +import logging import os import signal import sys -import tempfile import uuid -import errno import grpc from google.protobuf.message import DecodeError import click -from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc -from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc -from .._protos.google.rpc import code_pb2 +from .._protos.build.bazel.remote.execution.v2 import ( + remote_execution_pb2, + remote_execution_pb2_grpc, +) +from .._protos.google.bytestream import bytestream_pb2_grpc +from .._protos.build.buildgrid import local_cas_pb2 from .._protos.buildstream.v2 import ( buildstream_pb2, buildstream_pb2_grpc, @@ -42,10 +45,15 @@ from .._protos.buildstream.v2 import ( source_pb2_grpc, ) -from .. import utils -from .._exceptions import CASError, CASCacheError - -from .cascache import CASCache +# Note: We'd ideally like to avoid imports from the core codebase as +# much as possible, since we're expecting to eventually split this +# module off into its own project. +# +# Not enough that we'd like to duplicate code, but enough that we want +# to make it very obvious what we're using, so in this case we import +# the specific methods we'll be using. +from ..utils import save_file_atomic, _remove_path_with_parents +from .casdprocessmanager import CASDProcessManager # The default limit for gRPC messages is 4 MiB. @@ -53,6 +61,37 @@ from .cascache import CASCache _MAX_PAYLOAD_BYTES = 1024 * 1024 +# LogLevel(): +# +# Manage log level choices using click. +# +class LogLevel(click.Choice): + # Levels(): + # + # Represents the actual buildbox-casd log level. + # + class Levels(Enum): + WARNING = "warning" + INFO = "info" + TRACE = "trace" + + def __init__(self): + super().__init__([m.lower() for m in LogLevel.Levels._member_names_]) # pylint: disable=no-member + + def convert(self, value, param, ctx) -> "LogLevel.Levels": + return LogLevel.Levels(super().convert(value, param, ctx)) + + @classmethod + def get_logging_equivalent(cls, level) -> int: + equivalents = { + cls.Levels.WARNING: logging.WARNING, + cls.Levels.INFO: logging.INFO, + cls.Levels.TRACE: logging.DEBUG, + } + + return equivalents[level] + + # create_server(): # # Create gRPC CAS artifact server as specified in the Remote Execution API. @@ -62,13 +101,22 @@ _MAX_PAYLOAD_BYTES = 1024 * 1024 # enable_push (bool): Whether to allow blob uploads and artifact updates # index_only (bool): Whether to store CAS blobs or only artifacts # -@contextmanager -def create_server(repo, *, enable_push, quota, index_only): - cas = CASCache(os.path.abspath(repo), cache_quota=quota, protect_session_blobs=False) +@contextlib.contextmanager +def create_server(repo, *, enable_push, quota, index_only, log_level=LogLevel.Levels.WARNING): + logger = logging.getLogger("buildstream._cas.casserver") + logger.setLevel(LogLevel.get_logging_equivalent(log_level)) + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(logging.Formatter(fmt="%(levelname)s: %(funcName)s: %(message)s")) + logger.addHandler(handler) + + casd_manager = CASDProcessManager( + os.path.abspath(repo), os.path.join(os.path.abspath(repo), "logs"), log_level, quota, False + ) + casd_channel = casd_manager.create_channel() try: - artifactdir = os.path.join(os.path.abspath(repo), "artifacts", "refs") - sourcedir = os.path.join(os.path.abspath(repo), "source_protos") + root = os.path.abspath(repo) + sourcedir = os.path.join(root, "source_protos") # Use max_workers default from Python 3.5+ max_workers = (os.cpu_count() or 1) * 5 @@ -76,21 +124,21 @@ def create_server(repo, *, enable_push, quota, index_only): if not index_only: bytestream_pb2_grpc.add_ByteStreamServicer_to_server( - _ByteStreamServicer(cas, enable_push=enable_push), server + _ByteStreamServicer(casd_channel, enable_push=enable_push), server ) remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server( - _ContentAddressableStorageServicer(cas, enable_push=enable_push), server + _ContentAddressableStorageServicer(casd_channel, enable_push=enable_push), server ) remote_execution_pb2_grpc.add_CapabilitiesServicer_to_server(_CapabilitiesServicer(), server) buildstream_pb2_grpc.add_ReferenceStorageServicer_to_server( - _ReferenceStorageServicer(cas, enable_push=enable_push), server + _ReferenceStorageServicer(casd_channel, root, enable_push=enable_push), server ) artifact_pb2_grpc.add_ArtifactServiceServicer_to_server( - _ArtifactServicer(cas, artifactdir, update_cas=not index_only), server + _ArtifactServicer(casd_channel, root, update_cas=not index_only), server ) source_pb2_grpc.add_SourceServiceServicer_to_server(_SourceServicer(sourcedir), server) @@ -105,7 +153,8 @@ def create_server(repo, *, enable_push, quota, index_only): yield server finally: - cas.release_resources() + casd_channel.close() + casd_manager.release_resources() @click.command(short_help="CAS Artifact Server") @@ -120,14 +169,17 @@ def create_server(repo, *, enable_push, quota, index_only): is_flag=True, help='Only provide the BuildStream artifact and source services ("index"), not the CAS ("storage")', ) +@click.option("--log-level", type=LogLevel(), help="The log level to launch with", default="warning") @click.argument("repo") -def server_main(repo, port, server_key, server_cert, client_certs, enable_push, quota, index_only): +def server_main(repo, port, server_key, server_cert, client_certs, enable_push, quota, index_only, log_level): # Handle SIGTERM by calling sys.exit(0), which will raise a SystemExit exception, # properly executing cleanup code in `finally` clauses and context managers. # This is required to terminate buildbox-casd on SIGTERM. signal.signal(signal.SIGTERM, lambda signalnum, frame: sys.exit(0)) - with create_server(repo, quota=quota, enable_push=enable_push, index_only=index_only) as server: + with create_server( + repo, quota=quota, enable_push=enable_push, index_only=index_only, log_level=log_level + ) as server: use_tls = bool(server_key) @@ -171,216 +223,49 @@ def server_main(repo, port, server_key, server_cert, client_certs, enable_push, class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer): - def __init__(self, cas, *, enable_push): + def __init__(self, casd, *, enable_push): super().__init__() - self.cas = cas + self.bytestream = casd.get_bytestream() self.enable_push = enable_push + self.logger = logging.getLogger("buildstream._cas.casserver") def Read(self, request, context): - resource_name = request.resource_name - client_digest = _digest_from_download_resource_name(resource_name) - if client_digest is None: - context.set_code(grpc.StatusCode.NOT_FOUND) - return - - if request.read_offset > client_digest.size_bytes: - context.set_code(grpc.StatusCode.OUT_OF_RANGE) - return - - try: - with open(self.cas.objpath(client_digest), "rb") as f: - if os.fstat(f.fileno()).st_size != client_digest.size_bytes: - context.set_code(grpc.StatusCode.NOT_FOUND) - return - - os.utime(f.fileno()) - - if request.read_offset > 0: - f.seek(request.read_offset) - - remaining = client_digest.size_bytes - request.read_offset - while remaining > 0: - chunk_size = min(remaining, _MAX_PAYLOAD_BYTES) - remaining -= chunk_size - - response = bytestream_pb2.ReadResponse() - # max. 64 kB chunks - response.data = f.read(chunk_size) - yield response - except FileNotFoundError: - context.set_code(grpc.StatusCode.NOT_FOUND) + self.logger.debug("Reading %s", request.resource_name) + return self.bytestream.Read(request) def Write(self, request_iterator, context): - response = bytestream_pb2.WriteResponse() - - if not self.enable_push: - context.set_code(grpc.StatusCode.PERMISSION_DENIED) - return response - - offset = 0 - finished = False - resource_name = None - with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out: - for request in request_iterator: - if finished or request.write_offset != offset: - context.set_code(grpc.StatusCode.FAILED_PRECONDITION) - return response - - if resource_name is None: - # First request - resource_name = request.resource_name - client_digest = _digest_from_upload_resource_name(resource_name) - if client_digest is None: - context.set_code(grpc.StatusCode.NOT_FOUND) - return response - - while True: - if client_digest.size_bytes == 0: - break - - try: - os.posix_fallocate(out.fileno(), 0, client_digest.size_bytes) - break - except OSError as e: - # Multiple upload can happen in the same time - if e.errno != errno.ENOSPC: - raise - - elif request.resource_name: - # If it is set on subsequent calls, it **must** match the value of the first request. - if request.resource_name != resource_name: - context.set_code(grpc.StatusCode.FAILED_PRECONDITION) - return response - - if (offset + len(request.data)) > client_digest.size_bytes: - context.set_code(grpc.StatusCode.FAILED_PRECONDITION) - return response - - out.write(request.data) - offset += len(request.data) - if request.finish_write: - if client_digest.size_bytes != offset: - context.set_code(grpc.StatusCode.FAILED_PRECONDITION) - return response - out.flush() - - try: - digest = self.cas.add_object(path=out.name, link_directly=True) - except CASCacheError as e: - if e.reason == "cache-too-full": - context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) - else: - context.set_code(grpc.StatusCode.INTERNAL) - return response - - if digest.hash != client_digest.hash: - context.set_code(grpc.StatusCode.FAILED_PRECONDITION) - return response - - finished = True - - assert finished - - response.committed_size = offset - return response + # Note that we can't easily give more information because the + # data is stuck in an iterator that will be consumed if read. + self.logger.debug("Writing data") + return self.bytestream.Write(request_iterator) class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddressableStorageServicer): - def __init__(self, cas, *, enable_push): + def __init__(self, casd, *, enable_push): super().__init__() - self.cas = cas + self.cas = casd.get_cas() self.enable_push = enable_push + self.logger = logging.getLogger("buildstream._cas.casserver") def FindMissingBlobs(self, request, context): - response = remote_execution_pb2.FindMissingBlobsResponse() - for digest in request.blob_digests: - objpath = self.cas.objpath(digest) - try: - os.utime(objpath) - except OSError as e: - if e.errno != errno.ENOENT: - raise - - d = response.missing_blob_digests.add() - d.hash = digest.hash - d.size_bytes = digest.size_bytes - - return response + self.logger.info("Finding '%s'", request.blob_digests) + return self.cas.FindMissingBlobs(request) def BatchReadBlobs(self, request, context): - response = remote_execution_pb2.BatchReadBlobsResponse() - batch_size = 0 - - for digest in request.digests: - batch_size += digest.size_bytes - if batch_size > _MAX_PAYLOAD_BYTES: - context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - return response - - blob_response = response.responses.add() - blob_response.digest.hash = digest.hash - blob_response.digest.size_bytes = digest.size_bytes - try: - objpath = self.cas.objpath(digest) - with open(objpath, "rb") as f: - if os.fstat(f.fileno()).st_size != digest.size_bytes: - blob_response.status.code = code_pb2.NOT_FOUND - continue - - os.utime(f.fileno()) - - blob_response.data = f.read(digest.size_bytes) - except FileNotFoundError: - blob_response.status.code = code_pb2.NOT_FOUND - - return response + self.logger.info("Reading '%s'", request.digests) + return self.cas.BatchReadBlobs(request) def BatchUpdateBlobs(self, request, context): - response = remote_execution_pb2.BatchUpdateBlobsResponse() - - if not self.enable_push: - context.set_code(grpc.StatusCode.PERMISSION_DENIED) - return response - - batch_size = 0 - - for blob_request in request.requests: - digest = blob_request.digest - - batch_size += digest.size_bytes - if batch_size > _MAX_PAYLOAD_BYTES: - context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - return response - - blob_response = response.responses.add() - blob_response.digest.hash = digest.hash - blob_response.digest.size_bytes = digest.size_bytes - - if len(blob_request.data) != digest.size_bytes: - blob_response.status.code = code_pb2.FAILED_PRECONDITION - continue - - with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out: - out.write(blob_request.data) - out.flush() - - try: - server_digest = self.cas.add_object(path=out.name) - except CASCacheError as e: - if e.reason == "cache-too-full": - blob_response.status.code = code_pb2.RESOURCE_EXHAUSTED - else: - blob_response.status.code = code_pb2.INTERNAL - continue - - if server_digest.hash != digest.hash: - blob_response.status.code = code_pb2.FAILED_PRECONDITION - - return response + self.logger.info("Updating: '%s'", [request.digest for request in request.requests]) + return self.cas.BatchUpdateBlobs(request) class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer): + def __init__(self): + self.logger = logging.getLogger("buildstream._cas.casserver") + def GetCapabilities(self, request, context): + self.logger.info("Retrieving capabilities") response = remote_execution_pb2.ServerCapabilities() cache_capabilities = response.cache_capabilities @@ -397,31 +282,85 @@ class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer): class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer): - def __init__(self, cas, *, enable_push): + def __init__(self, casd, cas_root, *, enable_push): super().__init__() - self.cas = cas + self.cas = casd.get_cas() + self.root = cas_root self.enable_push = enable_push + self.logger = logging.getLogger("buildstream._cas.casserver") + self.tmpdir = os.path.join(self.root, "tmp") + self.casdir = os.path.join(self.root, "cas") + self.refdir = os.path.join(self.casdir, "refs", "heads") + os.makedirs(self.tmpdir, exist_ok=True) + + # ref_path(): + # + # Get the path to a digest's file. + # + # Args: + # ref - The ref of the digest. + # + # Returns: + # str - The path to the digest's file. + # + def ref_path(self, ref: str) -> str: + return os.path.join(self.refdir, ref) + + # set_ref(): + # + # Create or update a ref with a new digest. + # + # Args: + # ref - The ref of the digest. + # tree - The digest to write. + # + def set_ref(self, ref: str, tree): + ref_path = self.ref_path(ref) + + os.makedirs(os.path.dirname(ref_path), exist_ok=True) + with save_file_atomic(ref_path, "wb", tempdir=self.tmpdir) as f: + f.write(tree.SerializeToString()) + + # resolve_ref(): + # + # Resolve a ref to a digest. + # + # Args: + # ref (str): The name of the ref + # + # Returns: + # (Digest): The digest stored in the ref + # + def resolve_ref(self, ref): + ref_path = self.ref_path(ref) + + with open(ref_path, "rb") as f: + os.utime(ref_path) + + digest = remote_execution_pb2.Digest() + digest.ParseFromString(f.read()) + return digest def GetReference(self, request, context): + self.logger.debug("'%s'", request.key) response = buildstream_pb2.GetReferenceResponse() try: - tree = self.cas.resolve_ref(request.key, update_mtime=True) - try: - self.cas.update_tree_mtime(tree) - except FileNotFoundError: - self.cas.remove(request.key) - context.set_code(grpc.StatusCode.NOT_FOUND) - return response - - response.digest.hash = tree.hash - response.digest.size_bytes = tree.size_bytes - except CASError: + digest = self.resolve_ref(request.key) + except FileNotFoundError: + with contextlib.suppress(FileNotFoundError): + _remove_path_with_parents(self.refdir, request.key) + context.set_code(grpc.StatusCode.NOT_FOUND) + return response + + response.digest.hash = digest.hash + response.digest.size_bytes = digest.size_bytes return response def UpdateReference(self, request, context): + self.logger.debug("%s -> %s", request.keys, request.digest) response = buildstream_pb2.UpdateReferenceResponse() if not self.enable_push: @@ -429,11 +368,12 @@ class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer): return response for key in request.keys: - self.cas.set_ref(key, request.digest) + self.set_ref(key, request.digest) return response def Status(self, request, context): + self.logger.debug("Retrieving status") response = buildstream_pb2.StatusResponse() response.allow_updates = self.enable_push @@ -442,14 +382,48 @@ class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer): class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): - def __init__(self, cas, artifactdir, *, update_cas=True): + def __init__(self, casd, root, *, update_cas=True): super().__init__() - self.cas = cas - self.artifactdir = artifactdir + self.cas = casd.get_cas() + self.local_cas = casd.get_local_cas() + self.root = root + self.artifactdir = os.path.join(root, "artifacts", "refs") self.update_cas = update_cas - os.makedirs(artifactdir, exist_ok=True) + self.logger = logging.getLogger("buildstream._cas.casserver") + + # object_path(): + # + # Get the path to an object's file. + # + # Args: + # digest - The digest of the object. + # + # Returns: + # str - The path to the object's file. + # + def object_path(self, digest) -> str: + return os.path.join(self.root, "cas", "objects", digest.hash[:2], digest.hash[2:]) + + # resolve_digest(): + # + # Read the directory corresponding to a digest. + # + # Args: + # digest - The digest corresponding to a directory. + # + # Returns: + # remote_execution_pb2.Directory - The directory. + # + # Raises: + # FileNotFoundError - If the digest object doesn't exist. + def resolve_digest(self, digest): + directory = remote_execution_pb2.Directory() + with open(self.object_path(digest), "rb") as f: + directory.ParseFromString(f.read()) + return directory def GetArtifact(self, request, context): + self.logger.info("'%s'", request.cache_key) artifact_path = os.path.join(self.artifactdir, request.cache_key) if not os.path.exists(artifact_path): context.abort(grpc.StatusCode.NOT_FOUND, "Artifact proto not found") @@ -458,6 +432,8 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): with open(artifact_path, "rb") as f: artifact.ParseFromString(f.read()) + os.utime(artifact_path) + # Artifact-only servers will not have blobs on their system, # so we can't reasonably update their mtimes. Instead, we exit # early, and let the CAS server deal with its blobs. @@ -476,30 +452,45 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): try: if str(artifact.files): - self.cas.update_tree_mtime(artifact.files) + request = local_cas_pb2.FetchTreeRequest() + request.root_digest.CopyFrom(artifact.files) + request.fetch_file_blobs = True + self.local_cas.FetchTree(request) if str(artifact.buildtree): - # buildtrees might not be there try: - self.cas.update_tree_mtime(artifact.buildtree) - except FileNotFoundError: - pass + request = local_cas_pb2.FetchTreeRequest() + request.root_digest.CopyFrom(artifact.buildtree) + request.fetch_file_blobs = True + self.local_cas.FetchTree(request) + except grpc.RpcError as err: + # buildtrees might not be there + if err.code() != grpc.StatusCode.NOT_FOUND: + raise if str(artifact.public_data): - os.utime(self.cas.objpath(artifact.public_data)) + request = remote_execution_pb2.FindMissingBlobsRequest() + d = request.blob_digests.add() + d.CopyFrom(artifact.public_data) + self.cas.FindMissingBlobs(request) + request = remote_execution_pb2.FindMissingBlobsRequest() for log_file in artifact.logs: - os.utime(self.cas.objpath(log_file.digest)) - - except FileNotFoundError: - os.unlink(artifact_path) - context.abort(grpc.StatusCode.NOT_FOUND, "Artifact files incomplete") - except DecodeError: - context.abort(grpc.StatusCode.NOT_FOUND, "Artifact files not valid") + d = request.blob_digests.add() + d.CopyFrom(log_file.digest) + self.cas.FindMissingBlobs(request) + + except grpc.RpcError as err: + if err.code() == grpc.StatusCode.NOT_FOUND: + os.unlink(artifact_path) + context.abort(grpc.StatusCode.NOT_FOUND, "Artifact files incomplete") + else: + context.abort(grpc.StatusCode.NOT_FOUND, "Artifact files not valid") return artifact def UpdateArtifact(self, request, context): + self.logger.info("'%s' -> '%s'", request.artifact, request.cache_key) artifact = request.artifact if self.update_cas: @@ -518,28 +509,29 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer): # Add the artifact proto to the cas artifact_path = os.path.join(self.artifactdir, request.cache_key) os.makedirs(os.path.dirname(artifact_path), exist_ok=True) - with utils.save_file_atomic(artifact_path, mode="wb") as f: + with save_file_atomic(artifact_path, mode="wb") as f: f.write(artifact.SerializeToString()) return artifact def ArtifactStatus(self, request, context): + self.logger.info("Retrieving status") return artifact_pb2.ArtifactStatusResponse() def _check_directory(self, name, digest, context): try: - directory = remote_execution_pb2.Directory() - with open(self.cas.objpath(digest), "rb") as f: - directory.ParseFromString(f.read()) + self.resolve_digest(digest) except FileNotFoundError: + self.logger.warning("Artifact %s specified but no files found", name) context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Artifact {} specified but no files found".format(name)) except DecodeError: + self.logger.warning("Artifact %s specified but directory not found", name) context.abort( grpc.StatusCode.FAILED_PRECONDITION, "Artifact {} specified but directory not found".format(name) ) def _check_file(self, name, digest, context): - if not os.path.exists(self.cas.objpath(digest)): + if not os.path.exists(self.object_path(digest)): context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Artifact {} specified but not found".format(name)) @@ -558,8 +550,10 @@ class _BuildStreamCapabilitiesServicer(buildstream_pb2_grpc.CapabilitiesServicer class _SourceServicer(source_pb2_grpc.SourceServiceServicer): def __init__(self, sourcedir): self.sourcedir = sourcedir + self.logger = logging.getLogger("buildstream._cas.casserver") def GetSource(self, request, context): + self.logger.info("'%s'", request.cache_key) try: source_proto = self._get_source(request.cache_key) except FileNotFoundError: @@ -570,6 +564,7 @@ class _SourceServicer(source_pb2_grpc.SourceServiceServicer): return source_proto def UpdateSource(self, request, context): + self.logger.info("'%s' -> '%s'", request.source, request.cache_key) self._set_source(request.cache_key, request.source) return request.source @@ -584,7 +579,7 @@ class _SourceServicer(source_pb2_grpc.SourceServiceServicer): def _set_source(self, cache_key, source_proto): path = os.path.join(self.sourcedir, cache_key) os.makedirs(os.path.dirname(path), exist_ok=True) - with utils.save_file_atomic(path, "w+b") as f: + with save_file_atomic(path, "w+b") as f: f.write(source_proto.SerializeToString()) diff --git a/src/buildstream/_exceptions.py b/src/buildstream/_exceptions.py index ca17577f7..51f542783 100644 --- a/src/buildstream/_exceptions.py +++ b/src/buildstream/_exceptions.py @@ -273,6 +273,15 @@ class SandboxError(BstError): super().__init__(message, detail=detail, domain=ErrorDomain.SANDBOX, reason=reason) +# CacheError +# +# Raised when errors are encountered in either type of cache +# +class CacheError(BstError): + def __init__(self, message, detail=None, reason=None): + super().__init__(message, detail=detail, domain=ErrorDomain.SANDBOX, reason=reason) + + # SourceCacheError # # Raised when errors are encountered in the source caches diff --git a/src/buildstream/_sourcecache.py b/src/buildstream/_sourcecache.py index 03e2d1830..221694e94 100644 --- a/src/buildstream/_sourcecache.py +++ b/src/buildstream/_sourcecache.py @@ -129,8 +129,8 @@ class SourceCache(BaseCache): def __init__(self, context): super().__init__(context) - self.sourcerefdir = os.path.join(context.cachedir, "source_protos") - os.makedirs(self.sourcerefdir, exist_ok=True) + self._basedir = os.path.join(context.cachedir, "source_protos") + os.makedirs(self._basedir, exist_ok=True) # list_sources() # @@ -140,7 +140,7 @@ class SourceCache(BaseCache): # ([str]): iterable over all source refs # def list_sources(self): - return [ref for _, ref in self._list_refs_mtimes(self.sourcerefdir)] + return [ref for _, ref in self._list_refs_mtimes(self._basedir)] # contains() # @@ -326,7 +326,7 @@ class SourceCache(BaseCache): return pushed_index and pushed_storage def _remove_source(self, ref, *, defer_prune=False): - return self.cas.remove(ref, basedir=self.sourcerefdir, defer_prune=defer_prune) + return self.cas.remove(ref, basedir=self._basedir, defer_prune=defer_prune) def _store_source(self, ref, digest): source_proto = source_pb2.Source() @@ -351,10 +351,10 @@ class SourceCache(BaseCache): raise SourceCacheError("Attempted to access unavailable source: {}".format(e)) from e def _source_path(self, ref): - return os.path.join(self.sourcerefdir, ref) + return os.path.join(self._basedir, ref) def _reachable_directories(self): - for root, _, files in os.walk(self.sourcerefdir): + for root, _, files in os.walk(self._basedir): for source_file in files: source = source_pb2.Source() with open(os.path.join(root, source_file), "r+b") as f: diff --git a/src/buildstream/utils.py b/src/buildstream/utils.py index 0307b8470..5009be338 100644 --- a/src/buildstream/utils.py +++ b/src/buildstream/utils.py @@ -779,6 +779,44 @@ def _is_main_process(): return os.getpid() == _MAIN_PID +# Remove a path and any empty directories leading up to it. +# +# Args: +# basedir - The basedir at which to stop pruning even if +# it is empty. +# path - A path relative to basedir that should be pruned. +# +# Raises: +# FileNotFoundError - if the path itself doesn't exist. +# OSError - if something else goes wrong +# +def _remove_path_with_parents(basedir: Union[Path, str], path: Union[Path, str]): + assert not os.path.isabs(path), "The path ({}) should be relative to basedir ({})".format(path, basedir) + path = os.path.join(basedir, path) + + # Start by removing the path itself + os.unlink(path) + + # Now walk up the directory tree and delete any empty directories + path = os.path.dirname(path) + while path != basedir: + try: + os.rmdir(path) + except FileNotFoundError: + # The parent directory did not exist (race conditions can + # cause this), but it's parent directory might still be + # ready to prune + pass + except OSError as e: + if e.errno == errno.ENOTEMPTY: + # The parent directory was not empty, so we + # cannot prune directories beyond this point + break + raise + + path = os.path.dirname(path) + + # Recursively remove directories, ignoring file permissions as much as # possible. def _force_rmtree(rootpath, **kwargs): diff --git a/tests/sourcecache/fetch.py b/tests/sourcecache/fetch.py index 0c347ebbf..4096b56b8 100644 --- a/tests/sourcecache/fetch.py +++ b/tests/sourcecache/fetch.py @@ -37,6 +37,7 @@ DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "project") def move_local_cas_to_remote_source_share(local, remote): shutil.rmtree(os.path.join(remote, "repo", "cas")) + shutil.rmtree(os.path.join(remote, "repo", "source_protos")) shutil.move(os.path.join(local, "source_protos"), os.path.join(remote, "repo")) shutil.move(os.path.join(local, "cas"), os.path.join(remote, "repo")) shutil.rmtree(os.path.join(local, "sources")) @@ -85,8 +86,7 @@ def test_source_fetch(cli, tmpdir, datafiles): assert not element._source_cached() source = list(element.sources())[0] - cas = context.get_cascache() - assert not cas.contains(source._get_source_name()) + assert not share.get_source_proto(source._get_source_name()) # Just check that we sensibly fetch and build the element res = cli.run(project=project_dir, args=["build", element_name]) @@ -132,8 +132,7 @@ def test_fetch_fallback(cli, tmpdir, datafiles): assert not element._source_cached() source = list(element.sources())[0] - cas = context.get_cascache() - assert not cas.contains(source._get_source_name()) + assert not share.get_source_proto(source._get_source_name()) assert not os.path.exists(os.path.join(cache_dir, "sources")) # Now check if it falls back to the source fetch method. @@ -195,8 +194,7 @@ def test_source_pull_partial_fallback_fetch(cli, tmpdir, datafiles): assert not element._source_cached() source = list(element.sources())[0] - cas = context.get_cascache() - assert not cas.contains(source._get_source_name()) + assert not share.get_artifact_proto(source._get_source_name()) # Just check that we sensibly fetch and build the element res = cli.run(project=project_dir, args=["build", element_name]) diff --git a/tests/sourcecache/push.py b/tests/sourcecache/push.py index 719860425..0b7bb9954 100644 --- a/tests/sourcecache/push.py +++ b/tests/sourcecache/push.py @@ -89,8 +89,7 @@ def test_source_push_split(cli, tmpdir, datafiles): source = list(element.sources())[0] # check we don't have it in the current cache - cas = context.get_cascache() - assert not cas.contains(source._get_source_name()) + assert not index.get_source_proto(source._get_source_name()) # build the element, this should fetch and then push the source to the # remote @@ -139,8 +138,7 @@ def test_source_push(cli, tmpdir, datafiles): source = list(element.sources())[0] # check we don't have it in the current cache - cas = context.get_cascache() - assert not cas.contains(source._get_source_name()) + assert not share.get_source_proto(source._get_source_name()) # build the element, this should fetch and then push the source to the # remote diff --git a/tests/testutils/artifactshare.py b/tests/testutils/artifactshare.py index 8d0448f8a..19c19131a 100644 --- a/tests/testutils/artifactshare.py +++ b/tests/testutils/artifactshare.py @@ -13,7 +13,7 @@ from buildstream._cas import CASCache from buildstream._cas.casserver import create_server from buildstream._exceptions import CASError from buildstream._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 -from buildstream._protos.buildstream.v2 import artifact_pb2 +from buildstream._protos.buildstream.v2 import artifact_pb2, source_pb2 class BaseArtifactShare: @@ -120,6 +120,8 @@ class ArtifactShare(BaseArtifactShare): os.makedirs(self.repodir) self.artifactdir = os.path.join(self.repodir, "artifacts", "refs") os.makedirs(self.artifactdir) + self.sourcedir = os.path.join(self.repodir, "source_protos", "refs") + os.makedirs(self.sourcedir) self.cas = CASCache(self.repodir, casd=casd) @@ -160,6 +162,18 @@ class ArtifactShare(BaseArtifactShare): return artifact_proto + def get_source_proto(self, source_name): + source_proto = source_pb2.Source() + source_path = os.path.join(self.sourcedir, source_name) + + try: + with open(source_path, "rb") as f: + source_proto.ParseFromString(f.read()) + except FileNotFoundError: + return None + + return source_proto + def get_cas_files(self, artifact_proto): reachable = set() |