summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTristan Maat <tm@tlater.net>2019-12-03 11:38:54 +0000
committerTristan Maat <tm@tlater.net>2019-12-03 11:38:54 +0000
commit476a53975286364088576df270a3be13fe1bec4a (patch)
tree6534dd727c43f20a2381ae0df9d430716e4b93c7
parentc35a8eb80e4a069f3ba6fe1f3719c7c74ebf7fd1 (diff)
parent845a2fdb4ff786c9e9c7441ba321387124a25354 (diff)
downloadbuildstream-476a53975286364088576df270a3be13fe1bec4a.tar.gz
Merge branch 'tlater/artifactserver-casd' into 'master'
Refactor casserver.py: Stop relying on the buildstream-internal `CASCache` implementation Closes #1167 See merge request BuildStream/buildstream!1645
-rw-r--r--src/buildstream/_artifactcache.py26
-rw-r--r--src/buildstream/_basecache.py26
-rw-r--r--src/buildstream/_cas/cascache.py165
-rw-r--r--src/buildstream/_cas/casdprocessmanager.py8
-rw-r--r--src/buildstream/_cas/casserver.py491
-rw-r--r--src/buildstream/_exceptions.py9
-rw-r--r--src/buildstream/_sourcecache.py12
-rw-r--r--src/buildstream/utils.py38
-rw-r--r--tests/sourcecache/fetch.py10
-rw-r--r--tests/sourcecache/push.py6
-rw-r--r--tests/testutils/artifactshare.py16
11 files changed, 364 insertions, 443 deletions
diff --git a/src/buildstream/_artifactcache.py b/src/buildstream/_artifactcache.py
index 10ccf1527..02dd21d41 100644
--- a/src/buildstream/_artifactcache.py
+++ b/src/buildstream/_artifactcache.py
@@ -22,7 +22,7 @@ import os
import grpc
from ._basecache import BaseCache
-from ._exceptions import ArtifactError, CASError, CASCacheError, CASRemoteError, RemoteError
+from ._exceptions import ArtifactError, CASError, CacheError, CASRemoteError, RemoteError
from ._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc, artifact_pb2, artifact_pb2_grpc
from ._remote import BaseRemote
@@ -146,12 +146,12 @@ class ArtifactCache(BaseCache):
super().__init__(context)
# create artifact directory
- self.artifactdir = context.artifactdir
- os.makedirs(self.artifactdir, exist_ok=True)
+ self._basedir = context.artifactdir
+ os.makedirs(self._basedir, exist_ok=True)
def update_mtime(self, ref):
try:
- os.utime(os.path.join(self.artifactdir, ref))
+ os.utime(os.path.join(self._basedir, ref))
except FileNotFoundError as e:
raise ArtifactError("Couldn't find artifact: {}".format(ref)) from e
@@ -176,7 +176,7 @@ class ArtifactCache(BaseCache):
def contains(self, element, key):
ref = element.get_artifact_name(key)
- return os.path.exists(os.path.join(self.artifactdir, ref))
+ return os.path.exists(os.path.join(self._basedir, ref))
# list_artifacts():
#
@@ -189,7 +189,7 @@ class ArtifactCache(BaseCache):
# ([str]) - A list of artifact names as generated in LRU order
#
def list_artifacts(self, *, glob=None):
- return [ref for _, ref in sorted(list(self._list_refs_mtimes(self.artifactdir, glob_expr=glob)))]
+ return [ref for _, ref in sorted(list(self._list_refs_mtimes(self._basedir, glob_expr=glob)))]
# remove():
#
@@ -202,8 +202,8 @@ class ArtifactCache(BaseCache):
#
def remove(self, ref):
try:
- self.cas.remove(ref, basedir=self.artifactdir)
- except CASCacheError as e:
+ self._remove_ref(ref)
+ except CacheError as e:
raise ArtifactError("{}".format(e)) from e
# diff():
@@ -410,8 +410,8 @@ class ArtifactCache(BaseCache):
oldref = element.get_artifact_name(oldkey)
newref = element.get_artifact_name(newkey)
- if not os.path.exists(os.path.join(self.artifactdir, newref)):
- os.link(os.path.join(self.artifactdir, oldref), os.path.join(self.artifactdir, newref))
+ if not os.path.exists(os.path.join(self._basedir, newref)):
+ os.link(os.path.join(self._basedir, oldref), os.path.join(self._basedir, newref))
# get_artifact_logs():
#
@@ -514,7 +514,7 @@ class ArtifactCache(BaseCache):
# (iter): Iterator over directories digests available from artifacts.
#
def _reachable_directories(self):
- for root, _, files in os.walk(self.artifactdir):
+ for root, _, files in os.walk(self._basedir):
for artifact_file in files:
artifact = artifact_pb2.Artifact()
with open(os.path.join(root, artifact_file), "r+b") as f:
@@ -532,7 +532,7 @@ class ArtifactCache(BaseCache):
# (iter): Iterator over single file digests in artifacts
#
def _reachable_digests(self):
- for root, _, files in os.walk(self.artifactdir):
+ for root, _, files in os.walk(self._basedir):
for artifact_file in files:
artifact = artifact_pb2.Artifact()
with open(os.path.join(root, artifact_file), "r+b") as f:
@@ -707,7 +707,7 @@ class ArtifactCache(BaseCache):
return None
# Write the artifact proto to cache
- artifact_path = os.path.join(self.artifactdir, artifact_name)
+ artifact_path = os.path.join(self._basedir, artifact_name)
os.makedirs(os.path.dirname(artifact_path), exist_ok=True)
with utils.save_file_atomic(artifact_path, mode="wb") as f:
f.write(artifact.SerializeToString())
diff --git a/src/buildstream/_basecache.py b/src/buildstream/_basecache.py
index 15b1d5389..dff7742e7 100644
--- a/src/buildstream/_basecache.py
+++ b/src/buildstream/_basecache.py
@@ -25,7 +25,7 @@ from . import utils
from . import _yaml
from ._cas import CASRemote
from ._message import Message, MessageType
-from ._exceptions import LoadError, RemoteError
+from ._exceptions import LoadError, RemoteError, CacheError
from ._remote import RemoteSpec, RemoteType
@@ -62,6 +62,8 @@ class BaseCache:
self._has_fetch_remotes = False
self._has_push_remotes = False
+ self._basedir = None
+
# has_open_grpc_channels():
#
# Return whether there are gRPC channel instances. This is used to safeguard
@@ -429,3 +431,25 @@ class BaseCache:
if not glob_expr or fnmatch(relative_path, glob_expr):
# Obtain the mtime (the time a file was last modified)
yield (os.path.getmtime(ref_path), relative_path)
+
+ # _remove_ref()
+ #
+ # Removes a ref.
+ #
+ # This also takes care of pruning away directories which can
+ # be removed after having removed the given ref.
+ #
+ # Args:
+ # ref (str): The ref to remove
+ #
+ # Raises:
+ # (CASCacheError): If the ref didnt exist, or a system error
+ # occurred while removing it
+ #
+ def _remove_ref(self, ref):
+ try:
+ utils._remove_path_with_parents(self._basedir, ref)
+ except FileNotFoundError as e:
+ raise CacheError("Could not find ref '{}'".format(ref)) from e
+ except OSError as e:
+ raise CacheError("System error while removing ref '{}': {}".format(ref, e)) from e
diff --git a/src/buildstream/_cas/cascache.py b/src/buildstream/_cas/cascache.py
index 98581d351..c45a199fe 100644
--- a/src/buildstream/_cas/cascache.py
+++ b/src/buildstream/_cas/cascache.py
@@ -21,7 +21,6 @@
import itertools
import os
import stat
-import errno
import contextlib
import ctypes
import multiprocessing
@@ -69,7 +68,6 @@ class CASCache:
):
self.casdir = os.path.join(path, "cas")
self.tmpdir = os.path.join(path, "tmp")
- os.makedirs(os.path.join(self.casdir, "refs", "heads"), exist_ok=True)
os.makedirs(os.path.join(self.casdir, "objects"), exist_ok=True)
os.makedirs(self.tmpdir, exist_ok=True)
@@ -134,9 +132,7 @@ class CASCache:
# Preflight check.
#
def preflight(self):
- headdir = os.path.join(self.casdir, "refs", "heads")
- objdir = os.path.join(self.casdir, "objects")
- if not (os.path.isdir(headdir) and os.path.isdir(objdir)):
+ if not os.path.join(self.casdir, "objects"):
raise CASCacheError("CAS repository check failed for '{}'".format(self.casdir))
# close_grpc_channels():
@@ -160,21 +156,6 @@ class CASCache:
self._casd_process_manager.release_resources(messenger)
self._casd_process_manager = None
- # contains():
- #
- # Check whether the specified ref is already available in the local CAS cache.
- #
- # Args:
- # ref (str): The ref to check
- #
- # Returns: True if the ref is in the cache, False otherwise
- #
- def contains(self, ref):
- refpath = self._refpath(ref)
-
- # This assumes that the repository doesn't have any dangling pointers
- return os.path.exists(refpath)
-
# contains_file():
#
# Check whether a digest corresponds to a file which exists in CAS
@@ -261,28 +242,6 @@ class CASCache:
fullpath = os.path.join(dest, symlinknode.name)
os.symlink(symlinknode.target, fullpath)
- # diff():
- #
- # Return a list of files that have been added or modified between
- # the refs described by ref_a and ref_b.
- #
- # Args:
- # ref_a (str): The first ref
- # ref_b (str): The second ref
- # subdir (str): A subdirectory to limit the comparison to
- #
- def diff(self, ref_a, ref_b):
- tree_a = self.resolve_ref(ref_a)
- tree_b = self.resolve_ref(ref_b)
-
- added = []
- removed = []
- modified = []
-
- self.diff_trees(tree_a, tree_b, added=added, removed=removed, modified=modified)
-
- return modified, removed, added
-
# pull_tree():
#
# Pull a single Tree rather than a ref.
@@ -409,74 +368,6 @@ class CASCache:
return utils._message_digest(root_directory)
- # set_ref():
- #
- # Create or replace a ref.
- #
- # Args:
- # ref (str): The name of the ref
- #
- def set_ref(self, ref, tree):
- refpath = self._refpath(ref)
- os.makedirs(os.path.dirname(refpath), exist_ok=True)
- with utils.save_file_atomic(refpath, "wb", tempdir=self.tmpdir) as f:
- f.write(tree.SerializeToString())
-
- # resolve_ref():
- #
- # Resolve a ref to a digest.
- #
- # Args:
- # ref (str): The name of the ref
- # update_mtime (bool): Whether to update the mtime of the ref
- #
- # Returns:
- # (Digest): The digest stored in the ref
- #
- def resolve_ref(self, ref, *, update_mtime=False):
- refpath = self._refpath(ref)
-
- try:
- with open(refpath, "rb") as f:
- if update_mtime:
- os.utime(refpath)
-
- digest = remote_execution_pb2.Digest()
- digest.ParseFromString(f.read())
- return digest
-
- except FileNotFoundError as e:
- raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e
-
- # update_mtime()
- #
- # Update the mtime of a ref.
- #
- # Args:
- # ref (str): The ref to update
- #
- def update_mtime(self, ref):
- try:
- os.utime(self._refpath(ref))
- except FileNotFoundError as e:
- raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e
-
- # remove():
- #
- # Removes the given symbolic ref from the repo.
- #
- # Args:
- # ref (str): A symbolic ref
- # basedir (str): Path of base directory the ref is in, defaults to
- # CAS refs heads
- #
- def remove(self, ref, *, basedir=None):
-
- if basedir is None:
- basedir = os.path.join(self.casdir, "refs", "heads")
- # Remove cache ref
- self._remove_ref(ref, basedir)
-
def update_tree_mtime(self, tree):
reachable = set()
self._reachable_refs_dir(reachable, tree, update_mtime=True)
@@ -645,60 +536,6 @@ class CASCache:
# Local Private Methods #
################################################
- def _refpath(self, ref):
- return os.path.join(self.casdir, "refs", "heads", ref)
-
- # _remove_ref()
- #
- # Removes a ref.
- #
- # This also takes care of pruning away directories which can
- # be removed after having removed the given ref.
- #
- # Args:
- # ref (str): The ref to remove
- # basedir (str): Path of base directory the ref is in
- #
- # Raises:
- # (CASCacheError): If the ref didnt exist, or a system error
- # occurred while removing it
- #
- def _remove_ref(self, ref, basedir):
-
- # Remove the ref itself
- refpath = os.path.join(basedir, ref)
-
- try:
- os.unlink(refpath)
- except FileNotFoundError as e:
- raise CASCacheError("Could not find ref '{}'".format(ref)) from e
-
- # Now remove any leading directories
-
- components = list(os.path.split(ref))
- while components:
- components.pop()
- refdir = os.path.join(basedir, *components)
-
- # Break out once we reach the base
- if refdir == basedir:
- break
-
- try:
- os.rmdir(refdir)
- except FileNotFoundError:
- # The parent directory did not exist, but it's
- # parent directory might still be ready to prune
- pass
- except OSError as e:
- if e.errno == errno.ENOTEMPTY:
- # The parent directory was not empty, so we
- # cannot prune directories beyond this point
- break
-
- # Something went wrong here
- raise CASCacheError("System error while removing ref '{}': {}".format(ref, e)) from e
-
def _get_subdir(self, tree, subdir):
head, name = os.path.split(subdir)
if head:
diff --git a/src/buildstream/_cas/casdprocessmanager.py b/src/buildstream/_cas/casdprocessmanager.py
index e4a58d7d5..68bb88ef0 100644
--- a/src/buildstream/_cas/casdprocessmanager.py
+++ b/src/buildstream/_cas/casdprocessmanager.py
@@ -28,6 +28,7 @@ import grpc
from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2_grpc
from .._protos.build.buildgrid import local_cas_pb2_grpc
+from .._protos.google.bytestream import bytestream_pb2_grpc
from .. import _signals, utils
from .._exceptions import CASCacheError
@@ -177,6 +178,7 @@ class CASDChannel:
self._connection_string = connection_string
self._start_time = start_time
self._casd_channel = None
+ self._bytestream = None
self._casd_cas = None
self._local_cas = None
@@ -192,6 +194,7 @@ class CASDChannel:
time.sleep(0.01)
self._casd_channel = grpc.insecure_channel(self._connection_string)
+ self._bytestream = bytestream_pb2_grpc.ByteStreamStub(self._casd_channel)
self._casd_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self._casd_channel)
self._local_cas = local_cas_pb2_grpc.LocalContentAddressableStorageStub(self._casd_channel)
@@ -213,6 +216,11 @@ class CASDChannel:
self._establish_connection()
return self._local_cas
+ def get_bytestream(self):
+ if self._casd_channel is None:
+ self._establish_connection()
+ return self._bytestream
+
# is_closed():
#
# Return whether this connection is closed or not.
diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py
index a2110d8a2..e4acbde55 100644
--- a/src/buildstream/_cas/casserver.py
+++ b/src/buildstream/_cas/casserver.py
@@ -18,21 +18,24 @@
# Jürg Billeter <juerg.billeter@codethink.co.uk>
from concurrent import futures
-from contextlib import contextmanager
+from enum import Enum
+import contextlib
+import logging
import os
import signal
import sys
-import tempfile
import uuid
-import errno
import grpc
from google.protobuf.message import DecodeError
import click
-from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
-from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
-from .._protos.google.rpc import code_pb2
+from .._protos.build.bazel.remote.execution.v2 import (
+ remote_execution_pb2,
+ remote_execution_pb2_grpc,
+)
+from .._protos.google.bytestream import bytestream_pb2_grpc
+from .._protos.build.buildgrid import local_cas_pb2
from .._protos.buildstream.v2 import (
buildstream_pb2,
buildstream_pb2_grpc,
@@ -42,10 +45,15 @@ from .._protos.buildstream.v2 import (
source_pb2_grpc,
)
-from .. import utils
-from .._exceptions import CASError, CASCacheError
-
-from .cascache import CASCache
+# Note: We'd ideally like to avoid imports from the core codebase as
+# much as possible, since we're expecting to eventually split this
+# module off into its own project.
+#
+# Not enough that we'd like to duplicate code, but enough that we want
+# to make it very obvious what we're using, so in this case we import
+# the specific methods we'll be using.
+from ..utils import save_file_atomic, _remove_path_with_parents
+from .casdprocessmanager import CASDProcessManager
# The default limit for gRPC messages is 4 MiB.
@@ -53,6 +61,37 @@ from .cascache import CASCache
_MAX_PAYLOAD_BYTES = 1024 * 1024
+# LogLevel():
+#
+# Manage log level choices using click.
+#
+class LogLevel(click.Choice):
+ # Levels():
+ #
+ # Represents the actual buildbox-casd log level.
+ #
+ class Levels(Enum):
+ WARNING = "warning"
+ INFO = "info"
+ TRACE = "trace"
+
+ def __init__(self):
+ super().__init__([m.lower() for m in LogLevel.Levels._member_names_]) # pylint: disable=no-member
+
+ def convert(self, value, param, ctx) -> "LogLevel.Levels":
+ return LogLevel.Levels(super().convert(value, param, ctx))
+
+ @classmethod
+ def get_logging_equivalent(cls, level) -> int:
+ equivalents = {
+ cls.Levels.WARNING: logging.WARNING,
+ cls.Levels.INFO: logging.INFO,
+ cls.Levels.TRACE: logging.DEBUG,
+ }
+
+ return equivalents[level]
+
+
# create_server():
#
# Create gRPC CAS artifact server as specified in the Remote Execution API.
@@ -62,13 +101,22 @@ _MAX_PAYLOAD_BYTES = 1024 * 1024
# enable_push (bool): Whether to allow blob uploads and artifact updates
# index_only (bool): Whether to store CAS blobs or only artifacts
#
-@contextmanager
-def create_server(repo, *, enable_push, quota, index_only):
- cas = CASCache(os.path.abspath(repo), cache_quota=quota, protect_session_blobs=False)
+@contextlib.contextmanager
+def create_server(repo, *, enable_push, quota, index_only, log_level=LogLevel.Levels.WARNING):
+ logger = logging.getLogger("buildstream._cas.casserver")
+ logger.setLevel(LogLevel.get_logging_equivalent(log_level))
+ handler = logging.StreamHandler(sys.stderr)
+ handler.setFormatter(logging.Formatter(fmt="%(levelname)s: %(funcName)s: %(message)s"))
+ logger.addHandler(handler)
+
+ casd_manager = CASDProcessManager(
+ os.path.abspath(repo), os.path.join(os.path.abspath(repo), "logs"), log_level, quota, False
+ )
+ casd_channel = casd_manager.create_channel()
try:
- artifactdir = os.path.join(os.path.abspath(repo), "artifacts", "refs")
- sourcedir = os.path.join(os.path.abspath(repo), "source_protos")
+ root = os.path.abspath(repo)
+ sourcedir = os.path.join(root, "source_protos")
# Use max_workers default from Python 3.5+
max_workers = (os.cpu_count() or 1) * 5
@@ -76,21 +124,21 @@ def create_server(repo, *, enable_push, quota, index_only):
if not index_only:
bytestream_pb2_grpc.add_ByteStreamServicer_to_server(
- _ByteStreamServicer(cas, enable_push=enable_push), server
+ _ByteStreamServicer(casd_channel, enable_push=enable_push), server
)
remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server(
- _ContentAddressableStorageServicer(cas, enable_push=enable_push), server
+ _ContentAddressableStorageServicer(casd_channel, enable_push=enable_push), server
)
remote_execution_pb2_grpc.add_CapabilitiesServicer_to_server(_CapabilitiesServicer(), server)
buildstream_pb2_grpc.add_ReferenceStorageServicer_to_server(
- _ReferenceStorageServicer(cas, enable_push=enable_push), server
+ _ReferenceStorageServicer(casd_channel, root, enable_push=enable_push), server
)
artifact_pb2_grpc.add_ArtifactServiceServicer_to_server(
- _ArtifactServicer(cas, artifactdir, update_cas=not index_only), server
+ _ArtifactServicer(casd_channel, root, update_cas=not index_only), server
)
source_pb2_grpc.add_SourceServiceServicer_to_server(_SourceServicer(sourcedir), server)
@@ -105,7 +153,8 @@ def create_server(repo, *, enable_push, quota, index_only):
yield server
finally:
- cas.release_resources()
+ casd_channel.close()
+ casd_manager.release_resources()
@click.command(short_help="CAS Artifact Server")
@@ -120,14 +169,17 @@ def create_server(repo, *, enable_push, quota, index_only):
is_flag=True,
help='Only provide the BuildStream artifact and source services ("index"), not the CAS ("storage")',
)
+@click.option("--log-level", type=LogLevel(), help="The log level to launch with", default="warning")
@click.argument("repo")
-def server_main(repo, port, server_key, server_cert, client_certs, enable_push, quota, index_only):
+def server_main(repo, port, server_key, server_cert, client_certs, enable_push, quota, index_only, log_level):
# Handle SIGTERM by calling sys.exit(0), which will raise a SystemExit exception,
# properly executing cleanup code in `finally` clauses and context managers.
# This is required to terminate buildbox-casd on SIGTERM.
signal.signal(signal.SIGTERM, lambda signalnum, frame: sys.exit(0))
- with create_server(repo, quota=quota, enable_push=enable_push, index_only=index_only) as server:
+ with create_server(
+ repo, quota=quota, enable_push=enable_push, index_only=index_only, log_level=log_level
+ ) as server:
use_tls = bool(server_key)
@@ -171,216 +223,49 @@ def server_main(repo, port, server_key, server_cert, client_certs, enable_push,
class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer):
- def __init__(self, cas, *, enable_push):
+ def __init__(self, casd, *, enable_push):
super().__init__()
- self.cas = cas
+ self.bytestream = casd.get_bytestream()
self.enable_push = enable_push
+ self.logger = logging.getLogger("buildstream._cas.casserver")
def Read(self, request, context):
- resource_name = request.resource_name
- client_digest = _digest_from_download_resource_name(resource_name)
- if client_digest is None:
- context.set_code(grpc.StatusCode.NOT_FOUND)
- return
-
- if request.read_offset > client_digest.size_bytes:
- context.set_code(grpc.StatusCode.OUT_OF_RANGE)
- return
-
- try:
- with open(self.cas.objpath(client_digest), "rb") as f:
- if os.fstat(f.fileno()).st_size != client_digest.size_bytes:
- context.set_code(grpc.StatusCode.NOT_FOUND)
- return
-
- os.utime(f.fileno())
-
- if request.read_offset > 0:
- f.seek(request.read_offset)
-
- remaining = client_digest.size_bytes - request.read_offset
- while remaining > 0:
- chunk_size = min(remaining, _MAX_PAYLOAD_BYTES)
- remaining -= chunk_size
-
- response = bytestream_pb2.ReadResponse()
- # max. 64 kB chunks
- response.data = f.read(chunk_size)
- yield response
- except FileNotFoundError:
- context.set_code(grpc.StatusCode.NOT_FOUND)
+ self.logger.debug("Reading %s", request.resource_name)
+ return self.bytestream.Read(request)
def Write(self, request_iterator, context):
- response = bytestream_pb2.WriteResponse()
-
- if not self.enable_push:
- context.set_code(grpc.StatusCode.PERMISSION_DENIED)
- return response
-
- offset = 0
- finished = False
- resource_name = None
- with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out:
- for request in request_iterator:
- if finished or request.write_offset != offset:
- context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
- return response
-
- if resource_name is None:
- # First request
- resource_name = request.resource_name
- client_digest = _digest_from_upload_resource_name(resource_name)
- if client_digest is None:
- context.set_code(grpc.StatusCode.NOT_FOUND)
- return response
-
- while True:
- if client_digest.size_bytes == 0:
- break
-
- try:
- os.posix_fallocate(out.fileno(), 0, client_digest.size_bytes)
- break
- except OSError as e:
- # Multiple upload can happen in the same time
- if e.errno != errno.ENOSPC:
- raise
-
- elif request.resource_name:
- # If it is set on subsequent calls, it **must** match the value of the first request.
- if request.resource_name != resource_name:
- context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
- return response
-
- if (offset + len(request.data)) > client_digest.size_bytes:
- context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
- return response
-
- out.write(request.data)
- offset += len(request.data)
- if request.finish_write:
- if client_digest.size_bytes != offset:
- context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
- return response
- out.flush()
-
- try:
- digest = self.cas.add_object(path=out.name, link_directly=True)
- except CASCacheError as e:
- if e.reason == "cache-too-full":
- context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
- else:
- context.set_code(grpc.StatusCode.INTERNAL)
- return response
-
- if digest.hash != client_digest.hash:
- context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
- return response
-
- finished = True
-
- assert finished
-
- response.committed_size = offset
- return response
+ # Note that we can't easily give more information because the
+ # data is stuck in an iterator that will be consumed if read.
+ self.logger.debug("Writing data")
+ return self.bytestream.Write(request_iterator)
class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddressableStorageServicer):
- def __init__(self, cas, *, enable_push):
+ def __init__(self, casd, *, enable_push):
super().__init__()
- self.cas = cas
+ self.cas = casd.get_cas()
self.enable_push = enable_push
+ self.logger = logging.getLogger("buildstream._cas.casserver")
def FindMissingBlobs(self, request, context):
- response = remote_execution_pb2.FindMissingBlobsResponse()
- for digest in request.blob_digests:
- objpath = self.cas.objpath(digest)
- try:
- os.utime(objpath)
- except OSError as e:
- if e.errno != errno.ENOENT:
- raise
-
- d = response.missing_blob_digests.add()
- d.hash = digest.hash
- d.size_bytes = digest.size_bytes
-
- return response
+ self.logger.info("Finding '%s'", request.blob_digests)
+ return self.cas.FindMissingBlobs(request)
def BatchReadBlobs(self, request, context):
- response = remote_execution_pb2.BatchReadBlobsResponse()
- batch_size = 0
-
- for digest in request.digests:
- batch_size += digest.size_bytes
- if batch_size > _MAX_PAYLOAD_BYTES:
- context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
- return response
-
- blob_response = response.responses.add()
- blob_response.digest.hash = digest.hash
- blob_response.digest.size_bytes = digest.size_bytes
- try:
- objpath = self.cas.objpath(digest)
- with open(objpath, "rb") as f:
- if os.fstat(f.fileno()).st_size != digest.size_bytes:
- blob_response.status.code = code_pb2.NOT_FOUND
- continue
-
- os.utime(f.fileno())
-
- blob_response.data = f.read(digest.size_bytes)
- except FileNotFoundError:
- blob_response.status.code = code_pb2.NOT_FOUND
-
- return response
+ self.logger.info("Reading '%s'", request.digests)
+ return self.cas.BatchReadBlobs(request)
def BatchUpdateBlobs(self, request, context):
- response = remote_execution_pb2.BatchUpdateBlobsResponse()
-
- if not self.enable_push:
- context.set_code(grpc.StatusCode.PERMISSION_DENIED)
- return response
-
- batch_size = 0
-
- for blob_request in request.requests:
- digest = blob_request.digest
-
- batch_size += digest.size_bytes
- if batch_size > _MAX_PAYLOAD_BYTES:
- context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
- return response
-
- blob_response = response.responses.add()
- blob_response.digest.hash = digest.hash
- blob_response.digest.size_bytes = digest.size_bytes
-
- if len(blob_request.data) != digest.size_bytes:
- blob_response.status.code = code_pb2.FAILED_PRECONDITION
- continue
-
- with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out:
- out.write(blob_request.data)
- out.flush()
-
- try:
- server_digest = self.cas.add_object(path=out.name)
- except CASCacheError as e:
- if e.reason == "cache-too-full":
- blob_response.status.code = code_pb2.RESOURCE_EXHAUSTED
- else:
- blob_response.status.code = code_pb2.INTERNAL
- continue
-
- if server_digest.hash != digest.hash:
- blob_response.status.code = code_pb2.FAILED_PRECONDITION
-
- return response
+ self.logger.info("Updating: '%s'", [request.digest for request in request.requests])
+ return self.cas.BatchUpdateBlobs(request)
class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer):
+ def __init__(self):
+ self.logger = logging.getLogger("buildstream._cas.casserver")
+
def GetCapabilities(self, request, context):
+ self.logger.info("Retrieving capabilities")
response = remote_execution_pb2.ServerCapabilities()
cache_capabilities = response.cache_capabilities
@@ -397,31 +282,85 @@ class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer):
class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer):
- def __init__(self, cas, *, enable_push):
+ def __init__(self, casd, cas_root, *, enable_push):
super().__init__()
- self.cas = cas
+ self.cas = casd.get_cas()
+ self.root = cas_root
self.enable_push = enable_push
+ self.logger = logging.getLogger("buildstream._cas.casserver")
+ self.tmpdir = os.path.join(self.root, "tmp")
+ self.casdir = os.path.join(self.root, "cas")
+ self.refdir = os.path.join(self.casdir, "refs", "heads")
+ os.makedirs(self.tmpdir, exist_ok=True)
+
+ # ref_path():
+ #
+ # Get the path to a digest's file.
+ #
+ # Args:
+ # ref - The ref of the digest.
+ #
+ # Returns:
+ # str - The path to the digest's file.
+ #
+ def ref_path(self, ref: str) -> str:
+ return os.path.join(self.refdir, ref)
+
+ # set_ref():
+ #
+ # Create or update a ref with a new digest.
+ #
+ # Args:
+ # ref - The ref of the digest.
+ # tree - The digest to write.
+ #
+ def set_ref(self, ref: str, tree):
+ ref_path = self.ref_path(ref)
+
+ os.makedirs(os.path.dirname(ref_path), exist_ok=True)
+ with save_file_atomic(ref_path, "wb", tempdir=self.tmpdir) as f:
+ f.write(tree.SerializeToString())
+
+ # resolve_ref():
+ #
+ # Resolve a ref to a digest.
+ #
+ # Args:
+ # ref (str): The name of the ref
+ #
+ # Returns:
+ # (Digest): The digest stored in the ref
+ #
+ def resolve_ref(self, ref):
+ ref_path = self.ref_path(ref)
+
+ with open(ref_path, "rb") as f:
+ os.utime(ref_path)
+
+ digest = remote_execution_pb2.Digest()
+ digest.ParseFromString(f.read())
+ return digest
def GetReference(self, request, context):
+ self.logger.debug("'%s'", request.key)
response = buildstream_pb2.GetReferenceResponse()
try:
- tree = self.cas.resolve_ref(request.key, update_mtime=True)
- try:
- self.cas.update_tree_mtime(tree)
- except FileNotFoundError:
- self.cas.remove(request.key)
- context.set_code(grpc.StatusCode.NOT_FOUND)
- return response
-
- response.digest.hash = tree.hash
- response.digest.size_bytes = tree.size_bytes
- except CASError:
+ digest = self.resolve_ref(request.key)
+ except FileNotFoundError:
+ with contextlib.suppress(FileNotFoundError):
+ _remove_path_with_parents(self.refdir, request.key)
+
context.set_code(grpc.StatusCode.NOT_FOUND)
+ return response
+
+ response.digest.hash = digest.hash
+ response.digest.size_bytes = digest.size_bytes
return response
def UpdateReference(self, request, context):
+ self.logger.debug("%s -> %s", request.keys, request.digest)
response = buildstream_pb2.UpdateReferenceResponse()
if not self.enable_push:
@@ -429,11 +368,12 @@ class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer):
return response
for key in request.keys:
- self.cas.set_ref(key, request.digest)
+ self.set_ref(key, request.digest)
return response
def Status(self, request, context):
+ self.logger.debug("Retrieving status")
response = buildstream_pb2.StatusResponse()
response.allow_updates = self.enable_push
@@ -442,14 +382,48 @@ class _ReferenceStorageServicer(buildstream_pb2_grpc.ReferenceStorageServicer):
class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer):
- def __init__(self, cas, artifactdir, *, update_cas=True):
+ def __init__(self, casd, root, *, update_cas=True):
super().__init__()
- self.cas = cas
- self.artifactdir = artifactdir
+ self.cas = casd.get_cas()
+ self.local_cas = casd.get_local_cas()
+ self.root = root
+ self.artifactdir = os.path.join(root, "artifacts", "refs")
self.update_cas = update_cas
- os.makedirs(artifactdir, exist_ok=True)
+ self.logger = logging.getLogger("buildstream._cas.casserver")
+
+ # object_path():
+ #
+ # Get the path to an object's file.
+ #
+ # Args:
+ # digest - The digest of the object.
+ #
+ # Returns:
+ # str - The path to the object's file.
+ #
+ def object_path(self, digest) -> str:
+ return os.path.join(self.root, "cas", "objects", digest.hash[:2], digest.hash[2:])
+
+ # resolve_digest():
+ #
+ # Read the directory corresponding to a digest.
+ #
+ # Args:
+ # digest - The digest corresponding to a directory.
+ #
+ # Returns:
+ # remote_execution_pb2.Directory - The directory.
+ #
+ # Raises:
+ # FileNotFoundError - If the digest object doesn't exist.
+ def resolve_digest(self, digest):
+ directory = remote_execution_pb2.Directory()
+ with open(self.object_path(digest), "rb") as f:
+ directory.ParseFromString(f.read())
+ return directory
def GetArtifact(self, request, context):
+ self.logger.info("'%s'", request.cache_key)
artifact_path = os.path.join(self.artifactdir, request.cache_key)
if not os.path.exists(artifact_path):
context.abort(grpc.StatusCode.NOT_FOUND, "Artifact proto not found")
@@ -458,6 +432,8 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer):
with open(artifact_path, "rb") as f:
artifact.ParseFromString(f.read())
+ os.utime(artifact_path)
+
# Artifact-only servers will not have blobs on their system,
# so we can't reasonably update their mtimes. Instead, we exit
# early, and let the CAS server deal with its blobs.
@@ -476,30 +452,45 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer):
try:
if str(artifact.files):
- self.cas.update_tree_mtime(artifact.files)
+ request = local_cas_pb2.FetchTreeRequest()
+ request.root_digest.CopyFrom(artifact.files)
+ request.fetch_file_blobs = True
+ self.local_cas.FetchTree(request)
if str(artifact.buildtree):
- # buildtrees might not be there
try:
- self.cas.update_tree_mtime(artifact.buildtree)
- except FileNotFoundError:
- pass
+ request = local_cas_pb2.FetchTreeRequest()
+ request.root_digest.CopyFrom(artifact.buildtree)
+ request.fetch_file_blobs = True
+ self.local_cas.FetchTree(request)
+ except grpc.RpcError as err:
+ # buildtrees might not be there
+ if err.code() != grpc.StatusCode.NOT_FOUND:
+ raise
if str(artifact.public_data):
- os.utime(self.cas.objpath(artifact.public_data))
+ request = remote_execution_pb2.FindMissingBlobsRequest()
+ d = request.blob_digests.add()
+ d.CopyFrom(artifact.public_data)
+ self.cas.FindMissingBlobs(request)
+ request = remote_execution_pb2.FindMissingBlobsRequest()
for log_file in artifact.logs:
- os.utime(self.cas.objpath(log_file.digest))
-
- except FileNotFoundError:
- os.unlink(artifact_path)
- context.abort(grpc.StatusCode.NOT_FOUND, "Artifact files incomplete")
- except DecodeError:
- context.abort(grpc.StatusCode.NOT_FOUND, "Artifact files not valid")
+ d = request.blob_digests.add()
+ d.CopyFrom(log_file.digest)
+ self.cas.FindMissingBlobs(request)
+
+ except grpc.RpcError as err:
+ if err.code() == grpc.StatusCode.NOT_FOUND:
+ os.unlink(artifact_path)
+ context.abort(grpc.StatusCode.NOT_FOUND, "Artifact files incomplete")
+ else:
+ context.abort(grpc.StatusCode.NOT_FOUND, "Artifact files not valid")
return artifact
def UpdateArtifact(self, request, context):
+ self.logger.info("'%s' -> '%s'", request.artifact, request.cache_key)
artifact = request.artifact
if self.update_cas:
@@ -518,28 +509,29 @@ class _ArtifactServicer(artifact_pb2_grpc.ArtifactServiceServicer):
# Add the artifact proto to the cas
artifact_path = os.path.join(self.artifactdir, request.cache_key)
os.makedirs(os.path.dirname(artifact_path), exist_ok=True)
- with utils.save_file_atomic(artifact_path, mode="wb") as f:
+ with save_file_atomic(artifact_path, mode="wb") as f:
f.write(artifact.SerializeToString())
return artifact
def ArtifactStatus(self, request, context):
+ self.logger.info("Retrieving status")
return artifact_pb2.ArtifactStatusResponse()
def _check_directory(self, name, digest, context):
try:
- directory = remote_execution_pb2.Directory()
- with open(self.cas.objpath(digest), "rb") as f:
- directory.ParseFromString(f.read())
+ self.resolve_digest(digest)
except FileNotFoundError:
+ self.logger.warning("Artifact %s specified but no files found", name)
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Artifact {} specified but no files found".format(name))
except DecodeError:
+ self.logger.warning("Artifact %s specified but directory not found", name)
context.abort(
grpc.StatusCode.FAILED_PRECONDITION, "Artifact {} specified but directory not found".format(name)
)
def _check_file(self, name, digest, context):
- if not os.path.exists(self.cas.objpath(digest)):
+ if not os.path.exists(self.object_path(digest)):
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Artifact {} specified but not found".format(name))
@@ -558,8 +550,10 @@ class _BuildStreamCapabilitiesServicer(buildstream_pb2_grpc.CapabilitiesServicer
class _SourceServicer(source_pb2_grpc.SourceServiceServicer):
def __init__(self, sourcedir):
self.sourcedir = sourcedir
+ self.logger = logging.getLogger("buildstream._cas.casserver")
def GetSource(self, request, context):
+ self.logger.info("'%s'", request.cache_key)
try:
source_proto = self._get_source(request.cache_key)
except FileNotFoundError:
@@ -570,6 +564,7 @@ class _SourceServicer(source_pb2_grpc.SourceServiceServicer):
return source_proto
def UpdateSource(self, request, context):
+ self.logger.info("'%s' -> '%s'", request.source, request.cache_key)
self._set_source(request.cache_key, request.source)
return request.source
@@ -584,7 +579,7 @@ class _SourceServicer(source_pb2_grpc.SourceServiceServicer):
def _set_source(self, cache_key, source_proto):
path = os.path.join(self.sourcedir, cache_key)
os.makedirs(os.path.dirname(path), exist_ok=True)
- with utils.save_file_atomic(path, "w+b") as f:
+ with save_file_atomic(path, "w+b") as f:
f.write(source_proto.SerializeToString())
diff --git a/src/buildstream/_exceptions.py b/src/buildstream/_exceptions.py
index ca17577f7..51f542783 100644
--- a/src/buildstream/_exceptions.py
+++ b/src/buildstream/_exceptions.py
@@ -273,6 +273,15 @@ class SandboxError(BstError):
super().__init__(message, detail=detail, domain=ErrorDomain.SANDBOX, reason=reason)
+# CacheError
+#
+# Raised when errors are encountered in either type of cache
+#
+class CacheError(BstError):
+ def __init__(self, message, detail=None, reason=None):
+ super().__init__(message, detail=detail, domain=ErrorDomain.SANDBOX, reason=reason)
+
+
# SourceCacheError
#
# Raised when errors are encountered in the source caches
diff --git a/src/buildstream/_sourcecache.py b/src/buildstream/_sourcecache.py
index 03e2d1830..221694e94 100644
--- a/src/buildstream/_sourcecache.py
+++ b/src/buildstream/_sourcecache.py
@@ -129,8 +129,8 @@ class SourceCache(BaseCache):
def __init__(self, context):
super().__init__(context)
- self.sourcerefdir = os.path.join(context.cachedir, "source_protos")
- os.makedirs(self.sourcerefdir, exist_ok=True)
+ self._basedir = os.path.join(context.cachedir, "source_protos")
+ os.makedirs(self._basedir, exist_ok=True)
# list_sources()
#
@@ -140,7 +140,7 @@ class SourceCache(BaseCache):
# ([str]): iterable over all source refs
#
def list_sources(self):
- return [ref for _, ref in self._list_refs_mtimes(self.sourcerefdir)]
+ return [ref for _, ref in self._list_refs_mtimes(self._basedir)]
# contains()
#
@@ -326,7 +326,7 @@ class SourceCache(BaseCache):
return pushed_index and pushed_storage
def _remove_source(self, ref, *, defer_prune=False):
- return self.cas.remove(ref, basedir=self.sourcerefdir, defer_prune=defer_prune)
+ return self.cas.remove(ref, basedir=self._basedir, defer_prune=defer_prune)
def _store_source(self, ref, digest):
source_proto = source_pb2.Source()
@@ -351,10 +351,10 @@ class SourceCache(BaseCache):
raise SourceCacheError("Attempted to access unavailable source: {}".format(e)) from e
def _source_path(self, ref):
- return os.path.join(self.sourcerefdir, ref)
+ return os.path.join(self._basedir, ref)
def _reachable_directories(self):
- for root, _, files in os.walk(self.sourcerefdir):
+ for root, _, files in os.walk(self._basedir):
for source_file in files:
source = source_pb2.Source()
with open(os.path.join(root, source_file), "r+b") as f:
diff --git a/src/buildstream/utils.py b/src/buildstream/utils.py
index 0307b8470..5009be338 100644
--- a/src/buildstream/utils.py
+++ b/src/buildstream/utils.py
@@ -779,6 +779,44 @@ def _is_main_process():
return os.getpid() == _MAIN_PID
+# Remove a path and any empty directories leading up to it.
+#
+# Args:
+# basedir - The basedir at which to stop pruning even if
+# it is empty.
+# path - A path relative to basedir that should be pruned.
+#
+# Raises:
+# FileNotFoundError - if the path itself doesn't exist.
+# OSError - if something else goes wrong
+#
+def _remove_path_with_parents(basedir: Union[Path, str], path: Union[Path, str]):
+ assert not os.path.isabs(path), "The path ({}) should be relative to basedir ({})".format(path, basedir)
+ path = os.path.join(basedir, path)
+
+ # Start by removing the path itself
+ os.unlink(path)
+
+ # Now walk up the directory tree and delete any empty directories
+ path = os.path.dirname(path)
+ while path != basedir:
+ try:
+ os.rmdir(path)
+ except FileNotFoundError:
+ # The parent directory did not exist (race conditions can
+ # cause this), but it's parent directory might still be
+ # ready to prune
+ pass
+ except OSError as e:
+ if e.errno == errno.ENOTEMPTY:
+ # The parent directory was not empty, so we
+ # cannot prune directories beyond this point
+ break
+ raise
+
+ path = os.path.dirname(path)
+
+
# Recursively remove directories, ignoring file permissions as much as
# possible.
def _force_rmtree(rootpath, **kwargs):
diff --git a/tests/sourcecache/fetch.py b/tests/sourcecache/fetch.py
index 0c347ebbf..4096b56b8 100644
--- a/tests/sourcecache/fetch.py
+++ b/tests/sourcecache/fetch.py
@@ -37,6 +37,7 @@ DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "project")
def move_local_cas_to_remote_source_share(local, remote):
shutil.rmtree(os.path.join(remote, "repo", "cas"))
+ shutil.rmtree(os.path.join(remote, "repo", "source_protos"))
shutil.move(os.path.join(local, "source_protos"), os.path.join(remote, "repo"))
shutil.move(os.path.join(local, "cas"), os.path.join(remote, "repo"))
shutil.rmtree(os.path.join(local, "sources"))
@@ -85,8 +86,7 @@ def test_source_fetch(cli, tmpdir, datafiles):
assert not element._source_cached()
source = list(element.sources())[0]
- cas = context.get_cascache()
- assert not cas.contains(source._get_source_name())
+ assert not share.get_source_proto(source._get_source_name())
# Just check that we sensibly fetch and build the element
res = cli.run(project=project_dir, args=["build", element_name])
@@ -132,8 +132,7 @@ def test_fetch_fallback(cli, tmpdir, datafiles):
assert not element._source_cached()
source = list(element.sources())[0]
- cas = context.get_cascache()
- assert not cas.contains(source._get_source_name())
+ assert not share.get_source_proto(source._get_source_name())
assert not os.path.exists(os.path.join(cache_dir, "sources"))
# Now check if it falls back to the source fetch method.
@@ -195,8 +194,7 @@ def test_source_pull_partial_fallback_fetch(cli, tmpdir, datafiles):
assert not element._source_cached()
source = list(element.sources())[0]
- cas = context.get_cascache()
- assert not cas.contains(source._get_source_name())
+ assert not share.get_artifact_proto(source._get_source_name())
# Just check that we sensibly fetch and build the element
res = cli.run(project=project_dir, args=["build", element_name])
diff --git a/tests/sourcecache/push.py b/tests/sourcecache/push.py
index 719860425..0b7bb9954 100644
--- a/tests/sourcecache/push.py
+++ b/tests/sourcecache/push.py
@@ -89,8 +89,7 @@ def test_source_push_split(cli, tmpdir, datafiles):
source = list(element.sources())[0]
# check we don't have it in the current cache
- cas = context.get_cascache()
- assert not cas.contains(source._get_source_name())
+ assert not index.get_source_proto(source._get_source_name())
# build the element, this should fetch and then push the source to the
# remote
@@ -139,8 +138,7 @@ def test_source_push(cli, tmpdir, datafiles):
source = list(element.sources())[0]
# check we don't have it in the current cache
- cas = context.get_cascache()
- assert not cas.contains(source._get_source_name())
+ assert not share.get_source_proto(source._get_source_name())
# build the element, this should fetch and then push the source to the
# remote
diff --git a/tests/testutils/artifactshare.py b/tests/testutils/artifactshare.py
index 8d0448f8a..19c19131a 100644
--- a/tests/testutils/artifactshare.py
+++ b/tests/testutils/artifactshare.py
@@ -13,7 +13,7 @@ from buildstream._cas import CASCache
from buildstream._cas.casserver import create_server
from buildstream._exceptions import CASError
from buildstream._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
-from buildstream._protos.buildstream.v2 import artifact_pb2
+from buildstream._protos.buildstream.v2 import artifact_pb2, source_pb2
class BaseArtifactShare:
@@ -120,6 +120,8 @@ class ArtifactShare(BaseArtifactShare):
os.makedirs(self.repodir)
self.artifactdir = os.path.join(self.repodir, "artifacts", "refs")
os.makedirs(self.artifactdir)
+ self.sourcedir = os.path.join(self.repodir, "source_protos", "refs")
+ os.makedirs(self.sourcedir)
self.cas = CASCache(self.repodir, casd=casd)
@@ -160,6 +162,18 @@ class ArtifactShare(BaseArtifactShare):
return artifact_proto
+ def get_source_proto(self, source_name):
+ source_proto = source_pb2.Source()
+ source_path = os.path.join(self.sourcedir, source_name)
+
+ try:
+ with open(source_path, "rb") as f:
+ source_proto.ParseFromString(f.read())
+ except FileNotFoundError:
+ return None
+
+ return source_proto
+
def get_cas_files(self, artifact_proto):
reachable = set()