diff options
-rw-r--r-- | src/buildstream/_cas/casserver.py | 7 | ||||
-rw-r--r-- | src/buildstream/_signals.py | 13 |
2 files changed, 12 insertions, 8 deletions
diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py index dd822d53b..43cc131aa 100644 --- a/src/buildstream/_cas/casserver.py +++ b/src/buildstream/_cas/casserver.py @@ -29,6 +29,7 @@ import grpc from google.protobuf.message import DecodeError import click +from .. import _signals from .._protos.build.bazel.remote.execution.v2 import ( remote_execution_pb2, remote_execution_pb2_grpc, @@ -149,7 +150,11 @@ def create_server(repo, *, enable_push, quota, index_only, log_level=LogLevel.Le _BuildStreamCapabilitiesServicer(artifact_capabilities, source_capabilities), server ) - yield server + # Ensure we have the signal handler set for SIGTERM + # This allows threads from GRPC to call our methods that do register + # handlers at exit. + with _signals.terminator(lambda: None): + yield server finally: casd_channel.close() diff --git a/src/buildstream/_signals.py b/src/buildstream/_signals.py index 03b55b052..1edd445b6 100644 --- a/src/buildstream/_signals.py +++ b/src/buildstream/_signals.py @@ -80,13 +80,10 @@ def terminator_handler(signal_, frame): def terminator(terminate_func): global terminator_stack # pylint: disable=global-statement - # Signal handling only works in the main thread - if threading.current_thread() != threading.main_thread(): - yield - return - outermost = bool(not terminator_stack) + assert threading.current_thread() == threading.main_thread() or not outermost + terminator_stack.append(terminate_func) if outermost: original_handler = signal.signal(signal.SIGTERM, terminator_handler) @@ -96,7 +93,7 @@ def terminator(terminate_func): finally: if outermost: signal.signal(signal.SIGTERM, original_handler) - terminator_stack.pop() + terminator_stack.remove(terminate_func) # Just a simple object for holding on to two callbacks @@ -146,6 +143,8 @@ def suspendable(suspend_callback, resume_callback): global suspendable_stack # pylint: disable=global-statement outermost = bool(not suspendable_stack) + assert threading.current_thread() == threading.main_thread() or not outermost + suspender = Suspender(suspend_callback, resume_callback) suspendable_stack.append(suspender) @@ -158,7 +157,7 @@ def suspendable(suspend_callback, resume_callback): if outermost: signal.signal(signal.SIGTSTP, original_stop) - suspendable_stack.pop() + suspendable_stack.remove(suspender) # blocked() |